Skip to content

Audio decoding support: range-based core API #538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ae15304
Add basic range support
NicolasHug Mar 6, 2025
29e0b8d
Add more tests
NicolasHug Mar 6, 2025
cad69da
Merge branch 'main' of github.com:pytorch/torchcodec into audioooooooo
NicolasHug Mar 6, 2025
04f6282
Add separate audio decoding method
NicolasHug Mar 7, 2025
f8dfcda
Merge branch 'main' of github.com:pytorch/torchcodec into audioooooooo
NicolasHug Mar 7, 2025
da40954
More stuff
NicolasHug Mar 7, 2025
3881586
Cleanups
NicolasHug Mar 7, 2025
82bea4a
Remove old code
NicolasHug Mar 7, 2025
ce12f03
More validation, more tests
NicolasHug Mar 8, 2025
59b0d15
remove next() support
NicolasHug Mar 8, 2025
f4bed23
Rename
NicolasHug Mar 8, 2025
fe04cd2
Add support for None stop_seconds
NicolasHug Mar 8, 2025
98fee85
Remove pre-alloc logic
NicolasHug Mar 8, 2025
d2357fe
Add test
NicolasHug Mar 8, 2025
f3b56f8
Add proper error when backward seek is neede
NicolasHug Mar 8, 2025
5f2800a
Cleanup
NicolasHug Mar 8, 2025
2f020f2
Add TODO
NicolasHug Mar 9, 2025
0c11f72
Put back original compilation flags
NicolasHug Mar 9, 2025
de4facc
Fix
NicolasHug Mar 9, 2025
b5f2df0
nit
NicolasHug Mar 9, 2025
09e6f44
Oops, fix
NicolasHug Mar 9, 2025
3d955c1
Add case for start=stop
NicolasHug Mar 9, 2025
d791d2a
Simplify
NicolasHug Mar 9, 2025
0c0f62b
Don't use a lambda
NicolasHug Mar 9, 2025
893c358
Merge branch 'main' of github.com:pytorch/torchcodec into audioooooooo
NicolasHug Mar 11, 2025
c35ae47
Fix
NicolasHug Mar 11, 2025
c453a3c
Address comments
NicolasHug Mar 12, 2025
dafb927
Add comment
NicolasHug Mar 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/torchcodec/decoders/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,22 @@ int64_t getDuration(const AVFrame* frame) {
#endif
}

int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext) {
int getNumChannels(const AVFrame* avFrame) {
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
int numChannels = avCodecContext->ch_layout.nb_channels;
return avFrame->ch_layout.nb_channels;
#else
int numChannels = avCodecContext->channels;
return av_get_channel_layout_nb_channels(avFrame->channel_layout);
#endif
}

return static_cast<int64_t>(numChannels);
int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
return avCodecContext->ch_layout.nb_channels;
#else
return avCodecContext->channels;
#endif
}

AVIOBytesContext::AVIOBytesContext(
Expand Down
3 changes: 2 additions & 1 deletion src/torchcodec/decoders/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
int64_t getDuration(const UniqueAVFrame& frame);
int64_t getDuration(const AVFrame* frame);

int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext);
int getNumChannels(const AVFrame* avFrame);
int getNumChannels(const UniqueAVCodecContext& avCodecContext);

// Returns true if sws_scale can handle unaligned data.
bool canSwsScaleHandleUnalignedData();
Expand Down
123 changes: 116 additions & 7 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cstdint>
#include <cstdio>
#include <iostream>
#include <limits>
#include <sstream>
#include <stdexcept>
#include <string_view>
Expand Down Expand Up @@ -552,7 +553,8 @@ void VideoDecoder::addAudioStream(int streamIndex) {
containerMetadata_.allStreamMetadata[activeStreamIndex_];
streamMetadata.sampleRate =
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
streamMetadata.numChannels = getNumChannels(streamInfo.codecContext);
streamMetadata.numChannels =
static_cast<int64_t>(getNumChannels(streamInfo.codecContext));
}

// --------------------------------------------------------------------------
Expand All @@ -567,6 +569,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {

VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
std::optional<torch::Tensor> preAllocatedOutputTensor) {
validateActiveStream(AVMEDIA_TYPE_VIDEO);
AVFrameStream avFrameStream = decodeAVFrame(
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
Expand Down Expand Up @@ -685,6 +688,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
}

VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
validateActiveStream(AVMEDIA_TYPE_VIDEO);
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
double frameStartTime =
ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase);
Expand Down Expand Up @@ -757,7 +761,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
double startSeconds,
double stopSeconds) {
validateActiveStream(AVMEDIA_TYPE_VIDEO);

const auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];
TORCH_CHECK(
Expand Down Expand Up @@ -835,6 +838,68 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
return frameBatchOutput;
}

torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
double startSeconds,
std::optional<double> stopSecondsOptional) {
validateActiveStream(AVMEDIA_TYPE_AUDIO);

double stopSeconds =
stopSecondsOptional.value_or(std::numeric_limits<double>::max());

TORCH_CHECK(
startSeconds <= stopSeconds,
"Start seconds (" + std::to_string(startSeconds) +
") must be less than or equal to stop seconds (" +
std::to_string(stopSeconds) + ".");

if (startSeconds == stopSeconds) {
// For consistency with video
return torch::empty({0});
}

StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];

// TODO-AUDIO This essentially enforce that we don't need to seek (backwards).
// We should remove it and seek back to the stream's beginning when needed.
// See test_multiple_calls
TORCH_CHECK(
streamInfo.lastDecodedAvFramePts +
streamInfo.lastDecodedAvFrameDuration <=
secondsToClosestPts(startSeconds, streamInfo.timeBase),
"Audio decoder cannot seek backwards, or start from the last decoded frame.");

setCursorPtsInSeconds(startSeconds);

// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
// cat(). This would save a copy. We know the duration of the output and the
// sample rate, so in theory we know the number of output samples.
std::vector<torch::Tensor> tensors;

auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase);
auto finished = false;
while (!finished) {
try {
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
return cursor_ < avFrame->pts + getDuration(avFrame);
});
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
tensors.push_back(frameOutput.data);
} catch (const EndOfFileException& e) {
finished = true;
}
Copy link
Member Author

@NicolasHug NicolasHug Mar 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q about C++ best practices: I realize we're already doing it in a few places (like custom ops), but is it a good practice to use exceptions for control flow? Maybe the reachedEOF flag from decodeAVFrame() could be a stateful attribute instead? (not that I find statefulness appealing either!)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we shouldn't ordinarily decode past the end of a file, I think it makes sense for us to throw exceptions when we reach the end of a file. Here, we're not really using an exception for control flow per se. That is, we're not trying to read past the end of the file, we just have to handle the case that we might.

With that said, I do find it more natural when the "normal" stop conditions are explicitly part of the while loop's condition, as opposed to setting a boolean inside the loop. But since you're depending on the internal state of the decoder to know the last decoded frame info, I don't know if that's possible. When I implemented something similar, I ended up using a priming read to get around this problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, nit: for local variables used in a small space with a clear purpose, I prefer shorter names. So even stop as the boolean would make this easier for me to read.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I'll use finished instead of stop, because we already have local variables named stopPts and stopSeconds in this function.


// If stopSeconds is in [begin, end] of the last decoded frame, we should
// stop decoding more frames. Note that if we were to use [begin, end),
// which may seem more natural, then we would decode the frame starting at
// stopSeconds, which isn't what we want!
auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
streamInfo.lastDecodedAvFrameDuration;
finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts &&
(stopPts <= lastDecodedAvFrameEnd);
}
return torch::cat(tensors, 1);
}

// --------------------------------------------------------------------------
// SEEKING APIs
// --------------------------------------------------------------------------
Expand Down Expand Up @@ -871,6 +936,10 @@ I P P P I P P P I P P I P P I P
(2) is more efficient than (1) if there is an I frame between x and y.
*/
bool VideoDecoder::canWeAvoidSeeking() const {
const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
return true;
}
int64_t lastDecodedAvFramePts =
streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;
if (cursor_ < lastDecodedAvFramePts) {
Expand All @@ -897,7 +966,7 @@ bool VideoDecoder::canWeAvoidSeeking() const {
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
// the comment of canWeAvoidSeeking() for details.
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
validateActiveStream(AVMEDIA_TYPE_VIDEO);
validateActiveStream();
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];

decodeStats_.numSeeksAttempted++;
Expand Down Expand Up @@ -942,7 +1011,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {

VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
std::function<bool(AVFrame*)> filterFunction) {
validateActiveStream(AVMEDIA_TYPE_VIDEO);
validateActiveStream();

resetDecodeStats();

Expand Down Expand Up @@ -1071,13 +1140,14 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
AVFrame* avFrame = avFrameStream.avFrame.get();
frameOutput.streamIndex = streamIndex;
auto& streamInfo = streamInfos_[streamIndex];
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO);
frameOutput.ptsSeconds = ptsToSeconds(
avFrame->pts, formatContext_->streams[streamIndex]->time_base);
frameOutput.durationSeconds = ptsToSeconds(
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
// TODO: we should fold preAllocatedOutputTensor into AVFrameStream.
if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
convertAudioAVFrameToFrameOutputOnCPU(
avFrameStream, frameOutput, preAllocatedOutputTensor);
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
convertAVFrameToFrameOutputOnCPU(
avFrameStream, frameOutput, preAllocatedOutputTensor);
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) {
Expand Down Expand Up @@ -1253,6 +1323,45 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
}

void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
VideoDecoder::AVFrameStream& avFrameStream,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
TORCH_CHECK(
!preAllocatedOutputTensor.has_value(),
"pre-allocated audio tensor not supported yet.");

const AVFrame* avFrame = avFrameStream.avFrame.get();

auto numSamples = avFrame->nb_samples; // per channel
auto numChannels = getNumChannels(avFrame);
torch::Tensor outputData =
torch::empty({numChannels, numSamples}, torch::kFloat32);

AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
// TODO-AUDIO Implement all formats.
switch (format) {
case AV_SAMPLE_FMT_FLTP: {
uint8_t* outputChannelData = static_cast<uint8_t*>(outputData.data_ptr());
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
for (auto channel = 0; channel < numChannels;
++channel, outputChannelData += numBytesPerChannel) {
memcpy(
outputChannelData,
avFrame->extended_data[channel],
numBytesPerChannel);
}
break;
}
default:
TORCH_CHECK(
false,
"Unsupported audio format (yet!): ",
av_get_sample_fmt_name(format));
}
frameOutput.data = outputData;
}

// --------------------------------------------------------------------------
// OUTPUT ALLOCATION AND SHAPE CONVERSION
// --------------------------------------------------------------------------
Expand Down
10 changes: 10 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ class VideoDecoder {
double startSeconds,
double stopSeconds);

// TODO-AUDIO: Should accept sampleRate
torch::Tensor getFramesPlayedInRangeAudio(
double startSeconds,
std::optional<double> stopSecondsOptional = std::nullopt);

class EndOfFileException : public std::runtime_error {
public:
explicit EndOfFileException(const std::string& msg)
Expand Down Expand Up @@ -379,6 +384,11 @@ class VideoDecoder {
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

void convertAudioAVFrameToFrameOutputOnCPU(
AVFrameStream& avFrameStream,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame);

int convertAVFrameToTensorUsingSwsScale(
Expand Down
14 changes: 12 additions & 2 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ namespace facebook::torchcodec {
// https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#readme
TORCH_LIBRARY(torchcodec_ns, m) {
m.impl_abstract_pystub(
"torchcodec.decoders._core.video_decoder_ops",
"//pytorch/torchcodec:torchcodec");
"torchcodec.decoders._core.ops", "//pytorch/torchcodec:torchcodec");
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
m.def(
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
Expand All @@ -48,6 +47,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
"get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> Tensor");
m.def(
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
Expand Down Expand Up @@ -289,6 +290,14 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
return makeOpsFrameBatchOutput(result);
}

torch::Tensor get_frames_by_pts_in_range_audio(
at::Tensor& decoder,
double start_seconds,
std::optional<double> stop_seconds) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
return videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds);
}

std::string quoteValue(const std::string& value) {
return "\"" + value + "\"";
}
Expand Down Expand Up @@ -540,6 +549,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("get_frames_at_indices", &get_frames_at_indices);
m.impl("get_frames_in_range", &get_frames_in_range);
m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range);
m.impl("get_frames_by_pts_in_range_audio", &get_frames_by_pts_in_range_audio);
m.impl("get_frames_by_pts", &get_frames_by_pts);
m.impl("_test_frame_pts_equality", &_test_frame_pts_equality);
m.impl(
Expand Down
5 changes: 5 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
double start_seconds,
double stop_seconds);

torch::Tensor get_frames_by_pts_in_range_audio(
at::Tensor& decoder,
double start_seconds,
std::optional<double> stop_seconds = std::nullopt);

// For testing only. We need to implement this operation as a core library
// function because what we're testing is round-tripping pts values as
// double-precision floating point numbers from C++ to Python and back to C++.
Expand Down
3 changes: 2 additions & 1 deletion src/torchcodec/decoders/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
get_container_metadata_from_header,
VideoStreamMetadata,
)
from .video_decoder_ops import (
from .ops import (
_add_video_stream,
_get_key_frame_indices,
_test_frame_pts_equality,
Expand All @@ -27,6 +27,7 @@
get_frames_at_indices,
get_frames_by_pts,
get_frames_by_pts_in_range,
get_frames_by_pts_in_range_audio,
get_frames_in_range,
get_json_metadata,
get_next_frame,
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/_core/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch

from torchcodec.decoders._core.video_decoder_ops import (
from torchcodec.decoders._core.ops import (
_get_container_json_metadata,
_get_stream_json_metadata,
create_from_file,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def load_torchcodec_extension():
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
get_frames_by_pts_in_range_audio = (
torch.ops.torchcodec_ns.get_frames_by_pts_in_range_audio.default
)
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default
_test_frame_pts_equality = torch.ops.torchcodec_ns._test_frame_pts_equality.default
_get_container_json_metadata = (
Expand Down Expand Up @@ -262,6 +265,17 @@ def get_frames_by_pts_in_range_abstract(
)


@register_fake("torchcodec_ns::get_frames_by_pts_in_range_audio")
def get_frames_by_pts_in_range_audio_abstract(
decoder: torch.Tensor,
*,
start_seconds: float,
stop_seconds: Optional[float] = None,
) -> torch.Tensor:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return torch.empty(image_size)


@register_fake("torchcodec_ns::_get_key_frame_indices")
def get_key_frame_indices_abstract(decoder: torch.Tensor) -> torch.Tensor:
return torch.empty([], dtype=torch.int)
Expand Down
Loading
Loading