Skip to content

Commit c19dac3

Browse files
committed
Make device interface generic
Fixes: pytorch#605 Changes: * Device interface made device agnostic by intorducing `class DeviceInterface` from which specific backends should inherit their device specific implementations * Implemented `CudaDevice` derived from `DeviceInterface` * Created device interface registration mechanism (`registerDeviceInterface`) * Created device interface creation mechanism (`createDeviceInterface`) These changes allow to replace CUDA specific code in `VideoDecoder.cpp` and `VideoDecoderOps.cpp` by device agnostic code. Signed-off-by: Dmitry Rogozhkin <[email protected]> address comments Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent f416dcf commit c19dac3

9 files changed

+175
-135
lines changed

src/torchcodec/decoders/_core/CMakeLists.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ function(make_torchcodec_libraries
6161
AVIOContextHolder.cpp
6262
FFMPEGCommon.cpp
6363
VideoDecoder.cpp
64+
DeviceInterface.cpp
6465
)
6566

6667
if(ENABLE_CUDA)
6768
list(APPEND decoder_sources CudaDevice.cpp)
68-
else()
69-
list(APPEND decoder_sources CPUOnlyDevice.cpp)
7069
endif()
7170

7271
set(decoder_library_dependencies

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

-44
This file was deleted.

src/torchcodec/decoders/_core/CudaDevice.cpp

+19-29
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <torch/types.h>
55
#include <mutex>
66

7-
#include "src/torchcodec/decoders/_core/DeviceInterface.h"
7+
#include "src/torchcodec/decoders/_core/CudaDevice.h"
88
#include "src/torchcodec/decoders/_core/FFMPEGCommon.h"
99
#include "src/torchcodec/decoders/_core/VideoDecoder.h"
1010

@@ -16,6 +16,10 @@ extern "C" {
1616
namespace facebook::torchcodec {
1717
namespace {
1818

19+
bool g_cuda = registerDeviceInterface("cuda", [](const std::string& device) {
20+
return new CudaDevice(device);
21+
});
22+
1923
// We reuse cuda contexts across VideoDeoder instances. This is because
2024
// creating a cuda context is expensive. The cache mechanism is as follows:
2125
// 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for
@@ -156,39 +160,29 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
156160
device, nonNegativeDeviceIndex, type);
157161
#endif
158162
}
163+
} // namespace
159164

160-
void throwErrorIfNonCudaDevice(const torch::Device& device) {
161-
TORCH_CHECK(
162-
device.type() != torch::kCPU,
163-
"Device functions should only be called if the device is not CPU.")
164-
if (device.type() != torch::kCUDA) {
165-
throw std::runtime_error("Unsupported device: " + device.str());
165+
CudaDevice::CudaDevice(const std::string& device) : DeviceInterface(device) {
166+
if (device_.type() != torch::kCUDA) {
167+
throw std::runtime_error("Unsupported device: " + device_.str());
166168
}
167169
}
168-
} // namespace
169170

170-
void releaseContextOnCuda(
171-
const torch::Device& device,
172-
AVCodecContext* codecContext) {
173-
throwErrorIfNonCudaDevice(device);
174-
addToCacheIfCacheHasCapacity(device, codecContext);
171+
void CudaDevice::releaseContext(AVCodecContext* codecContext) {
172+
addToCacheIfCacheHasCapacity(device_, codecContext);
175173
}
176174

177-
void initializeContextOnCuda(
178-
const torch::Device& device,
179-
AVCodecContext* codecContext) {
180-
throwErrorIfNonCudaDevice(device);
175+
void CudaDevice::initializeContext(AVCodecContext* codecContext) {
181176
// It is important for pytorch itself to create the cuda context. If ffmpeg
182177
// creates the context it may not be compatible with pytorch.
183178
// This is a dummy tensor to initialize the cuda context.
184179
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
185-
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device));
186-
codecContext->hw_device_ctx = getCudaContext(device);
180+
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
181+
codecContext->hw_device_ctx = getCudaContext(device_);
187182
return;
188183
}
189184

190-
void convertAVFrameToFrameOutputOnCuda(
191-
const torch::Device& device,
185+
void CudaDevice::convertAVFrameToFrameOutput(
192186
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
193187
UniqueAVFrame& avFrame,
194188
VideoDecoder::FrameOutput& frameOutput,
@@ -215,11 +209,11 @@ void convertAVFrameToFrameOutputOnCuda(
215209
"x3, got ",
216210
shape);
217211
} else {
218-
dst = allocateEmptyHWCTensor(height, width, videoStreamOptions.device);
212+
dst = allocateEmptyHWCTensor(height, width, device_);
219213
}
220214

221215
// Use the user-requested GPU for running the NPP kernel.
222-
c10::cuda::CUDAGuard deviceGuard(device);
216+
c10::cuda::CUDAGuard deviceGuard(device_);
223217

224218
NppiSize oSizeROI = {width, height};
225219
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};
@@ -247,7 +241,7 @@ void convertAVFrameToFrameOutputOnCuda(
247241
// output.
248242
at::cuda::CUDAEvent nppDoneEvent;
249243
at::cuda::CUDAStream nppStreamWrapper =
250-
c10::cuda::getStreamFromExternal(nppGetStream(), device.index());
244+
c10::cuda::getStreamFromExternal(nppGetStream(), device_.index());
251245
nppDoneEvent.record(nppStreamWrapper);
252246
nppDoneEvent.block(at::cuda::getCurrentCUDAStream());
253247

@@ -262,11 +256,7 @@ void convertAVFrameToFrameOutputOnCuda(
262256
// we have to do this because of an FFmpeg bug where hardware decoding is not
263257
// appropriately set, so we just go off and find the matching codec for the CUDA
264258
// device
265-
std::optional<const AVCodec*> findCudaCodec(
266-
const torch::Device& device,
267-
const AVCodecID& codecId) {
268-
throwErrorIfNonCudaDevice(device);
269-
259+
std::optional<const AVCodec*> CudaDevice::findCodec(const AVCodecID& codecId) {
270260
void* i = nullptr;
271261
const AVCodec* codec = nullptr;
272262
while ((codec = av_codec_iterate(&i)) != nullptr) {
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include "src/torchcodec/decoders/_core/DeviceInterface.h"
10+
11+
namespace facebook::torchcodec {
12+
13+
class CudaDevice : public DeviceInterface {
14+
public:
15+
CudaDevice(const std::string& device);
16+
17+
virtual ~CudaDevice(){};
18+
19+
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
20+
21+
void initializeContext(AVCodecContext* codecContext) override;
22+
23+
void convertAVFrameToFrameOutput(
24+
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
25+
UniqueAVFrame& avFrame,
26+
VideoDecoder::FrameOutput& frameOutput,
27+
std::optional<torch::Tensor> preAllocatedOutputTensor =
28+
std::nullopt) override;
29+
30+
void releaseContext(AVCodecContext* codecContext) override;
31+
};
32+
33+
} // namespace facebook::torchcodec
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "src/torchcodec/decoders/_core/DeviceInterface.h"
8+
#include <map>
9+
#include <mutex>
10+
11+
namespace facebook::torchcodec {
12+
13+
namespace {
14+
std::mutex g_interface_mutex;
15+
std::map<std::string, CreateDeviceInterfaceFn> g_interface_map;
16+
17+
std::string getDeviceType(const std::string& device) {
18+
size_t pos = device.find(':');
19+
if (pos == std::string::npos) {
20+
return device;
21+
}
22+
return device.substr(0, pos);
23+
}
24+
25+
} // namespace
26+
27+
bool registerDeviceInterface(
28+
const std::string deviceType,
29+
CreateDeviceInterfaceFn createInterface) {
30+
std::scoped_lock lock(g_interface_mutex);
31+
TORCH_CHECK(
32+
g_interface_map.find(deviceType) == g_interface_map.end(),
33+
"Device interface already registered for ",
34+
deviceType);
35+
g_interface_map.insert({deviceType, createInterface});
36+
return true;
37+
}
38+
39+
std::shared_ptr<DeviceInterface> createDeviceInterface(
40+
const std::string device) {
41+
// TODO: remove once DeviceInterface for CPU is implemented
42+
if (device == "cpu") {
43+
return nullptr;
44+
}
45+
46+
std::scoped_lock lock(g_interface_mutex);
47+
std::string deviceType = getDeviceType(device);
48+
TORCH_CHECK(
49+
g_interface_map.find(deviceType) != g_interface_map.end(),
50+
"Unsupported device: ",
51+
device);
52+
53+
return std::shared_ptr<DeviceInterface>(g_interface_map[deviceType](device));
54+
}
55+
56+
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/DeviceInterface.h

+38-20
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#pragma once
88

99
#include <torch/types.h>
10+
#include <functional>
1011
#include <memory>
1112
#include <stdexcept>
1213
#include <string>
@@ -23,25 +24,42 @@ namespace facebook::torchcodec {
2324
// deviceFunction(device, ...);
2425
// }
2526

26-
// Initialize the hardware device that is specified in `device`. Some builds
27-
// support CUDA and others only support CPU.
28-
void initializeContextOnCuda(
29-
const torch::Device& device,
30-
AVCodecContext* codecContext);
31-
32-
void convertAVFrameToFrameOutputOnCuda(
33-
const torch::Device& device,
34-
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
35-
UniqueAVFrame& avFrame,
36-
VideoDecoder::FrameOutput& frameOutput,
37-
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
38-
39-
void releaseContextOnCuda(
40-
const torch::Device& device,
41-
AVCodecContext* codecContext);
42-
43-
std::optional<const AVCodec*> findCudaCodec(
44-
const torch::Device& device,
45-
const AVCodecID& codecId);
27+
class DeviceInterface {
28+
public:
29+
DeviceInterface(const std::string& device) : device_(device) {}
30+
31+
virtual ~DeviceInterface(){};
32+
33+
torch::Device& device() {
34+
return device_;
35+
};
36+
37+
virtual std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) = 0;
38+
39+
// Initialize the hardware device that is specified in `device`. Some builds
40+
// support CUDA and others only support CPU.
41+
virtual void initializeContext(AVCodecContext* codecContext) = 0;
42+
43+
virtual void convertAVFrameToFrameOutput(
44+
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
45+
UniqueAVFrame& avFrame,
46+
VideoDecoder::FrameOutput& frameOutput,
47+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;
48+
49+
virtual void releaseContext(AVCodecContext* codecContext) = 0;
50+
51+
protected:
52+
torch::Device device_;
53+
};
54+
55+
using CreateDeviceInterfaceFn =
56+
std::function<DeviceInterface*(const std::string& device)>;
57+
58+
bool registerDeviceInterface(
59+
const std::string deviceType,
60+
const CreateDeviceInterfaceFn createInterface);
61+
62+
std::shared_ptr<DeviceInterface> createDeviceInterface(
63+
const std::string device);
4664

4765
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)