Skip to content

Commit 5402e7d

Browse files
authored
Move stream options and frame output structs to dedicated headers (#620)
Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 8b19f45 commit 5402e7d

10 files changed

+226
-189
lines changed

Diff for: src/torchcodec/_core/CudaDevice.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,9 @@ void CudaDevice::initializeContext(AVCodecContext* codecContext) {
190190
}
191191

192192
void CudaDevice::convertAVFrameToFrameOutput(
193-
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
193+
const VideoStreamOptions& videoStreamOptions,
194194
UniqueAVFrame& avFrame,
195-
SingleStreamDecoder::FrameOutput& frameOutput,
195+
FrameOutput& frameOutput,
196196
std::optional<torch::Tensor> preAllocatedOutputTensor) {
197197
TORCH_CHECK(
198198
avFrame->format == AV_PIX_FMT_CUDA,

Diff for: src/torchcodec/_core/CudaDevice.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ class CudaDevice : public DeviceInterface {
2121
void initializeContext(AVCodecContext* codecContext) override;
2222

2323
void convertAVFrameToFrameOutput(
24-
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
24+
const VideoStreamOptions& videoStreamOptions,
2525
UniqueAVFrame& avFrame,
26-
SingleStreamDecoder::FrameOutput& frameOutput,
26+
FrameOutput& frameOutput,
2727
std::optional<torch::Tensor> preAllocatedOutputTensor =
2828
std::nullopt) override;
2929

Diff for: src/torchcodec/_core/DeviceInterface.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
#include <stdexcept>
1313
#include <string>
1414
#include "FFMPEGCommon.h"
15-
#include "src/torchcodec/_core/SingleStreamDecoder.h"
15+
#include "src/torchcodec/_core/Frame.h"
16+
#include "src/torchcodec/_core/StreamOptions.h"
1617

1718
namespace facebook::torchcodec {
1819

@@ -41,9 +42,9 @@ class DeviceInterface {
4142
virtual void initializeContext(AVCodecContext* codecContext) = 0;
4243

4344
virtual void convertAVFrameToFrameOutput(
44-
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
45+
const VideoStreamOptions& videoStreamOptions,
4546
UniqueAVFrame& avFrame,
46-
SingleStreamDecoder::FrameOutput& frameOutput,
47+
FrameOutput& frameOutput,
4748
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;
4849

4950
protected:

Diff for: src/torchcodec/_core/Frame.h

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <torch/types.h>
10+
#include "src/torchcodec/_core/Metadata.h"
11+
#include "src/torchcodec/_core/StreamOptions.h"
12+
13+
namespace facebook::torchcodec {
14+
15+
// All public video decoding entry points return either a FrameOutput or a
16+
// FrameBatchOutput.
17+
// They are the equivalent of the user-facing Frame and FrameBatch classes in
18+
// Python. They contain RGB decoded frames along with some associated data
19+
// like PTS and duration.
20+
// FrameOutput is also relevant for audio decoding, typically as the output of
21+
// getNextFrame(), or as a temporary output variable.
22+
struct FrameOutput {
23+
// data shape is:
24+
// - 3D (C, H, W) or (H, W, C) for videos
25+
// - 2D (numChannels, numSamples) for audio
26+
torch::Tensor data;
27+
double ptsSeconds;
28+
double durationSeconds;
29+
};
30+
31+
struct FrameBatchOutput {
32+
torch::Tensor data; // 4D: of shape NCHW or NHWC.
33+
torch::Tensor ptsSeconds; // 1D of shape (N,)
34+
torch::Tensor durationSeconds; // 1D of shape (N,)
35+
36+
explicit FrameBatchOutput(
37+
int64_t numFrames,
38+
const VideoStreamOptions& videoStreamOptions,
39+
const StreamMetadata& streamMetadata);
40+
};
41+
42+
struct AudioFramesOutput {
43+
torch::Tensor data; // shape is (numChannels, numSamples)
44+
double ptsSeconds;
45+
};
46+
47+
} // namespace facebook::torchcodec

Diff for: src/torchcodec/_core/Metadata.h

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <optional>
10+
#include <string>
11+
#include <vector>
12+
13+
extern "C" {
14+
#include <libavcodec/avcodec.h>
15+
#include <libavutil/avutil.h>
16+
}
17+
18+
namespace facebook::torchcodec {
19+
20+
struct StreamMetadata {
21+
// Common (video and audio) fields derived from the AVStream.
22+
int streamIndex;
23+
// See this link for what various values are available:
24+
// https://ffmpeg.org/doxygen/trunk/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48
25+
AVMediaType mediaType;
26+
std::optional<AVCodecID> codecId;
27+
std::optional<std::string> codecName;
28+
std::optional<double> durationSeconds;
29+
std::optional<double> beginStreamFromHeader;
30+
std::optional<int64_t> numFrames;
31+
std::optional<int64_t> numKeyFrames;
32+
std::optional<double> averageFps;
33+
std::optional<double> bitRate;
34+
35+
// More accurate duration, obtained by scanning the file.
36+
// These presentation timestamps are in time base.
37+
std::optional<int64_t> minPtsFromScan;
38+
std::optional<int64_t> maxPtsFromScan;
39+
// These presentation timestamps are in seconds.
40+
std::optional<double> minPtsSecondsFromScan;
41+
std::optional<double> maxPtsSecondsFromScan;
42+
// This can be useful for index-based seeking.
43+
std::optional<int64_t> numFramesFromScan;
44+
45+
// Video-only fields derived from the AVCodecContext.
46+
std::optional<int64_t> width;
47+
std::optional<int64_t> height;
48+
49+
// Audio-only fields
50+
std::optional<int64_t> sampleRate;
51+
std::optional<int64_t> numChannels;
52+
std::optional<std::string> sampleFormat;
53+
};
54+
55+
struct ContainerMetadata {
56+
std::vector<StreamMetadata> allStreamMetadata;
57+
int numAudioStreams = 0;
58+
int numVideoStreams = 0;
59+
// Note that this is the container-level duration, which is usually the max
60+
// of all stream durations available in the container.
61+
std::optional<double> durationSeconds;
62+
// Total BitRate level information at the container level in bit/s
63+
std::optional<double> bitRate;
64+
// If set, this is the index to the default audio stream.
65+
std::optional<int> bestAudioStreamIndex;
66+
// If set, this is the index to the default video stream.
67+
std::optional<int> bestVideoStreamIndex;
68+
};
69+
70+
} // namespace facebook::torchcodec

Diff for: src/torchcodec/_core/SingleStreamDecoder.cpp

+26-34
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include <sstream>
1414
#include <stdexcept>
1515
#include <string_view>
16-
#include "src/torchcodec/_core/DeviceInterface.h"
1716
#include "torch/types.h"
1817

1918
extern "C" {
@@ -350,8 +349,7 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
350349
scannedAllStreams_ = true;
351350
}
352351

353-
SingleStreamDecoder::ContainerMetadata
354-
SingleStreamDecoder::getContainerMetadata() const {
352+
ContainerMetadata SingleStreamDecoder::getContainerMetadata() const {
355353
return containerMetadata_;
356354
}
357355

@@ -406,7 +404,7 @@ void SingleStreamDecoder::addStream(
406404
streamInfo.stream = formatContext_->streams[activeStreamIndex_];
407405
streamInfo.avMediaType = mediaType;
408406

409-
deviceInterface = createDeviceInterface(device);
407+
deviceInterface_ = createDeviceInterface(device);
410408

411409
// This should never happen, checking just to be safe.
412410
TORCH_CHECK(
@@ -418,9 +416,9 @@ void SingleStreamDecoder::addStream(
418416
// TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
419417
// addStream() which is supposed to be generic
420418
if (mediaType == AVMEDIA_TYPE_VIDEO) {
421-
if (deviceInterface) {
419+
if (deviceInterface_) {
422420
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
423-
deviceInterface->findCodec(streamInfo.stream->codecpar->codec_id)
421+
deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id)
424422
.value_or(avCodec));
425423
}
426424
}
@@ -438,8 +436,8 @@ void SingleStreamDecoder::addStream(
438436

439437
// TODO_CODE_QUALITY same as above.
440438
if (mediaType == AVMEDIA_TYPE_VIDEO) {
441-
if (deviceInterface) {
442-
deviceInterface->initializeContext(codecContext);
439+
if (deviceInterface_) {
440+
deviceInterface_->initializeContext(codecContext);
443441
}
444442
}
445443

@@ -501,9 +499,8 @@ void SingleStreamDecoder::addVideoStream(
501499
// swscale requires widths to be multiples of 32:
502500
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
503501
// so we fall back to filtergraph if the width is not a multiple of 32.
504-
auto defaultLibrary = (width % 32 == 0)
505-
? SingleStreamDecoder::ColorConversionLibrary::SWSCALE
506-
: SingleStreamDecoder::ColorConversionLibrary::FILTERGRAPH;
502+
auto defaultLibrary = (width % 32 == 0) ? ColorConversionLibrary::SWSCALE
503+
: ColorConversionLibrary::FILTERGRAPH;
507504

508505
streamInfo.colorConversionLibrary =
509506
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
@@ -539,30 +536,29 @@ void SingleStreamDecoder::addAudioStream(
539536
// HIGH-LEVEL DECODING ENTRY-POINTS
540537
// --------------------------------------------------------------------------
541538

542-
SingleStreamDecoder::FrameOutput SingleStreamDecoder::getNextFrame() {
539+
FrameOutput SingleStreamDecoder::getNextFrame() {
543540
auto output = getNextFrameInternal();
544541
if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) {
545542
output.data = maybePermuteHWC2CHW(output.data);
546543
}
547544
return output;
548545
}
549546

550-
SingleStreamDecoder::FrameOutput SingleStreamDecoder::getNextFrameInternal(
547+
FrameOutput SingleStreamDecoder::getNextFrameInternal(
551548
std::optional<torch::Tensor> preAllocatedOutputTensor) {
552549
validateActiveStream();
553550
UniqueAVFrame avFrame = decodeAVFrame(
554551
[this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
555552
return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
556553
}
557554

558-
SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndex(
559-
int64_t frameIndex) {
555+
FrameOutput SingleStreamDecoder::getFrameAtIndex(int64_t frameIndex) {
560556
auto frameOutput = getFrameAtIndexInternal(frameIndex);
561557
frameOutput.data = maybePermuteHWC2CHW(frameOutput.data);
562558
return frameOutput;
563559
}
564560

565-
SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
561+
FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
566562
int64_t frameIndex,
567563
std::optional<torch::Tensor> preAllocatedOutputTensor) {
568564
validateActiveStream(AVMEDIA_TYPE_VIDEO);
@@ -577,7 +573,7 @@ SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
577573
return getNextFrameInternal(preAllocatedOutputTensor);
578574
}
579575

580-
SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
576+
FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
581577
const std::vector<int64_t>& frameIndices) {
582578
validateActiveStream(AVMEDIA_TYPE_VIDEO);
583579

@@ -636,7 +632,7 @@ SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
636632
return frameBatchOutput;
637633
}
638634

639-
SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesInRange(
635+
FrameBatchOutput SingleStreamDecoder::getFramesInRange(
640636
int64_t start,
641637
int64_t stop,
642638
int64_t step) {
@@ -670,8 +666,7 @@ SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesInRange(
670666
return frameBatchOutput;
671667
}
672668

673-
SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFramePlayedAt(
674-
double seconds) {
669+
FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) {
675670
validateActiveStream(AVMEDIA_TYPE_VIDEO);
676671
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
677672
double frameStartTime =
@@ -711,7 +706,7 @@ SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFramePlayedAt(
711706
return frameOutput;
712707
}
713708

714-
SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
709+
FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
715710
const std::vector<double>& timestamps) {
716711
validateActiveStream(AVMEDIA_TYPE_VIDEO);
717712

@@ -741,8 +736,7 @@ SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
741736
return getFramesAtIndices(frameIndices);
742737
}
743738

744-
SingleStreamDecoder::FrameBatchOutput
745-
SingleStreamDecoder::getFramesPlayedInRange(
739+
FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
746740
double startSeconds,
747741
double stopSeconds) {
748742
validateActiveStream(AVMEDIA_TYPE_VIDEO);
@@ -875,8 +869,7 @@ SingleStreamDecoder::getFramesPlayedInRange(
875869
// [2] If you're brave and curious, you can read the long "Seek offset for
876870
// audio" note in https://github.com/pytorch/torchcodec/pull/507/files, which
877871
// sums up past (and failed) attemps at working around this issue.
878-
SingleStreamDecoder::AudioFramesOutput
879-
SingleStreamDecoder::getFramesPlayedInRangeAudio(
872+
AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio(
880873
double startSeconds,
881874
std::optional<double> stopSecondsOptional) {
882875
validateActiveStream(AVMEDIA_TYPE_AUDIO);
@@ -1196,8 +1189,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
11961189
// AVFRAME <-> FRAME OUTPUT CONVERSION
11971190
// --------------------------------------------------------------------------
11981191

1199-
SingleStreamDecoder::FrameOutput
1200-
SingleStreamDecoder::convertAVFrameToFrameOutput(
1192+
FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
12011193
UniqueAVFrame& avFrame,
12021194
std::optional<torch::Tensor> preAllocatedOutputTensor) {
12031195
// Convert the frame to tensor.
@@ -1210,11 +1202,11 @@ SingleStreamDecoder::convertAVFrameToFrameOutput(
12101202
formatContext_->streams[activeStreamIndex_]->time_base);
12111203
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
12121204
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
1213-
} else if (!deviceInterface) {
1205+
} else if (!deviceInterface_) {
12141206
convertAVFrameToFrameOutputOnCPU(
12151207
avFrame, frameOutput, preAllocatedOutputTensor);
1216-
} else if (deviceInterface) {
1217-
deviceInterface->convertAVFrameToFrameOutput(
1208+
} else if (deviceInterface_) {
1209+
deviceInterface_->convertAVFrameToFrameOutput(
12181210
streamInfo.videoStreamOptions,
12191211
avFrame,
12201212
frameOutput,
@@ -1547,7 +1539,7 @@ std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
15471539
// OUTPUT ALLOCATION AND SHAPE CONVERSION
15481540
// --------------------------------------------------------------------------
15491541

1550-
SingleStreamDecoder::FrameBatchOutput::FrameBatchOutput(
1542+
FrameBatchOutput::FrameBatchOutput(
15511543
int64_t numFrames,
15521544
const VideoStreamOptions& videoStreamOptions,
15531545
const StreamMetadata& streamMetadata)
@@ -2047,15 +2039,15 @@ FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) {
20472039
}
20482040

20492041
FrameDims getHeightAndWidthFromOptionsOrMetadata(
2050-
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
2051-
const SingleStreamDecoder::StreamMetadata& streamMetadata) {
2042+
const VideoStreamOptions& videoStreamOptions,
2043+
const StreamMetadata& streamMetadata) {
20522044
return FrameDims(
20532045
videoStreamOptions.height.value_or(*streamMetadata.height),
20542046
videoStreamOptions.width.value_or(*streamMetadata.width));
20552047
}
20562048

20572049
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
2058-
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
2050+
const VideoStreamOptions& videoStreamOptions,
20592051
const UniqueAVFrame& avFrame) {
20602052
return FrameDims(
20612053
videoStreamOptions.height.value_or(avFrame->height),

0 commit comments

Comments
 (0)