From e4f05ce1df1911e6594165bcd107fd5535f97fa1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 29 Apr 2025 17:48:23 +0100 Subject: [PATCH 1/3] Enforce that encode() cannot be called twice --- src/torchcodec/_core/Encoder.cpp | 9 ++++++--- src/torchcodec/_core/Encoder.h | 2 ++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 114e8600..9b75f4fa 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -208,9 +208,12 @@ torch::Tensor AudioEncoder::encodeToTensor() { } void AudioEncoder::encode() { - // TODO-ENCODING: Need to check, but consecutive calls to encode() are - // probably invalid. We can address this once we (re)design the public and - // private encoding APIs. + // To be on the safe side we enforce that encode() can only be called once on + // an encoder object. Whether this is actually necessary is unknown, so this + // may be relaxed if needed. + TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice."); + encodeWasCalled_ = true; + UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); // Default to 256 like in torchaudio diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 17f09d59..bf31c31b 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -49,5 +49,7 @@ class AudioEncoder { // Stores the AVIOContext for the output tensor buffer. std::unique_ptr avioContextHolder_; + + bool encodeWasCalled_ = false; }; } // namespace facebook::torchcodec From 8a02d262f7d63ce89d9799c4916dd00e2d782e29 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 30 Apr 2025 11:14:35 +0100 Subject: [PATCH 2/3] Address other todos --- src/torchcodec/_core/Encoder.cpp | 44 ++++++++++++++++------------- src/torchcodec/_core/custom_ops.cpp | 2 -- test/test_ops.py | 9 ++++-- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index f5ecfbf4..1c876f4e 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -14,6 +14,18 @@ torch::Tensor validateWf(torch::Tensor wf) { "waveform must have float32 dtype, got ", wf.dtype()); TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim()); + + // We enforce this, but if we get user reports we should investigate whether + // that's actually needed. + int numChannels = static_cast(wf.sizes()[0]); + TORCH_CHECK( + numChannels <= AV_NUM_DATA_POINTERS, + "Trying to encode ", + numChannels, + " channels, but FFmpeg only supports ", + AV_NUM_DATA_POINTERS, + " channels per frame."); + return wf.contiguous(); } @@ -164,18 +176,7 @@ void AudioEncoder::initializeEncoder( // what the `.sample_fmt` defines. avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec); - int numChannels = static_cast(wf_.sizes()[0]); - TORCH_CHECK( - // TODO-ENCODING is this even true / needed? We can probably support more - // with non-planar data? - numChannels <= AV_NUM_DATA_POINTERS, - "Trying to encode ", - numChannels, - " channels, but FFmpeg only supports ", - AV_NUM_DATA_POINTERS, - " channels per frame."); - - setDefaultChannelLayout(avCodecContext_, numChannels); + setDefaultChannelLayout(avCodecContext_, static_cast(wf_.sizes()[0])); int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); TORCH_CHECK( @@ -325,14 +326,17 @@ void AudioEncoder::encodeInnerLoop( ReferenceAVPacket packet(autoAVPacket); status = avcodec_receive_packet(avCodecContext_.get(), packet.get()); if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) { - // TODO-ENCODING this is from TorchAudio, probably needed, but not sure. - // if (status == AVERROR_EOF) { - // status = av_interleaved_write_frame(avFormatContext_.get(), - // nullptr); TORCH_CHECK( - // status == AVSUCCESS, - // "Failed to flush packet ", - // getFFMPEGErrorStringFromErrorCode(status)); - // } + if (status == AVERROR_EOF) { + // Flush the packets that were potentially buffered by + // av_interleaved_write_frame(). See corresponding block in + // TorchAudio: + // https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21 + status = av_interleaved_write_frame(avFormatContext_.get(), nullptr); + TORCH_CHECK( + status == AVSUCCESS, + "Failed to flush packet: ", + getFFMPEGErrorStringFromErrorCode(status)); + } return; } TORCH_CHECK( diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 2f470617..813c53a7 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -394,8 +394,6 @@ void encode_audio_to_file( .encode(); } -// TODO-ENCODING is "format" a good parameter name?? It kinda conflicts with -// "sample_format" which we may eventually want to expose. at::Tensor encode_audio_to_tensor( const at::Tensor wf, int64_t sample_rate, diff --git a/test/test_ops.py b/test/test_ops.py index ddca330a..6e53d27b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1132,11 +1132,11 @@ def test_bad_input(self, tmp_path): with pytest.raises(RuntimeError, match="No such file or directory"): encode_audio_to_file( - wf=torch.rand(10, 10), sample_rate=10, filename="./bad/path.mp3" + wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3" ) with pytest.raises(RuntimeError, match="Check the desired extension"): encode_audio_to_file( - wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension" + wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension" ) with pytest.raises(RuntimeError, match="invalid sample rate=10"): @@ -1153,6 +1153,11 @@ def test_bad_input(self, tmp_path): bit_rate=-1, # bad ) + with pytest.raises(RuntimeError, match="Trying to encode 10 channels"): + encode_audio_to_file( + wf=torch.rand(10, 20), sample_rate=10, filename="doesnt_matter" + ) + @pytest.mark.parametrize( "encode_method", (encode_audio_to_file, encode_audio_to_tensor) ) From 45b8fe1d64ab2968abc851b3b9f7470d27f7ab3d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 30 Apr 2025 15:28:56 +0100 Subject: [PATCH 3/3] Encode into file-like --- src/torchcodec/_core/AVIOContextHolder.cpp | 8 +++--- src/torchcodec/_core/AVIOFileLikeContext.cpp | 10 +++++++- src/torchcodec/_core/AVIOFileLikeContext.h | 1 + src/torchcodec/_core/Encoder.cpp | 26 ++++++++++++++++++++ src/torchcodec/_core/Encoder.h | 8 ++++++ src/torchcodec/_core/__init__.py | 1 + src/torchcodec/_core/ops.py | 11 +++++++++ src/torchcodec/_core/pybind_ops.cpp | 19 ++++++++++++++ 8 files changed, 79 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/_core/AVIOContextHolder.cpp b/src/torchcodec/_core/AVIOContextHolder.cpp index e0462c28..99db8988 100644 --- a/src/torchcodec/_core/AVIOContextHolder.cpp +++ b/src/torchcodec/_core/AVIOContextHolder.cpp @@ -23,10 +23,10 @@ void AVIOContextHolder::createAVIOContext( buffer != nullptr, "Failed to allocate buffer of size " + std::to_string(bufferSize)); - TORCH_CHECK( - (seek != nullptr) && ((write != nullptr) ^ (read != nullptr)), - "seek method must be defined, and either write or read must be defined. " - "But not both!") + // TORCH_CHECK( + // (seek != nullptr) && ((write != nullptr) ^ (read != nullptr)), + // "seek method must be defined, and either write or read must be + // defined. " "But not both!") avioContext_.reset(avioAllocContext( buffer, bufferSize, diff --git a/src/torchcodec/_core/AVIOFileLikeContext.cpp b/src/torchcodec/_core/AVIOFileLikeContext.cpp index 5497f89b..3870e5a1 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/_core/AVIOFileLikeContext.cpp @@ -23,7 +23,7 @@ AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); } - createAVIOContext(&read, nullptr, &seek, &fileLike_); + createAVIOContext(&read, &write, &seek, &fileLike_); } int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { @@ -77,4 +77,12 @@ int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) { return py::cast((*fileLike)->attr("seek")(offset, whence)); } +int AVIOFileLikeContext::write(void* opaque, const uint8_t* buf, int buf_size) { + auto fileLike = static_cast(opaque); + py::gil_scoped_acquire gil; + py::bytes bytes_obj(reinterpret_cast(buf), buf_size); + + return py::cast((*fileLike)->attr("write")(bytes_obj)); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/AVIOFileLikeContext.h b/src/torchcodec/_core/AVIOFileLikeContext.h index 3e80f1c6..00948515 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/_core/AVIOFileLikeContext.h @@ -24,6 +24,7 @@ class AVIOFileLikeContext : public AVIOContextHolder { private: static int read(void* opaque, uint8_t* buf, int buf_size); static int64_t seek(void* opaque, int64_t offset, int whence); + static int write(void* opaque, const uint8_t* buf, int buf_size); // Note that we dynamically allocate the Python object because we need to // strictly control when its destructor is called. We must hold the GIL diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 1c876f4e..9277ccaa 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -1,6 +1,7 @@ #include #include "src/torchcodec/_core/AVIOBytesContext.h" +#include "src/torchcodec/_core/AVIOContextHolder.h" #include "src/torchcodec/_core/Encoder.h" #include "torch/types.h" @@ -148,6 +149,31 @@ AudioEncoder::AudioEncoder( initializeEncoder(sampleRate, bitRate); } +// TODO this sucks, shouldn't need 2 separate constructors for AVIOContextHolder +AudioEncoder::AudioEncoder( + const torch::Tensor wf, + int sampleRate, + std::string_view formatName, + std::unique_ptr avioContextHolder, + std::optional bitRate) + : wf_(validateWf(wf)), avioContextHolderrrr_(std::move(avioContextHolder)) { + setFFmpegLogLevel(); + AVFormatContext* avFormatContext = nullptr; + int status = avformat_alloc_output_context2( + &avFormatContext, nullptr, formatName.data(), nullptr); + + TORCH_CHECK( + avFormatContext != nullptr, + "Couldn't allocate AVFormatContext. ", + "Check the desired extension? ", + getFFMPEGErrorStringFromErrorCode(status)); + avFormatContext_.reset(avFormatContext); + + avFormatContext_->pb = avioContextHolderrrr_->getAVIOContext(); + + initializeEncoder(sampleRate, bitRate); +} + void AudioEncoder::initializeEncoder( int sampleRate, std::optional bitRate) { diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index bf31c31b..37d9c703 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -1,6 +1,7 @@ #pragma once #include #include "src/torchcodec/_core/AVIOBytesContext.h" +#include "src/torchcodec/_core/AVIOContextHolder.h" #include "src/torchcodec/_core/FFMPEGCommon.h" namespace facebook::torchcodec { @@ -28,6 +29,12 @@ class AudioEncoder { std::string_view formatName, std::unique_ptr avioContextHolder, std::optional bitRate = std::nullopt); + AudioEncoder( + const torch::Tensor wf, + int sampleRate, + std::string_view formatName, + std::unique_ptr avioContextHolder, + std::optional bitRate = std::nullopt); void encode(); torch::Tensor encodeToTensor(); @@ -49,6 +56,7 @@ class AudioEncoder { // Stores the AVIOContext for the output tensor buffer. std::unique_ptr avioContextHolder_; + std::unique_ptr avioContextHolderrrr_; // EWWWWW bool encodeWasCalled_ = false; }; diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index 77fc7b85..3d340bff 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -23,6 +23,7 @@ create_from_file_like, create_from_tensor, encode_audio_to_file, + encode_audio_to_file_like, encode_audio_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index e9b4faec..e1205794 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -153,6 +153,17 @@ def create_from_file_like( return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode)) +def encode_audio_to_file_like( + file_like: Union[io.RawIOBase, io.BufferedReader], + wf: torch.Tensor, + sample_rate: int, + format: str, + bit_rate: Optional[int] = None, +): + assert _pybind_ops is not None + _pybind_ops.encode_audio_to_file_like(file_like, wf, sample_rate, format, bit_rate) + + # ============================== # Abstract impl for the operators. Needed by torch.compile. # ============================== diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 6f873f5a..7d1d2e76 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -10,6 +10,7 @@ #include #include "src/torchcodec/_core/AVIOFileLikeContext.h" +#include "src/torchcodec/_core/Encoder.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" namespace py = pybind11; @@ -38,8 +39,26 @@ int64_t create_from_file_like( return reinterpret_cast(decoder); } +void encode_audio_to_file_like( + py::object file_like, + // const at::Tensor wf, + [[maybe_unused]] int wf, + int64_t sample_rate, + std::string_view format, + std::optional bit_rate = std::nullopt) { + auto avioContextHolder = std::make_unique(file_like); + AudioEncoder( + torch::empty({2, 1000}, torch::kFloat32), + sample_rate, // TODO need validateSampleRate + format, + std::move(avioContextHolder), + bit_rate) + .encode(); +} + PYBIND11_MODULE(decoder_core_pybind_ops, m) { m.def("create_from_file_like", &create_from_file_like); + m.def("encode_audio_to_file_like", &encode_audio_to_file_like); } } // namespace facebook::torchcodec