Skip to content

Commit de7c3c0

Browse files
committed
Move timeBase to convertAVFrameToFrameOutput
Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 5248cbd commit de7c3c0

7 files changed

+29
-32
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ namespace {
1616

1717
bool g_cpu = registerDeviceInterface(
1818
torch::kCPU,
19-
[](const torch::Device& device, const AVRational& timeBase) {
20-
return new CpuDeviceInterface(device, timeBase);
21-
});
19+
[](const torch::Device& device) { return new CpuDeviceInterface(device); });
2220

2321
} // namespace
2422

@@ -36,10 +34,8 @@ bool CpuDeviceInterface::DecodedFrameContext::operator!=(
3634
return !(*this == other);
3735
}
3836

39-
CpuDeviceInterface::CpuDeviceInterface(
40-
const torch::Device& device,
41-
const AVRational& timeBase)
42-
: DeviceInterface(device, timeBase) {
37+
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
38+
: DeviceInterface(device) {
4339
if (device_.type() != torch::kCPU) {
4440
throw std::runtime_error("Unsupported device: " + device_.str());
4541
}
@@ -56,6 +52,7 @@ CpuDeviceInterface::CpuDeviceInterface(
5652
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
5753
void CpuDeviceInterface::convertAVFrameToFrameOutput(
5854
const VideoStreamOptions& videoStreamOptions,
55+
const AVRational& timeBase,
5956
UniqueAVFrame& avFrame,
6057
FrameOutput& frameOutput,
6158
std::optional<torch::Tensor> preAllocatedOutputTensor) {
@@ -136,7 +133,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
136133
frameOutput.data = outputTensor;
137134
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
138135
if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
139-
createFilterGraph(frameContext, videoStreamOptions);
136+
createFilterGraph(frameContext, videoStreamOptions, timeBase);
140137
prevFrameContext_ = frameContext;
141138
}
142139
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
@@ -215,7 +212,8 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
215212

216213
void CpuDeviceInterface::createFilterGraph(
217214
const DecodedFrameContext& frameContext,
218-
const VideoStreamOptions& videoStreamOptions) {
215+
const VideoStreamOptions& videoStreamOptions,
216+
const AVRational& timeBase) {
219217
filterGraphContext_.filterGraph.reset(avfilter_graph_alloc());
220218
TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr);
221219

@@ -231,7 +229,7 @@ void CpuDeviceInterface::createFilterGraph(
231229
filterArgs << "video_size=" << frameContext.decodedWidth << "x"
232230
<< frameContext.decodedHeight;
233231
filterArgs << ":pix_fmt=" << frameContext.decodedFormat;
234-
filterArgs << ":time_base=" << timeBase_.num << "/" << timeBase_.den;
232+
filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den;
235233
filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/"
236234
<< frameContext.decodedAspectRatio.den;
237235

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace facebook::torchcodec {
1313

1414
class CpuDeviceInterface : public DeviceInterface {
1515
public:
16-
CpuDeviceInterface(const torch::Device& device, const AVRational& timeBase);
16+
CpuDeviceInterface(const torch::Device& device);
1717

1818
virtual ~CpuDeviceInterface() {}
1919

@@ -27,6 +27,7 @@ class CpuDeviceInterface : public DeviceInterface {
2727

2828
void convertAVFrameToFrameOutput(
2929
const VideoStreamOptions& videoStreamOptions,
30+
const AVRational& timeBase,
3031
UniqueAVFrame& avFrame,
3132
FrameOutput& frameOutput,
3233
std::optional<torch::Tensor> preAllocatedOutputTensor =
@@ -63,7 +64,8 @@ class CpuDeviceInterface : public DeviceInterface {
6364

6465
void createFilterGraph(
6566
const DecodedFrameContext& frameContext,
66-
const VideoStreamOptions& videoStreamOptions);
67+
const VideoStreamOptions& videoStreamOptions,
68+
const AVRational& timeBase);
6769

6870
// color-conversion fields. Only one of FilterGraphContext and
6971
// UniqueSwsContext should be non-null.

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@ extern "C" {
1515
namespace facebook::torchcodec {
1616
namespace {
1717

18-
bool g_cuda = registerDeviceInterface(
19-
torch::kCUDA,
20-
[](const torch::Device& device, const AVRational& timeBase) {
21-
return new CudaDeviceInterface(device, timeBase);
18+
bool g_cuda =
19+
registerDeviceInterface(torch::kCUDA, [](const torch::Device& device) {
20+
return new CudaDeviceInterface(device);
2221
});
2322

2423
// We reuse cuda contexts across VideoDeoder instances. This is because
@@ -164,10 +163,8 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
164163
}
165164
} // namespace
166165

167-
CudaDeviceInterface::CudaDeviceInterface(
168-
const torch::Device& device,
169-
const AVRational& timeBase)
170-
: DeviceInterface(device, timeBase) {
166+
CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
167+
: DeviceInterface(device) {
171168
if (device_.type() != torch::kCUDA) {
172169
throw std::runtime_error("Unsupported device: " + device_.str());
173170
}
@@ -195,6 +192,7 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
195192

196193
void CudaDeviceInterface::convertAVFrameToFrameOutput(
197194
const VideoStreamOptions& videoStreamOptions,
195+
[[maybe_unused]] const AVRational& timeBase,
198196
UniqueAVFrame& avFrame,
199197
FrameOutput& frameOutput,
200198
std::optional<torch::Tensor> preAllocatedOutputTensor) {

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace facebook::torchcodec {
1212

1313
class CudaDeviceInterface : public DeviceInterface {
1414
public:
15-
CudaDeviceInterface(const torch::Device& device, const AVRational& timeBase);
15+
CudaDeviceInterface(const torch::Device& device);
1616

1717
virtual ~CudaDeviceInterface();
1818

@@ -22,6 +22,7 @@ class CudaDeviceInterface : public DeviceInterface {
2222

2323
void convertAVFrameToFrameOutput(
2424
const VideoStreamOptions& videoStreamOptions,
25+
const AVRational& timeBase,
2526
UniqueAVFrame& avFrame,
2627
FrameOutput& frameOutput,
2728
std::optional<torch::Tensor> preAllocatedOutputTensor =

src/torchcodec/_core/DeviceInterface.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ torch::Device createTorchDevice(const std::string device) {
6363
}
6464

6565
std::unique_ptr<DeviceInterface> createDeviceInterface(
66-
const torch::Device& device,
67-
const AVRational& timeBase) {
66+
const torch::Device& device) {
6867
auto deviceType = device.type();
6968
std::scoped_lock lock(g_interface_mutex);
7069
TORCH_CHECK(
@@ -73,7 +72,7 @@ std::unique_ptr<DeviceInterface> createDeviceInterface(
7372
device);
7473

7574
return std::unique_ptr<DeviceInterface>(
76-
(*g_interface_map)[deviceType](device, timeBase));
75+
(*g_interface_map)[deviceType](device));
7776
}
7877

7978
} // namespace facebook::torchcodec

src/torchcodec/_core/DeviceInterface.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ namespace facebook::torchcodec {
2727

2828
class DeviceInterface {
2929
public:
30-
DeviceInterface(const torch::Device& device, const AVRational& timeBase)
31-
: device_(device), timeBase_(timeBase) {}
30+
DeviceInterface(const torch::Device& device) : device_(device) {}
3231

3332
virtual ~DeviceInterface(){};
3433

@@ -44,17 +43,17 @@ class DeviceInterface {
4443

4544
virtual void convertAVFrameToFrameOutput(
4645
const VideoStreamOptions& videoStreamOptions,
46+
const AVRational& timeBase,
4747
UniqueAVFrame& avFrame,
4848
FrameOutput& frameOutput,
4949
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;
5050

5151
protected:
5252
torch::Device device_;
53-
AVRational timeBase_;
5453
};
5554

56-
using CreateDeviceInterfaceFn = std::function<
57-
DeviceInterface*(const torch::Device& device, const AVRational& timeBase)>;
55+
using CreateDeviceInterfaceFn =
56+
std::function<DeviceInterface*(const torch::Device& device)>;
5857

5958
bool registerDeviceInterface(
6059
torch::DeviceType deviceType,
@@ -63,7 +62,6 @@ bool registerDeviceInterface(
6362
torch::Device createTorchDevice(const std::string device);
6463

6564
std::unique_ptr<DeviceInterface> createDeviceInterface(
66-
const torch::Device& device,
67-
const AVRational& timeBase);
65+
const torch::Device& device);
6866

6967
} // namespace facebook::torchcodec

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ void SingleStreamDecoder::addStream(
363363
streamInfo.stream = formatContext_->streams[activeStreamIndex_];
364364
streamInfo.avMediaType = mediaType;
365365

366-
deviceInterface_ = createDeviceInterface(device, streamInfo.timeBase);
366+
deviceInterface_ = createDeviceInterface(device);
367367

368368
// This should never happen, checking just to be safe.
369369
TORCH_CHECK(
@@ -1151,6 +1151,7 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
11511151
} else if (deviceInterface_) {
11521152
deviceInterface_->convertAVFrameToFrameOutput(
11531153
streamInfo.videoStreamOptions,
1154+
streamInfo.timeBase,
11541155
avFrame,
11551156
frameOutput,
11561157
preAllocatedOutputTensor);

0 commit comments

Comments
 (0)