Skip to content

Commit 8b19f45

Browse files
authored
Make device interface generic (#606)
Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent e611e29 commit 8b19f45

10 files changed

+197
-142
lines changed

Diff for: src/torchcodec/_core/CMakeLists.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ function(make_torchcodec_libraries
6060
set(decoder_sources
6161
AVIOContextHolder.cpp
6262
FFMPEGCommon.cpp
63+
DeviceInterface.cpp
6364
SingleStreamDecoder.cpp
6465
# TODO: lib name should probably not be "*_decoder*" now that it also
6566
# contains an encoder
@@ -68,8 +69,6 @@ function(make_torchcodec_libraries
6869

6970
if(ENABLE_CUDA)
7071
list(APPEND decoder_sources CudaDevice.cpp)
71-
else()
72-
list(APPEND decoder_sources CPUOnlyDevice.cpp)
7372
endif()
7473

7574
set(decoder_library_dependencies

Diff for: src/torchcodec/_core/CPUOnlyDevice.cpp

-45
This file was deleted.

Diff for: src/torchcodec/_core/CudaDevice.cpp

+27-32
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <torch/types.h>
55
#include <mutex>
66

7-
#include "src/torchcodec/_core/DeviceInterface.h"
7+
#include "src/torchcodec/_core/CudaDevice.h"
88
#include "src/torchcodec/_core/FFMPEGCommon.h"
99
#include "src/torchcodec/_core/SingleStreamDecoder.h"
1010

@@ -16,6 +16,10 @@ extern "C" {
1616
namespace facebook::torchcodec {
1717
namespace {
1818

19+
bool g_cuda = registerDeviceInterface(
20+
torch::kCUDA,
21+
[](const torch::Device& device) { return new CudaDevice(device); });
22+
1923
// We reuse cuda contexts across VideoDeoder instances. This is because
2024
// creating a cuda context is expensive. The cache mechanism is as follows:
2125
// 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for
@@ -49,7 +53,7 @@ torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) {
4953

5054
void addToCacheIfCacheHasCapacity(
5155
const torch::Device& device,
52-
AVCodecContext* codecContext) {
56+
AVBufferRef* hwContext) {
5357
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device);
5458
if (static_cast<int>(deviceIndex) >= MAX_CUDA_GPUS) {
5559
return;
@@ -60,8 +64,7 @@ void addToCacheIfCacheHasCapacity(
6064
MAX_CONTEXTS_PER_GPU_IN_CACHE) {
6165
return;
6266
}
63-
g_cached_hw_device_ctxs[deviceIndex].push_back(codecContext->hw_device_ctx);
64-
codecContext->hw_device_ctx = nullptr;
67+
g_cached_hw_device_ctxs[deviceIndex].push_back(av_buffer_ref(hwContext));
6568
}
6669

6770
AVBufferRef* getFromCache(const torch::Device& device) {
@@ -158,39 +161,35 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
158161
device, nonNegativeDeviceIndex, type);
159162
#endif
160163
}
164+
} // namespace
161165

162-
void throwErrorIfNonCudaDevice(const torch::Device& device) {
163-
TORCH_CHECK(
164-
device.type() != torch::kCPU,
165-
"Device functions should only be called if the device is not CPU.")
166-
if (device.type() != torch::kCUDA) {
167-
throw std::runtime_error("Unsupported device: " + device.str());
166+
CudaDevice::CudaDevice(const torch::Device& device) : DeviceInterface(device) {
167+
if (device_.type() != torch::kCUDA) {
168+
throw std::runtime_error("Unsupported device: " + device_.str());
168169
}
169170
}
170-
} // namespace
171171

172-
void releaseContextOnCuda(
173-
const torch::Device& device,
174-
AVCodecContext* codecContext) {
175-
throwErrorIfNonCudaDevice(device);
176-
addToCacheIfCacheHasCapacity(device, codecContext);
172+
CudaDevice::~CudaDevice() {
173+
if (ctx_) {
174+
addToCacheIfCacheHasCapacity(device_, ctx_);
175+
av_buffer_unref(&ctx_);
176+
}
177177
}
178178

179-
void initializeContextOnCuda(
180-
const torch::Device& device,
181-
AVCodecContext* codecContext) {
182-
throwErrorIfNonCudaDevice(device);
179+
void CudaDevice::initializeContext(AVCodecContext* codecContext) {
180+
TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized");
181+
183182
// It is important for pytorch itself to create the cuda context. If ffmpeg
184183
// creates the context it may not be compatible with pytorch.
185184
// This is a dummy tensor to initialize the cuda context.
186185
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
187-
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device));
188-
codecContext->hw_device_ctx = getCudaContext(device);
186+
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
187+
ctx_ = getCudaContext(device_);
188+
codecContext->hw_device_ctx = av_buffer_ref(ctx_);
189189
return;
190190
}
191191

192-
void convertAVFrameToFrameOutputOnCuda(
193-
const torch::Device& device,
192+
void CudaDevice::convertAVFrameToFrameOutput(
194193
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
195194
UniqueAVFrame& avFrame,
196195
SingleStreamDecoder::FrameOutput& frameOutput,
@@ -217,11 +216,11 @@ void convertAVFrameToFrameOutputOnCuda(
217216
"x3, got ",
218217
shape);
219218
} else {
220-
dst = allocateEmptyHWCTensor(height, width, videoStreamOptions.device);
219+
dst = allocateEmptyHWCTensor(height, width, device_);
221220
}
222221

223222
// Use the user-requested GPU for running the NPP kernel.
224-
c10::cuda::CUDAGuard deviceGuard(device);
223+
c10::cuda::CUDAGuard deviceGuard(device_);
225224

226225
NppiSize oSizeROI = {width, height};
227226
Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]};
@@ -249,7 +248,7 @@ void convertAVFrameToFrameOutputOnCuda(
249248
// output.
250249
at::cuda::CUDAEvent nppDoneEvent;
251250
at::cuda::CUDAStream nppStreamWrapper =
252-
c10::cuda::getStreamFromExternal(nppGetStream(), device.index());
251+
c10::cuda::getStreamFromExternal(nppGetStream(), device_.index());
253252
nppDoneEvent.record(nppStreamWrapper);
254253
nppDoneEvent.block(at::cuda::getCurrentCUDAStream());
255254

@@ -264,11 +263,7 @@ void convertAVFrameToFrameOutputOnCuda(
264263
// we have to do this because of an FFmpeg bug where hardware decoding is not
265264
// appropriately set, so we just go off and find the matching codec for the CUDA
266265
// device
267-
std::optional<const AVCodec*> findCudaCodec(
268-
const torch::Device& device,
269-
const AVCodecID& codecId) {
270-
throwErrorIfNonCudaDevice(device);
271-
266+
std::optional<const AVCodec*> CudaDevice::findCodec(const AVCodecID& codecId) {
272267
void* i = nullptr;
273268
const AVCodec* codec = nullptr;
274269
while ((codec = av_codec_iterate(&i)) != nullptr) {

Diff for: src/torchcodec/_core/CudaDevice.h

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include "src/torchcodec/_core/DeviceInterface.h"
10+
11+
namespace facebook::torchcodec {
12+
13+
class CudaDevice : public DeviceInterface {
14+
public:
15+
CudaDevice(const torch::Device& device);
16+
17+
virtual ~CudaDevice();
18+
19+
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
20+
21+
void initializeContext(AVCodecContext* codecContext) override;
22+
23+
void convertAVFrameToFrameOutput(
24+
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
25+
UniqueAVFrame& avFrame,
26+
SingleStreamDecoder::FrameOutput& frameOutput,
27+
std::optional<torch::Tensor> preAllocatedOutputTensor =
28+
std::nullopt) override;
29+
30+
private:
31+
AVBufferRef* ctx_ = nullptr;
32+
};
33+
34+
} // namespace facebook::torchcodec

Diff for: src/torchcodec/_core/DeviceInterface.cpp

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "src/torchcodec/_core/DeviceInterface.h"
8+
#include <map>
9+
#include <mutex>
10+
11+
namespace facebook::torchcodec {
12+
13+
namespace {
14+
std::mutex g_interface_mutex;
15+
std::map<torch::DeviceType, CreateDeviceInterfaceFn> g_interface_map;
16+
17+
std::string getDeviceType(const std::string& device) {
18+
size_t pos = device.find(':');
19+
if (pos == std::string::npos) {
20+
return device;
21+
}
22+
return device.substr(0, pos);
23+
}
24+
25+
} // namespace
26+
27+
bool registerDeviceInterface(
28+
torch::DeviceType deviceType,
29+
CreateDeviceInterfaceFn createInterface) {
30+
std::scoped_lock lock(g_interface_mutex);
31+
TORCH_CHECK(
32+
g_interface_map.find(deviceType) == g_interface_map.end(),
33+
"Device interface already registered for ",
34+
deviceType);
35+
g_interface_map.insert({deviceType, createInterface});
36+
return true;
37+
}
38+
39+
torch::Device createTorchDevice(const std::string device) {
40+
// TODO: remove once DeviceInterface for CPU is implemented
41+
if (device == "cpu") {
42+
return torch::kCPU;
43+
}
44+
45+
std::scoped_lock lock(g_interface_mutex);
46+
std::string deviceType = getDeviceType(device);
47+
auto deviceInterface = std::find_if(
48+
g_interface_map.begin(),
49+
g_interface_map.end(),
50+
[&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>& arg) {
51+
return device.rfind(
52+
torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0;
53+
});
54+
TORCH_CHECK(
55+
deviceInterface != g_interface_map.end(), "Unsupported device: ", device);
56+
57+
return torch::Device(device);
58+
}
59+
60+
std::unique_ptr<DeviceInterface> createDeviceInterface(
61+
const torch::Device& device) {
62+
auto deviceType = device.type();
63+
// TODO: remove once DeviceInterface for CPU is implemented
64+
if (deviceType == torch::kCPU) {
65+
return nullptr;
66+
}
67+
68+
std::scoped_lock lock(g_interface_mutex);
69+
TORCH_CHECK(
70+
g_interface_map.find(deviceType) != g_interface_map.end(),
71+
"Unsupported device: ",
72+
device);
73+
74+
return std::unique_ptr<DeviceInterface>(g_interface_map[deviceType](device));
75+
}
76+
77+
} // namespace facebook::torchcodec

Diff for: src/torchcodec/_core/DeviceInterface.h

+38-20
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#pragma once
88

99
#include <torch/types.h>
10+
#include <functional>
1011
#include <memory>
1112
#include <stdexcept>
1213
#include <string>
@@ -23,25 +24,42 @@ namespace facebook::torchcodec {
2324
// deviceFunction(device, ...);
2425
// }
2526

26-
// Initialize the hardware device that is specified in `device`. Some builds
27-
// support CUDA and others only support CPU.
28-
void initializeContextOnCuda(
29-
const torch::Device& device,
30-
AVCodecContext* codecContext);
31-
32-
void convertAVFrameToFrameOutputOnCuda(
33-
const torch::Device& device,
34-
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
35-
UniqueAVFrame& avFrame,
36-
SingleStreamDecoder::FrameOutput& frameOutput,
37-
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
38-
39-
void releaseContextOnCuda(
40-
const torch::Device& device,
41-
AVCodecContext* codecContext);
42-
43-
std::optional<const AVCodec*> findCudaCodec(
44-
const torch::Device& device,
45-
const AVCodecID& codecId);
27+
class DeviceInterface {
28+
public:
29+
DeviceInterface(const torch::Device& device) : device_(device) {}
30+
31+
virtual ~DeviceInterface(){};
32+
33+
torch::Device& device() {
34+
return device_;
35+
};
36+
37+
virtual std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) = 0;
38+
39+
// Initialize the hardware device that is specified in `device`. Some builds
40+
// support CUDA and others only support CPU.
41+
virtual void initializeContext(AVCodecContext* codecContext) = 0;
42+
43+
virtual void convertAVFrameToFrameOutput(
44+
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
45+
UniqueAVFrame& avFrame,
46+
SingleStreamDecoder::FrameOutput& frameOutput,
47+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;
48+
49+
protected:
50+
torch::Device device_;
51+
};
52+
53+
using CreateDeviceInterfaceFn =
54+
std::function<DeviceInterface*(const torch::Device& device)>;
55+
56+
bool registerDeviceInterface(
57+
torch::DeviceType deviceType,
58+
const CreateDeviceInterfaceFn createInterface);
59+
60+
torch::Device createTorchDevice(const std::string device);
61+
62+
std::unique_ptr<DeviceInterface> createDeviceInterface(
63+
const torch::Device& device);
4664

4765
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)