diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp index 4185d9b94..b7dbd8ef0 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp @@ -60,15 +60,22 @@ int64_t getDuration(const AVFrame* frame) { #endif } -int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext) { +int getNumChannels(const AVFrame* avFrame) { #if LIBAVFILTER_VERSION_MAJOR > 8 || \ (LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44) - int numChannels = avCodecContext->ch_layout.nb_channels; + return avFrame->ch_layout.nb_channels; #else - int numChannels = avCodecContext->channels; + return av_get_channel_layout_nb_channels(avFrame->channel_layout); #endif +} - return static_cast(numChannels); +int getNumChannels(const UniqueAVCodecContext& avCodecContext) { +#if LIBAVFILTER_VERSION_MAJOR > 8 || \ + (LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44) + return avCodecContext->ch_layout.nb_channels; +#else + return avCodecContext->channels; +#endif } AVIOBytesContext::AVIOBytesContext( diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.h b/src/torchcodec/decoders/_core/FFMPEGCommon.h index 0454058bc..88a81d18b 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.h +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.h @@ -139,7 +139,8 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode); int64_t getDuration(const UniqueAVFrame& frame); int64_t getDuration(const AVFrame* frame); -int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext); +int getNumChannels(const AVFrame* avFrame); +int getNumChannels(const UniqueAVCodecContext& avCodecContext); // Returns true if sws_scale can handle unaligned data. bool canSwsScaleHandleUnalignedData(); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 97214cec1..c0738e570 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -552,7 +553,8 @@ void VideoDecoder::addAudioStream(int streamIndex) { containerMetadata_.allStreamMetadata[activeStreamIndex_]; streamMetadata.sampleRate = static_cast(streamInfo.codecContext->sample_rate); - streamMetadata.numChannels = getNumChannels(streamInfo.codecContext); + streamMetadata.numChannels = + static_cast(getNumChannels(streamInfo.codecContext)); } // -------------------------------------------------------------------------- @@ -567,6 +569,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() { VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal( std::optional preAllocatedOutputTensor) { + validateActiveStream(AVMEDIA_TYPE_VIDEO); AVFrameStream avFrameStream = decodeAVFrame( [this](AVFrame* avFrame) { return avFrame->pts >= cursor_; }); return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor); @@ -685,6 +688,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) { } VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) { + validateActiveStream(AVMEDIA_TYPE_VIDEO); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; double frameStartTime = ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase); @@ -757,7 +761,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( double startSeconds, double stopSeconds) { validateActiveStream(AVMEDIA_TYPE_VIDEO); - const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; TORCH_CHECK( @@ -835,6 +838,68 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( return frameBatchOutput; } +torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( + double startSeconds, + std::optional stopSecondsOptional) { + validateActiveStream(AVMEDIA_TYPE_AUDIO); + + double stopSeconds = + stopSecondsOptional.value_or(std::numeric_limits::max()); + + TORCH_CHECK( + startSeconds <= stopSeconds, + "Start seconds (" + std::to_string(startSeconds) + + ") must be less than or equal to stop seconds (" + + std::to_string(stopSeconds) + "."); + + if (startSeconds == stopSeconds) { + // For consistency with video + return torch::empty({0}); + } + + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; + + // TODO-AUDIO This essentially enforce that we don't need to seek (backwards). + // We should remove it and seek back to the stream's beginning when needed. + // See test_multiple_calls + TORCH_CHECK( + streamInfo.lastDecodedAvFramePts + + streamInfo.lastDecodedAvFrameDuration <= + secondsToClosestPts(startSeconds, streamInfo.timeBase), + "Audio decoder cannot seek backwards, or start from the last decoded frame."); + + setCursorPtsInSeconds(startSeconds); + + // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec + + // cat(). This would save a copy. We know the duration of the output and the + // sample rate, so in theory we know the number of output samples. + std::vector tensors; + + auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase); + auto finished = false; + while (!finished) { + try { + AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) { + return cursor_ < avFrame->pts + getDuration(avFrame); + }); + auto frameOutput = convertAVFrameToFrameOutput(avFrameStream); + tensors.push_back(frameOutput.data); + } catch (const EndOfFileException& e) { + finished = true; + } + + // If stopSeconds is in [begin, end] of the last decoded frame, we should + // stop decoding more frames. Note that if we were to use [begin, end), + // which may seem more natural, then we would decode the frame starting at + // stopSeconds, which isn't what we want! + auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts + + streamInfo.lastDecodedAvFrameDuration; + finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts && + (stopPts <= lastDecodedAvFrameEnd); + } + return torch::cat(tensors, 1); +} + // -------------------------------------------------------------------------- // SEEKING APIs // -------------------------------------------------------------------------- @@ -871,6 +936,10 @@ I P P P I P P P I P P I P P I P (2) is more efficient than (1) if there is an I frame between x and y. */ bool VideoDecoder::canWeAvoidSeeking() const { + const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); + if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { + return true; + } int64_t lastDecodedAvFramePts = streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts; if (cursor_ < lastDecodedAvFramePts) { @@ -897,7 +966,7 @@ bool VideoDecoder::canWeAvoidSeeking() const { // AVFormatContext if it is needed. We can skip seeking in certain cases. See // the comment of canWeAvoidSeeking() for details. void VideoDecoder::maybeSeekToBeforeDesiredPts() { - validateActiveStream(AVMEDIA_TYPE_VIDEO); + validateActiveStream(); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; decodeStats_.numSeeksAttempted++; @@ -942,7 +1011,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( std::function filterFunction) { - validateActiveStream(AVMEDIA_TYPE_VIDEO); + validateActiveStream(); resetDecodeStats(); @@ -1071,13 +1140,14 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( AVFrame* avFrame = avFrameStream.avFrame.get(); frameOutput.streamIndex = streamIndex; auto& streamInfo = streamInfos_[streamIndex]; - TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO); frameOutput.ptsSeconds = ptsToSeconds( avFrame->pts, formatContext_->streams[streamIndex]->time_base); frameOutput.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[streamIndex]->time_base); - // TODO: we should fold preAllocatedOutputTensor into AVFrameStream. - if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { + if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { + convertAudioAVFrameToFrameOutputOnCPU( + avFrameStream, frameOutput, preAllocatedOutputTensor); + } else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { convertAVFrameToFrameOutputOnCPU( avFrameStream, frameOutput, preAllocatedOutputTensor); } else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) { @@ -1253,6 +1323,45 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); } +void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU( + VideoDecoder::AVFrameStream& avFrameStream, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + TORCH_CHECK( + !preAllocatedOutputTensor.has_value(), + "pre-allocated audio tensor not supported yet."); + + const AVFrame* avFrame = avFrameStream.avFrame.get(); + + auto numSamples = avFrame->nb_samples; // per channel + auto numChannels = getNumChannels(avFrame); + torch::Tensor outputData = + torch::empty({numChannels, numSamples}, torch::kFloat32); + + AVSampleFormat format = static_cast(avFrame->format); + // TODO-AUDIO Implement all formats. + switch (format) { + case AV_SAMPLE_FMT_FLTP: { + uint8_t* outputChannelData = static_cast(outputData.data_ptr()); + auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format); + for (auto channel = 0; channel < numChannels; + ++channel, outputChannelData += numBytesPerChannel) { + memcpy( + outputChannelData, + avFrame->extended_data[channel], + numBytesPerChannel); + } + break; + } + default: + TORCH_CHECK( + false, + "Unsupported audio format (yet!): ", + av_get_sample_fmt_name(format)); + } + frameOutput.data = outputData; +} + // -------------------------------------------------------------------------- // OUTPUT ALLOCATION AND SHAPE CONVERSION // -------------------------------------------------------------------------- diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index a41ea50c2..66b9d93c4 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -221,6 +221,11 @@ class VideoDecoder { double startSeconds, double stopSeconds); + // TODO-AUDIO: Should accept sampleRate + torch::Tensor getFramesPlayedInRangeAudio( + double startSeconds, + std::optional stopSecondsOptional = std::nullopt); + class EndOfFileException : public std::runtime_error { public: explicit EndOfFileException(const std::string& msg) @@ -379,6 +384,11 @@ class VideoDecoder { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); + void convertAudioAVFrameToFrameOutputOnCPU( + AVFrameStream& avFrameStream, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = std::nullopt); + torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame); int convertAVFrameToTensorUsingSwsScale( diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index bb13e113d..9eb61ac20 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -25,8 +25,7 @@ namespace facebook::torchcodec { // https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#readme TORCH_LIBRARY(torchcodec_ns, m) { m.impl_abstract_pystub( - "torchcodec.decoders._core.video_decoder_ops", - "//pytorch/torchcodec:torchcodec"); + "torchcodec.decoders._core.ops", "//pytorch/torchcodec:torchcodec"); m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); @@ -48,6 +47,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); + m.def( + "get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> Tensor"); m.def( "get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)"); m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor"); @@ -289,6 +290,14 @@ OpsFrameBatchOutput get_frames_by_pts_in_range( return makeOpsFrameBatchOutput(result); } +torch::Tensor get_frames_by_pts_in_range_audio( + at::Tensor& decoder, + double start_seconds, + std::optional stop_seconds) { + auto videoDecoder = unwrapTensorToGetDecoder(decoder); + return videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds); +} + std::string quoteValue(const std::string& value) { return "\"" + value + "\""; } @@ -540,6 +549,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("get_frames_at_indices", &get_frames_at_indices); m.impl("get_frames_in_range", &get_frames_in_range); m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range); + m.impl("get_frames_by_pts_in_range_audio", &get_frames_by_pts_in_range_audio); m.impl("get_frames_by_pts", &get_frames_by_pts); m.impl("_test_frame_pts_equality", &_test_frame_pts_equality); m.impl( diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 034a8842a..c8d324075 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -112,6 +112,11 @@ OpsFrameBatchOutput get_frames_by_pts_in_range( double start_seconds, double stop_seconds); +torch::Tensor get_frames_by_pts_in_range_audio( + at::Tensor& decoder, + double start_seconds, + std::optional stop_seconds = std::nullopt); + // For testing only. We need to implement this operation as a core library // function because what we're testing is round-tripping pts values as // double-precision floating point numbers from C++ to Python and back to C++. diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index 7dcb866c9..490e3d834 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -12,7 +12,7 @@ get_container_metadata_from_header, VideoStreamMetadata, ) -from .video_decoder_ops import ( +from .ops import ( _add_video_stream, _get_key_frame_indices, _test_frame_pts_equality, @@ -27,6 +27,7 @@ get_frames_at_indices, get_frames_by_pts, get_frames_by_pts_in_range, + get_frames_by_pts_in_range_audio, get_frames_in_range, get_json_metadata, get_next_frame, diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index ff48b72a4..fcfddecc9 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -12,7 +12,7 @@ import torch -from torchcodec.decoders._core.video_decoder_ops import ( +from torchcodec.decoders._core.ops import ( _get_container_json_metadata, _get_stream_json_metadata, create_from_file, diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/ops.py similarity index 95% rename from src/torchcodec/decoders/_core/video_decoder_ops.py rename to src/torchcodec/decoders/_core/ops.py index 190384684..74796a172 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -78,6 +78,9 @@ def load_torchcodec_extension(): get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default +get_frames_by_pts_in_range_audio = ( + torch.ops.torchcodec_ns.get_frames_by_pts_in_range_audio.default +) get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default _test_frame_pts_equality = torch.ops.torchcodec_ns._test_frame_pts_equality.default _get_container_json_metadata = ( @@ -262,6 +265,17 @@ def get_frames_by_pts_in_range_abstract( ) +@register_fake("torchcodec_ns::get_frames_by_pts_in_range_audio") +def get_frames_by_pts_in_range_audio_abstract( + decoder: torch.Tensor, + *, + start_seconds: float, + stop_seconds: Optional[float] = None, +) -> torch.Tensor: + image_size = [get_ctx().new_dynamic_size() for _ in range(4)] + return torch.empty(image_size) + + @register_fake("torchcodec_ns::_get_key_frame_indices") def get_key_frame_indices_abstract(decoder: torch.Tensor) -> torch.Tensor: return torch.empty([], dtype=torch.int) diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index 24558d8cb..e03391204 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -22,6 +22,7 @@ get_ffmpeg_major_version, H265_VIDEO, in_fbcode, + NASA_AUDIO, NASA_AUDIO_MP3, NASA_VIDEO, ) @@ -32,7 +33,7 @@ class TestDecoder: "Decoder, asset", ( (VideoDecoder, NASA_VIDEO), - (AudioDecoder, NASA_VIDEO), + (AudioDecoder, NASA_AUDIO), (AudioDecoder, NASA_AUDIO_MP3), ), ) @@ -939,11 +940,18 @@ def get_some_frames(decoder): class TestAudioDecoder: - def test_metadata(self): - decoder = AudioDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_metadata(self, asset): + decoder = AudioDecoder(asset.path) assert isinstance(decoder.metadata, AudioStreamMetadata) - assert decoder.stream_index == decoder.metadata.stream_index == 4 - assert decoder.metadata.duration_seconds == pytest.approx(13.056) - assert decoder.metadata.sample_rate == 16_000 - assert decoder.metadata.num_channels == 2 + assert ( + decoder.stream_index + == decoder.metadata.stream_index + == asset.default_stream_index + ) + assert decoder.metadata.duration_seconds == pytest.approx( + asset.duration_seconds + ) + assert decoder.metadata.sample_rate == asset.sample_rate + assert decoder.metadata.num_channels == asset.num_channels diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index ab0e2bb09..e33b9941d 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -30,6 +30,7 @@ get_frames_at_indices, get_frames_by_pts, get_frames_by_pts_in_range, + get_frames_by_pts_in_range_audio, get_frames_in_range, get_json_metadata, get_next_frame, @@ -39,6 +40,7 @@ from ..utils import ( assert_frames_equal, cpu_and_cuda, + NASA_AUDIO, NASA_AUDIO_MP3, NASA_VIDEO, needs_cuda, @@ -49,7 +51,7 @@ INDEX_OF_FRAME_AT_6_SECONDS = 180 -class TestOps: +class TestVideoOps: @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_and_next(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) @@ -396,12 +398,6 @@ def test_video_get_json_metadata_with_stream(self): assert metadata_dict["minPtsSecondsFromScan"] == 0 assert metadata_dict["maxPtsSecondsFromScan"] == 13.013 - def test_audio_get_json_metadata(self): - decoder = create_from_file(str(NASA_AUDIO_MP3.path)) - metadata = get_json_metadata(decoder) - metadata_dict = json.loads(metadata) - assert metadata_dict["durationSeconds"] == pytest.approx(13.248, abs=0.01) - def test_get_ffmpeg_version(self): ffmpeg_dict = get_ffmpeg_library_versions() assert len(ffmpeg_dict["libavcodec"]) == 3 @@ -620,6 +616,8 @@ def test_cuda_decoder(self): duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 ) + +class TestAudioOps: @pytest.mark.parametrize( "method", ( @@ -628,22 +626,192 @@ def test_cuda_decoder(self): partial(get_frames_in_range, start=4, stop=5), partial(get_frame_at_pts, seconds=2), partial(get_frames_by_pts, timestamps=[0, 1.5]), - partial(get_frames_by_pts_in_range, start_seconds=0, stop_seconds=1), + partial(get_next_frame), ), ) def test_audio_bad_method(self, method): - decoder = create_from_file(str(NASA_AUDIO_MP3.path), seek_mode="approximate") + decoder = create_from_file(str(NASA_AUDIO.path), seek_mode="approximate") add_audio_stream(decoder) with pytest.raises(RuntimeError, match="The method you called isn't supported"): method(decoder) def test_audio_bad_seek_mode(self): - decoder = create_from_file(str(NASA_AUDIO_MP3.path), seek_mode="exact") + decoder = create_from_file(str(NASA_AUDIO.path), seek_mode="exact") with pytest.raises( RuntimeError, match="seek_mode must be 'approximate' for audio" ): add_audio_stream(decoder) + @pytest.mark.parametrize( + "range", + ( + "begin_to_end", + "begin_to_None", + "begin_to_beyond_end", + "at_frame_boundaries", + "not_at_frame_boundaries", + ), + ) + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_get_frames_by_pts_in_range_audio(self, range, asset): + if range == "begin_to_end": + start_seconds, stop_seconds = 0, asset.duration_seconds + elif range == "begin_to_None": + start_seconds, stop_seconds = 0, None + elif range == "begin_to_beyond_end": + start_seconds, stop_seconds = 0, asset.duration_seconds + 10 + elif range == "at_frame_boundaries": + start_seconds = asset.get_frame_info(idx=10).pts_seconds + stop_seconds = asset.get_frame_info(idx=40).pts_seconds + else: + assert range == "not_at_frame_boundaries" + start_frame_info = asset.get_frame_info(idx=10) + stop_frame_info = asset.get_frame_info(idx=40) + start_seconds = start_frame_info.pts_seconds + ( + start_frame_info.duration_seconds / 2 + ) + stop_seconds = stop_frame_info.pts_seconds + ( + stop_frame_info.duration_seconds / 2 + ) + + ref_start_index = asset.get_frame_index(pts_seconds=start_seconds) + if range == "begin_to_None": + ref_stop_index = ( + asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1 + ) + elif range == "at_frame_boundaries": + ref_stop_index = asset.get_frame_index(pts_seconds=stop_seconds) + else: + ref_stop_index = asset.get_frame_index(pts_seconds=stop_seconds) + 1 + reference_frames = asset.get_frame_data_by_range( + start=ref_start_index, + stop=ref_stop_index, + ) + + decoder = create_from_file(str(asset.path), seek_mode="approximate") + add_audio_stream(decoder) + + frames = get_frames_by_pts_in_range_audio( + decoder, start_seconds=start_seconds, stop_seconds=stop_seconds + ) + + torch.testing.assert_close(frames, reference_frames) + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_decode_epsilon_range(self, asset): + decoder = create_from_file(str(asset.path), seek_mode="approximate") + add_audio_stream(decoder) + + start_seconds = 5 + frames = get_frames_by_pts_in_range_audio( + decoder, start_seconds=start_seconds, stop_seconds=start_seconds + 1e-5 + ) + torch.testing.assert_close( + frames, + asset.get_frame_data_by_index( + asset.get_frame_index(pts_seconds=start_seconds) + ), + ) + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_decode_just_one_frame_at_boundaries(self, asset): + decoder = create_from_file(str(asset.path), seek_mode="approximate") + add_audio_stream(decoder) + + start_seconds = asset.get_frame_info(idx=10).pts_seconds + stop_seconds = asset.get_frame_info(idx=11).pts_seconds + frames = get_frames_by_pts_in_range_audio( + decoder, start_seconds=start_seconds, stop_seconds=stop_seconds + ) + torch.testing.assert_close( + frames, + asset.get_frame_data_by_index( + asset.get_frame_index(pts_seconds=start_seconds) + ), + ) + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_decode_start_equal_stop(self, asset): + decoder = create_from_file(str(asset.path), seek_mode="approximate") + add_audio_stream(decoder) + frames = get_frames_by_pts_in_range_audio( + decoder, start_seconds=1, stop_seconds=1 + ) + assert frames.shape == (0,) + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_multiple_calls(self, asset): + # Ensure that multiple calls are OK as long as we're decoding + # "sequentially", i.e. we don't require a backwards seek. + # And ensure a proper error is raised in such case. + # TODO-AUDIO We shouldn't error, we should just implement the seeking + # back to the beginning of the stream. + + def get_reference_frames(start_seconds, stop_seconds): + # This stateless helper exists for convenience, to avoid + # complicating this test with pts-to-index conversions. Eventually + # we should remove it and just rely on the asset's methods. + # Using this helper is OK for now: we're comparing a decoder which + # seeks multiple times with a decoder which seeks only once (the one + # here, treated as the reference) + decoder = create_from_file(str(asset.path), seek_mode="approximate") + add_audio_stream(decoder) + + return get_frames_by_pts_in_range_audio( + decoder, start_seconds=start_seconds, stop_seconds=stop_seconds + ) + + decoder = create_from_file(str(asset.path), seek_mode="approximate") + add_audio_stream(decoder) + + start_seconds, stop_seconds = 0, 2 + frames = get_frames_by_pts_in_range_audio( + decoder, start_seconds=start_seconds, stop_seconds=stop_seconds + ) + torch.testing.assert_close( + frames, get_reference_frames(start_seconds, stop_seconds) + ) + + # "seeking" forward is OK + start_seconds, stop_seconds = 3, 4 + frames = get_frames_by_pts_in_range_audio( + decoder, start_seconds=start_seconds, stop_seconds=stop_seconds + ) + torch.testing.assert_close( + frames, get_reference_frames(start_seconds, stop_seconds) + ) + + # Starting at the frame immediately after the previous one is OK + index_of_frame_at_4 = asset.get_frame_index(pts_seconds=4) + start_seconds, stop_seconds = ( + asset.get_frame_info(idx=index_of_frame_at_4 + 1).pts_seconds, + 5, + ) + frames = get_frames_by_pts_in_range_audio( + decoder, start_seconds=start_seconds, stop_seconds=stop_seconds + ) + torch.testing.assert_close( + frames, get_reference_frames(start_seconds, stop_seconds) + ) + + # but starting immediately on the same frame raises + expected_match = "Audio decoder cannot seek backwards" + with pytest.raises(RuntimeError, match=expected_match): + get_frames_by_pts_in_range_audio( + decoder, start_seconds=stop_seconds, stop_seconds=6 + ) + + with pytest.raises(RuntimeError, match=expected_match): + get_frames_by_pts_in_range_audio( + decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6 + ) + + # and seeking backwards doesn't work either + with pytest.raises(RuntimeError, match=expected_match): + frames = get_frames_by_pts_in_range_audio( + decoder, start_seconds=0, stop_seconds=2 + ) + if __name__ == "__main__": pytest.main() diff --git a/test/resources/nasa_13013.mp4.audio.mp3.stream0.all_frames.pt b/test/resources/nasa_13013.mp4.audio.mp3.stream0.all_frames.pt new file mode 100644 index 000000000..ffcd57bed Binary files /dev/null and b/test/resources/nasa_13013.mp4.audio.mp3.stream0.all_frames.pt differ diff --git a/test/resources/nasa_13013.mp4.stream4.all_frames.pt b/test/resources/nasa_13013.mp4.stream4.all_frames.pt new file mode 100644 index 000000000..e2362375f Binary files /dev/null and b/test/resources/nasa_13013.mp4.stream4.all_frames.pt differ diff --git a/test/utils.py b/test/utils.py index 9186e6608..c6ce0ec8a 100644 --- a/test/utils.py +++ b/test/utils.py @@ -4,8 +4,8 @@ import pathlib import sys -from dataclasses import dataclass -from typing import Dict, Optional, Union +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Union import numpy as np import pytest @@ -90,11 +90,6 @@ def _get_file_path(filename: str) -> pathlib.Path: return pathlib.Path(__file__).parent / "resources" / filename -def _load_tensor_from_file(filename: str) -> torch.Tensor: - file_path = _get_file_path(filename) - return torch.load(file_path, weights_only=True).permute(2, 0, 1) - - @dataclass class TestFrameInfo: pts_seconds: float @@ -113,6 +108,7 @@ class TestAudioStreamInfo: sample_rate: int num_channels: int duration_seconds: float + num_frames: int @dataclass @@ -171,12 +167,7 @@ def to_tensor(self) -> torch.Tensor: def get_frame_data_by_index( self, idx: int, *, stream_index: Optional[int] = None ) -> torch.Tensor: - if stream_index is None: - stream_index = self.default_stream_index - - return _load_tensor_from_file( - f"{self.filename}.stream{stream_index}.frame{idx:06d}.pt" - ) + raise NotImplementedError("Override in child classes") def get_frame_data_by_range( self, @@ -186,11 +177,7 @@ def get_frame_data_by_range( *, stream_index: Optional[int] = None, ) -> torch.Tensor: - tensors = [ - self.get_frame_data_by_index(i, stream_index=stream_index) - for i in range(start, stop, step) - ] - return torch.stack(tensors) + raise NotImplementedError("Override in child classes") def get_pts_seconds_by_range( self, @@ -244,6 +231,32 @@ def empty_duration_seconds(self) -> torch.Tensor: @dataclass class TestVideo(TestContainerFile): + """Base class for the *video* streams of a video container""" + + def get_frame_data_by_index( + self, idx: int, *, stream_index: Optional[int] = None + ) -> torch.Tensor: + if stream_index is None: + stream_index = self.default_stream_index + + file_path = _get_file_path( + f"{self.filename}.stream{stream_index}.frame{idx:06d}.pt" + ) + return torch.load(file_path, weights_only=True).permute(2, 0, 1) + + def get_frame_data_by_range( + self, + start: int, + stop: int, + step: int = 1, + *, + stream_index: Optional[int] = None, + ) -> torch.Tensor: + tensors = [ + self.get_frame_data_by_index(i, stream_index=stream_index) + for i in range(start, stop, step) + ] + return torch.stack(tensors) @property def width(self) -> int: @@ -303,15 +316,115 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor: frames={}, # Automatically loaded from json file ) -NASA_AUDIO_MP3 = TestContainerFile( + +@dataclass +class TestAudio(TestContainerFile): + """Base class for the *audio* streams of a container (potentially a video), + or a pure audio file""" + + stream_infos: Dict[int, TestAudioStreamInfo] + # stream_index -> list of 2D frame tensors of shape (num_channels, num_samples_in_that_frame) + # num_samples_in_that_frame isn't necessarily constant for a given stream. + _reference_frames: Dict[int, List[torch.Tensor]] = field(default_factory=dict) + + # Storing each individual frame is too expensive for audio, because there's + # a massive overhead in the binary format saved by pytorch. Saving all the + # frames in a single file uses 1.6MB while saving all frames in individual + # files uses 302MB (yes). + # So we store the reference frames in a single file, and load/cache those + # when the TestAudio instance is created. + def __post_init__(self): + super().__post_init__() + for stream_index in self.stream_infos: + frames_data_path = _get_file_path( + f"{self.filename}.stream{stream_index}.all_frames.pt" + ) + + self._reference_frames[stream_index] = torch.load( + frames_data_path, weights_only=True + ) + + def get_frame_data_by_index( + self, idx: int, *, stream_index: Optional[int] = None + ) -> torch.Tensor: + if stream_index is None: + stream_index = self.default_stream_index + + return self._reference_frames[stream_index][idx] + + def get_frame_data_by_range( + self, + start: int, + stop: int, + step: int = 1, + *, + stream_index: Optional[int] = None, + ) -> torch.Tensor: + tensors = [ + self.get_frame_data_by_index(i, stream_index=stream_index) + for i in range(start, stop, step) + ] + return torch.cat(tensors, dim=-1) + + def get_frame_index( + self, *, pts_seconds: float, stream_index: Optional[int] = None + ) -> int: + if stream_index is None: + stream_index = self.default_stream_index + + if pts_seconds <= self.frames[stream_index][0].pts_seconds: + # Special case for e.g. NASA_AUDIO_MP3 whose first frame's pts is + # 0.13~, not 0. + return 0 + try: + # Could use bisect() to maek this faster if needed + return next( + frame_index + for (frame_index, frame_info) in self.frames[stream_index].items() + if frame_info.pts_seconds + <= pts_seconds + < frame_info.pts_seconds + frame_info.duration_seconds + ) + except StopIteration: + return len(self.frames[stream_index]) - 1 + + @property + def sample_rate(self) -> int: + return self.stream_infos[self.default_stream_index].sample_rate + + @property + def num_channels(self) -> int: + return self.stream_infos[self.default_stream_index].num_channels + + @property + def duration_seconds(self) -> float: + return self.stream_infos[self.default_stream_index].duration_seconds + + @property + def num_frames(self) -> int: + return self.stream_infos[self.default_stream_index].num_frames + + +NASA_AUDIO_MP3 = TestAudio( filename="nasa_13013.mp4.audio.mp3", default_stream_index=0, + frames={}, # Automatically loaded from json file stream_infos={ 0: TestAudioStreamInfo( - sample_rate=8_000, num_channels=2, duration_seconds=13.248 + sample_rate=8_000, num_channels=2, duration_seconds=13.248, num_frames=183 ) }, +) + +NASA_AUDIO = TestAudio( + filename="nasa_13013.mp4", + default_stream_index=4, frames={}, # Automatically loaded from json file + stream_infos={ + 4: TestAudioStreamInfo( + sample_rate=16_000, num_channels=2, duration_seconds=13.056, num_frames=204 + ) + }, ) H265_VIDEO = TestVideo(