@@ -23,6 +23,7 @@ extern "C" {
23
23
#include < libavutil/imgutils.h>
24
24
#include < libavutil/log.h>
25
25
#include < libavutil/pixdesc.h>
26
+ #include < libswresample/swresample.h>
26
27
#include < libswscale/swscale.h>
27
28
}
28
29
@@ -559,6 +560,12 @@ void VideoDecoder::addAudioStream(int streamIndex) {
559
560
static_cast <int64_t >(streamInfo.codecContext ->sample_rate );
560
561
streamMetadata.numChannels =
561
562
static_cast <int64_t >(getNumChannels (streamInfo.codecContext ));
563
+
564
+ // FFmpeg docs say that the decoder will try to decode natively in this
565
+ // format, if it can. Docs don't say what the decoder does when it doesn't
566
+ // support that format, but it looks like it does nothing, so this probably
567
+ // doesn't hurt.
568
+ streamInfo.codecContext ->request_sample_fmt = AV_SAMPLE_FMT_FLTP;
562
569
}
563
570
564
571
// --------------------------------------------------------------------------
@@ -1350,37 +1357,89 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1350
1357
!preAllocatedOutputTensor.has_value (),
1351
1358
" pre-allocated audio tensor not supported yet." );
1352
1359
1353
- const AVFrame* avFrame = avFrameStream.avFrame .get ();
1360
+ AVSampleFormat sourceSampleFormat =
1361
+ static_cast <AVSampleFormat>(avFrameStream.avFrame ->format );
1362
+ AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
1363
+
1364
+ UniqueAVFrame convertedAVFrame;
1365
+ if (sourceSampleFormat != desiredSampleFormat) {
1366
+ convertedAVFrame = convertAudioAVFrameSampleFormat (
1367
+ avFrameStream.avFrame , sourceSampleFormat, desiredSampleFormat);
1368
+ }
1369
+ const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
1370
+ ? convertedAVFrame
1371
+ : avFrameStream.avFrame ;
1372
+
1373
+ AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1374
+ TORCH_CHECK (
1375
+ format == desiredSampleFormat,
1376
+ " Something went wrong, the frame didn't get converted to the desired format. " ,
1377
+ " Desired format = " ,
1378
+ av_get_sample_fmt_name (desiredSampleFormat),
1379
+ " source format = " ,
1380
+ av_get_sample_fmt_name (format));
1354
1381
1355
1382
auto numSamples = avFrame->nb_samples ; // per channel
1356
1383
auto numChannels = getNumChannels (avFrame);
1357
1384
torch::Tensor outputData =
1358
1385
torch::empty ({numChannels, numSamples}, torch::kFloat32 );
1359
1386
1360
- AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1361
- // TODO-AUDIO Implement all formats.
1362
- switch (format) {
1363
- case AV_SAMPLE_FMT_FLTP: {
1364
- uint8_t * outputChannelData = static_cast <uint8_t *>(outputData.data_ptr ());
1365
- auto numBytesPerChannel = numSamples * av_get_bytes_per_sample (format);
1366
- for (auto channel = 0 ; channel < numChannels;
1367
- ++channel, outputChannelData += numBytesPerChannel) {
1368
- memcpy (
1369
- outputChannelData,
1370
- avFrame->extended_data [channel],
1371
- numBytesPerChannel);
1372
- }
1373
- break ;
1374
- }
1375
- default :
1376
- TORCH_CHECK (
1377
- false ,
1378
- " Unsupported audio format (yet!): " ,
1379
- av_get_sample_fmt_name (format));
1387
+ uint8_t * outputChannelData = static_cast <uint8_t *>(outputData.data_ptr ());
1388
+ auto numBytesPerChannel = numSamples * av_get_bytes_per_sample (format);
1389
+ for (auto channel = 0 ; channel < numChannels;
1390
+ ++channel, outputChannelData += numBytesPerChannel) {
1391
+ memcpy (
1392
+ outputChannelData, avFrame->extended_data [channel], numBytesPerChannel);
1380
1393
}
1381
1394
frameOutput.data = outputData;
1382
1395
}
1383
1396
1397
+ UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat (
1398
+ const UniqueAVFrame& avFrame,
1399
+ AVSampleFormat sourceSampleFormat,
1400
+ AVSampleFormat desiredSampleFormat
1401
+
1402
+ ) {
1403
+ auto & streamInfo = streamInfos_[activeStreamIndex_];
1404
+ const auto & streamMetadata =
1405
+ containerMetadata_.allStreamMetadata [activeStreamIndex_];
1406
+ int sampleRate = static_cast <int >(streamMetadata.sampleRate .value ());
1407
+
1408
+ if (!streamInfo.swrContext ) {
1409
+ createSwrContext (
1410
+ streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
1411
+ }
1412
+
1413
+ UniqueAVFrame convertedAVFrame (av_frame_alloc ());
1414
+ TORCH_CHECK (
1415
+ convertedAVFrame,
1416
+ " Could not allocate frame for sample format conversion." );
1417
+
1418
+ setChannelLayout (convertedAVFrame, avFrame);
1419
+ convertedAVFrame->format = static_cast <int >(desiredSampleFormat);
1420
+ convertedAVFrame->sample_rate = avFrame->sample_rate ;
1421
+ convertedAVFrame->nb_samples = avFrame->nb_samples ;
1422
+
1423
+ auto status = av_frame_get_buffer (convertedAVFrame.get (), 0 );
1424
+ TORCH_CHECK (
1425
+ status == AVSUCCESS,
1426
+ " Could not allocate frame buffers for sample format conversion: " ,
1427
+ getFFMPEGErrorStringFromErrorCode (status));
1428
+
1429
+ auto numSampleConverted = swr_convert (
1430
+ streamInfo.swrContext .get (),
1431
+ convertedAVFrame->data ,
1432
+ convertedAVFrame->nb_samples ,
1433
+ static_cast <const uint8_t **>(const_cast <const uint8_t **>(avFrame->data )),
1434
+ avFrame->nb_samples );
1435
+ TORCH_CHECK (
1436
+ numSampleConverted > 0 ,
1437
+ " Error in swr_convert: " ,
1438
+ getFFMPEGErrorStringFromErrorCode (numSampleConverted));
1439
+
1440
+ return convertedAVFrame;
1441
+ }
1442
+
1384
1443
// --------------------------------------------------------------------------
1385
1444
// OUTPUT ALLOCATION AND SHAPE CONVERSION
1386
1445
// --------------------------------------------------------------------------
@@ -1614,6 +1673,25 @@ void VideoDecoder::createSwsContext(
1614
1673
streamInfo.swsContext .reset (swsContext);
1615
1674
}
1616
1675
1676
+ void VideoDecoder::createSwrContext (
1677
+ StreamInfo& streamInfo,
1678
+ int sampleRate,
1679
+ AVSampleFormat sourceSampleFormat,
1680
+ AVSampleFormat desiredSampleFormat) {
1681
+ auto swrContext = allocateSwrContext (
1682
+ streamInfo.codecContext ,
1683
+ sampleRate,
1684
+ sourceSampleFormat,
1685
+ desiredSampleFormat);
1686
+
1687
+ auto status = swr_init (swrContext);
1688
+ TORCH_CHECK (
1689
+ status == AVSUCCESS,
1690
+ " Couldn't initialize SwrContext: " ,
1691
+ getFFMPEGErrorStringFromErrorCode (status));
1692
+ streamInfo.swrContext .reset (swrContext);
1693
+ }
1694
+
1617
1695
// --------------------------------------------------------------------------
1618
1696
// PTS <-> INDEX CONVERSIONS
1619
1697
// --------------------------------------------------------------------------
0 commit comments