Skip to content

Commit 5713507

Browse files
authored
Support all audio formats by converting to FLTP (#556)
1 parent b32aabe commit 5713507

File tree

8 files changed

+199
-24
lines changed

8 files changed

+199
-24
lines changed

Diff for: src/torchcodec/decoders/_core/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ if(DEFINED ENV{BUILD_AGAINST_ALL_FFMPEG_FROM_S3})
7777
)
7878

7979

80-
make_torchcodec_library(libtorchcodec4 ffmpeg4)
8180
make_torchcodec_library(libtorchcodec7 ffmpeg7)
8281
make_torchcodec_library(libtorchcodec6 ffmpeg6)
8382
make_torchcodec_library(libtorchcodec5 ffmpeg5)
83+
make_torchcodec_library(libtorchcodec4 ffmpeg4)
8484

8585
else()
8686
message(
@@ -97,6 +97,7 @@ else()
9797
libavformat
9898
libavcodec
9999
libavutil
100+
libswresample
100101
libswscale
101102
)
102103

Diff for: src/torchcodec/decoders/_core/FFMPEGCommon.cpp

+52-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ int64_t getDuration(const AVFrame* frame) {
6060
#endif
6161
}
6262

63-
int getNumChannels(const AVFrame* avFrame) {
63+
int getNumChannels(const UniqueAVFrame& avFrame) {
6464
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
6565
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
6666
return avFrame->ch_layout.nb_channels;
@@ -78,6 +78,57 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
7878
#endif
7979
}
8080

81+
void setChannelLayout(
82+
UniqueAVFrame& dstAVFrame,
83+
const UniqueAVFrame& srcAVFrame) {
84+
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
85+
dstAVFrame->ch_layout = srcAVFrame->ch_layout;
86+
#else
87+
dstAVFrame->channel_layout = srcAVFrame->channel_layout;
88+
#endif
89+
}
90+
91+
SwrContext* allocateSwrContext(
92+
UniqueAVCodecContext& avCodecContext,
93+
int sampleRate,
94+
AVSampleFormat sourceSampleFormat,
95+
AVSampleFormat desiredSampleFormat) {
96+
SwrContext* swrContext = nullptr;
97+
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
98+
AVChannelLayout layout = avCodecContext->ch_layout;
99+
auto status = swr_alloc_set_opts2(
100+
&swrContext,
101+
&layout,
102+
desiredSampleFormat,
103+
sampleRate,
104+
&layout,
105+
sourceSampleFormat,
106+
sampleRate,
107+
0,
108+
nullptr);
109+
110+
TORCH_CHECK(
111+
status == AVSUCCESS,
112+
"Couldn't create SwrContext: ",
113+
getFFMPEGErrorStringFromErrorCode(status));
114+
#else
115+
int64_t layout = static_cast<int64_t>(avCodecContext->channel_layout);
116+
swrContext = swr_alloc_set_opts(
117+
nullptr,
118+
layout,
119+
desiredSampleFormat,
120+
sampleRate,
121+
layout,
122+
sourceSampleFormat,
123+
sampleRate,
124+
0,
125+
nullptr);
126+
#endif
127+
128+
TORCH_CHECK(swrContext != nullptr, "Couldn't create swrContext");
129+
return swrContext;
130+
}
131+
81132
AVIOBytesContext::AVIOBytesContext(
82133
const void* data,
83134
size_t dataSize,

Diff for: src/torchcodec/decoders/_core/FFMPEGCommon.h

+13-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ extern "C" {
2222
#include <libavutil/opt.h>
2323
#include <libavutil/pixfmt.h>
2424
#include <libavutil/version.h>
25+
#include <libswresample/swresample.h>
2526
#include <libswscale/swscale.h>
2627
}
2728

@@ -67,6 +68,8 @@ using UniqueAVIOContext = std::
6768
unique_ptr<AVIOContext, Deleterp<AVIOContext, void, avio_context_free>>;
6869
using UniqueSwsContext =
6970
std::unique_ptr<SwsContext, Deleter<SwsContext, void, sws_freeContext>>;
71+
using UniqueSwrContext =
72+
std::unique_ptr<SwrContext, Deleterp<SwrContext, void, swr_free>>;
7073

7174
// These 2 classes share the same underlying AVPacket object. They are meant to
7275
// be used in tandem, like so:
@@ -139,9 +142,18 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
139142
int64_t getDuration(const UniqueAVFrame& frame);
140143
int64_t getDuration(const AVFrame* frame);
141144

142-
int getNumChannels(const AVFrame* avFrame);
145+
int getNumChannels(const UniqueAVFrame& avFrame);
143146
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
144147

148+
void setChannelLayout(
149+
UniqueAVFrame& dstAVFrame,
150+
const UniqueAVFrame& srcAVFrame);
151+
SwrContext* allocateSwrContext(
152+
UniqueAVCodecContext& avCodecContext,
153+
int sampleRate,
154+
AVSampleFormat sourceSampleFormat,
155+
AVSampleFormat desiredSampleFormat);
156+
145157
// Returns true if sws_scale can handle unaligned data.
146158
bool canSwsScaleHandleUnalignedData();
147159

Diff for: src/torchcodec/decoders/_core/VideoDecoder.cpp

+99-21
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ extern "C" {
2323
#include <libavutil/imgutils.h>
2424
#include <libavutil/log.h>
2525
#include <libavutil/pixdesc.h>
26+
#include <libswresample/swresample.h>
2627
#include <libswscale/swscale.h>
2728
}
2829

@@ -559,6 +560,12 @@ void VideoDecoder::addAudioStream(int streamIndex) {
559560
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
560561
streamMetadata.numChannels =
561562
static_cast<int64_t>(getNumChannels(streamInfo.codecContext));
563+
564+
// FFmpeg docs say that the decoder will try to decode natively in this
565+
// format, if it can. Docs don't say what the decoder does when it doesn't
566+
// support that format, but it looks like it does nothing, so this probably
567+
// doesn't hurt.
568+
streamInfo.codecContext->request_sample_fmt = AV_SAMPLE_FMT_FLTP;
562569
}
563570

564571
// --------------------------------------------------------------------------
@@ -1350,37 +1357,89 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13501357
!preAllocatedOutputTensor.has_value(),
13511358
"pre-allocated audio tensor not supported yet.");
13521359

1353-
const AVFrame* avFrame = avFrameStream.avFrame.get();
1360+
AVSampleFormat sourceSampleFormat =
1361+
static_cast<AVSampleFormat>(avFrameStream.avFrame->format);
1362+
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
1363+
1364+
UniqueAVFrame convertedAVFrame;
1365+
if (sourceSampleFormat != desiredSampleFormat) {
1366+
convertedAVFrame = convertAudioAVFrameSampleFormat(
1367+
avFrameStream.avFrame, sourceSampleFormat, desiredSampleFormat);
1368+
}
1369+
const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
1370+
? convertedAVFrame
1371+
: avFrameStream.avFrame;
1372+
1373+
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
1374+
TORCH_CHECK(
1375+
format == desiredSampleFormat,
1376+
"Something went wrong, the frame didn't get converted to the desired format. ",
1377+
"Desired format = ",
1378+
av_get_sample_fmt_name(desiredSampleFormat),
1379+
"source format = ",
1380+
av_get_sample_fmt_name(format));
13541381

13551382
auto numSamples = avFrame->nb_samples; // per channel
13561383
auto numChannels = getNumChannels(avFrame);
13571384
torch::Tensor outputData =
13581385
torch::empty({numChannels, numSamples}, torch::kFloat32);
13591386

1360-
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
1361-
// TODO-AUDIO Implement all formats.
1362-
switch (format) {
1363-
case AV_SAMPLE_FMT_FLTP: {
1364-
uint8_t* outputChannelData = static_cast<uint8_t*>(outputData.data_ptr());
1365-
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
1366-
for (auto channel = 0; channel < numChannels;
1367-
++channel, outputChannelData += numBytesPerChannel) {
1368-
memcpy(
1369-
outputChannelData,
1370-
avFrame->extended_data[channel],
1371-
numBytesPerChannel);
1372-
}
1373-
break;
1374-
}
1375-
default:
1376-
TORCH_CHECK(
1377-
false,
1378-
"Unsupported audio format (yet!): ",
1379-
av_get_sample_fmt_name(format));
1387+
uint8_t* outputChannelData = static_cast<uint8_t*>(outputData.data_ptr());
1388+
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
1389+
for (auto channel = 0; channel < numChannels;
1390+
++channel, outputChannelData += numBytesPerChannel) {
1391+
memcpy(
1392+
outputChannelData, avFrame->extended_data[channel], numBytesPerChannel);
13801393
}
13811394
frameOutput.data = outputData;
13821395
}
13831396

1397+
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat(
1398+
const UniqueAVFrame& avFrame,
1399+
AVSampleFormat sourceSampleFormat,
1400+
AVSampleFormat desiredSampleFormat
1401+
1402+
) {
1403+
auto& streamInfo = streamInfos_[activeStreamIndex_];
1404+
const auto& streamMetadata =
1405+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
1406+
int sampleRate = static_cast<int>(streamMetadata.sampleRate.value());
1407+
1408+
if (!streamInfo.swrContext) {
1409+
createSwrContext(
1410+
streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
1411+
}
1412+
1413+
UniqueAVFrame convertedAVFrame(av_frame_alloc());
1414+
TORCH_CHECK(
1415+
convertedAVFrame,
1416+
"Could not allocate frame for sample format conversion.");
1417+
1418+
setChannelLayout(convertedAVFrame, avFrame);
1419+
convertedAVFrame->format = static_cast<int>(desiredSampleFormat);
1420+
convertedAVFrame->sample_rate = avFrame->sample_rate;
1421+
convertedAVFrame->nb_samples = avFrame->nb_samples;
1422+
1423+
auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
1424+
TORCH_CHECK(
1425+
status == AVSUCCESS,
1426+
"Could not allocate frame buffers for sample format conversion: ",
1427+
getFFMPEGErrorStringFromErrorCode(status));
1428+
1429+
auto numSampleConverted = swr_convert(
1430+
streamInfo.swrContext.get(),
1431+
convertedAVFrame->data,
1432+
convertedAVFrame->nb_samples,
1433+
static_cast<const uint8_t**>(const_cast<const uint8_t**>(avFrame->data)),
1434+
avFrame->nb_samples);
1435+
TORCH_CHECK(
1436+
numSampleConverted > 0,
1437+
"Error in swr_convert: ",
1438+
getFFMPEGErrorStringFromErrorCode(numSampleConverted));
1439+
1440+
return convertedAVFrame;
1441+
}
1442+
13841443
// --------------------------------------------------------------------------
13851444
// OUTPUT ALLOCATION AND SHAPE CONVERSION
13861445
// --------------------------------------------------------------------------
@@ -1614,6 +1673,25 @@ void VideoDecoder::createSwsContext(
16141673
streamInfo.swsContext.reset(swsContext);
16151674
}
16161675

1676+
void VideoDecoder::createSwrContext(
1677+
StreamInfo& streamInfo,
1678+
int sampleRate,
1679+
AVSampleFormat sourceSampleFormat,
1680+
AVSampleFormat desiredSampleFormat) {
1681+
auto swrContext = allocateSwrContext(
1682+
streamInfo.codecContext,
1683+
sampleRate,
1684+
sourceSampleFormat,
1685+
desiredSampleFormat);
1686+
1687+
auto status = swr_init(swrContext);
1688+
TORCH_CHECK(
1689+
status == AVSUCCESS,
1690+
"Couldn't initialize SwrContext: ",
1691+
getFFMPEGErrorStringFromErrorCode(status));
1692+
streamInfo.swrContext.reset(swrContext);
1693+
}
1694+
16171695
// --------------------------------------------------------------------------
16181696
// PTS <-> INDEX CONVERSIONS
16191697
// --------------------------------------------------------------------------

Diff for: src/torchcodec/decoders/_core/VideoDecoder.h

+12
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ class VideoDecoder {
355355
FilterGraphContext filterGraphContext;
356356
ColorConversionLibrary colorConversionLibrary = FILTERGRAPH;
357357
UniqueSwsContext swsContext;
358+
UniqueSwrContext swrContext;
358359

359360
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
360361
// be created before decoding a new frame.
@@ -402,6 +403,11 @@ class VideoDecoder {
402403
const AVFrame* avFrame,
403404
torch::Tensor& outputTensor);
404405

406+
UniqueAVFrame convertAudioAVFrameSampleFormat(
407+
const UniqueAVFrame& avFrame,
408+
AVSampleFormat sourceSampleFormat,
409+
AVSampleFormat desiredSampleFormat);
410+
405411
// --------------------------------------------------------------------------
406412
// COLOR CONVERSION LIBRARIES HANDLERS CREATION
407413
// --------------------------------------------------------------------------
@@ -416,6 +422,12 @@ class VideoDecoder {
416422
const DecodedFrameContext& frameContext,
417423
const enum AVColorSpace colorspace);
418424

425+
void createSwrContext(
426+
StreamInfo& streamInfo,
427+
int sampleRate,
428+
AVSampleFormat sourceSampleFormat,
429+
AVSampleFormat desiredSampleFormat);
430+
419431
// --------------------------------------------------------------------------
420432
// PTS <-> INDEX CONVERSIONS
421433
// --------------------------------------------------------------------------

Diff for: test/decoders/test_decoders.py

+18
Original file line numberDiff line numberDiff line change
@@ -1070,3 +1070,21 @@ def test_frame_start_is_not_zero(self):
10701070

10711071
reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index)
10721072
torch.testing.assert_close(samples.data, reference_frames)
1073+
1074+
def test_single_channel(self):
1075+
asset = SINE_MONO_S32
1076+
decoder = AudioDecoder(asset.path)
1077+
1078+
samples = decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=2)
1079+
assert samples.data.shape[0] == asset.num_channels == 1
1080+
1081+
def test_format_conversion(self):
1082+
asset = SINE_MONO_S32
1083+
decoder = AudioDecoder(asset.path)
1084+
assert decoder.metadata.sample_format == asset.sample_format == "s32"
1085+
1086+
all_samples = decoder.get_samples_played_in_range(start_seconds=0)
1087+
assert all_samples.data.dtype == torch.float32
1088+
1089+
reference_frames = asset.get_frame_data_by_range(start=0, stop=asset.num_frames)
1090+
torch.testing.assert_close(all_samples.data, reference_frames)
266 KB
Binary file not shown.

Diff for: test/utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,9 @@ def sample_format(self) -> str:
444444
},
445445
)
446446

447+
# Note that the file itself is s32 sample format, but the reference frames are
448+
# stored as fltp. We can add the s32 original reference frames once we support
449+
# decoding to non-fltp format, but for now we don't need to.
447450
SINE_MONO_S32 = TestAudio(
448451
filename="sine_mono_s32.wav",
449452
default_stream_index=0,

0 commit comments

Comments
 (0)