4
4
#include < torch/types.h>
5
5
#include < mutex>
6
6
7
- #include " src/torchcodec/decoders/_core/DeviceInterface .h"
7
+ #include " src/torchcodec/decoders/_core/CudaDevice .h"
8
8
#include " src/torchcodec/decoders/_core/FFMPEGCommon.h"
9
9
#include " src/torchcodec/decoders/_core/VideoDecoder.h"
10
10
@@ -16,6 +16,10 @@ extern "C" {
16
16
namespace facebook ::torchcodec {
17
17
namespace {
18
18
19
+ bool g_cuda = registerDeviceInterface(" cuda" , [](const std::string& device) {
20
+ return new CudaDevice (device);
21
+ });
22
+
19
23
// We reuse cuda contexts across VideoDeoder instances. This is because
20
24
// creating a cuda context is expensive. The cache mechanism is as follows:
21
25
// 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for
@@ -156,39 +160,29 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
156
160
device, nonNegativeDeviceIndex, type);
157
161
#endif
158
162
}
163
+ } // namespace
159
164
160
- void throwErrorIfNonCudaDevice (const torch::Device& device) {
161
- TORCH_CHECK (
162
- device.type () != torch::kCPU ,
163
- " Device functions should only be called if the device is not CPU." )
164
- if (device.type () != torch::kCUDA ) {
165
- throw std::runtime_error (" Unsupported device: " + device.str ());
165
+ CudaDevice::CudaDevice (const std::string& device) : DeviceInterface(device) {
166
+ if (device_.type () != torch::kCUDA ) {
167
+ throw std::runtime_error (" Unsupported device: " + device_.str ());
166
168
}
167
169
}
168
- } // namespace
169
170
170
- void releaseContextOnCuda (
171
- const torch::Device& device,
172
- AVCodecContext* codecContext) {
173
- throwErrorIfNonCudaDevice (device);
174
- addToCacheIfCacheHasCapacity (device, codecContext);
171
+ void CudaDevice::releaseContext (AVCodecContext* codecContext) {
172
+ addToCacheIfCacheHasCapacity (device_, codecContext);
175
173
}
176
174
177
- void initializeContextOnCuda (
178
- const torch::Device& device,
179
- AVCodecContext* codecContext) {
180
- throwErrorIfNonCudaDevice (device);
175
+ void CudaDevice::initializeContext (AVCodecContext* codecContext) {
181
176
// It is important for pytorch itself to create the cuda context. If ffmpeg
182
177
// creates the context it may not be compatible with pytorch.
183
178
// This is a dummy tensor to initialize the cuda context.
184
179
torch::Tensor dummyTensorForCudaInitialization = torch::empty (
185
- {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device ));
186
- codecContext->hw_device_ctx = getCudaContext (device );
180
+ {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_ ));
181
+ codecContext->hw_device_ctx = getCudaContext (device_ );
187
182
return ;
188
183
}
189
184
190
- void convertAVFrameToFrameOutputOnCuda (
191
- const torch::Device& device,
185
+ void CudaDevice::convertAVFrameToFrameOutput (
192
186
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
193
187
UniqueAVFrame& avFrame,
194
188
VideoDecoder::FrameOutput& frameOutput,
@@ -215,11 +209,11 @@ void convertAVFrameToFrameOutputOnCuda(
215
209
" x3, got " ,
216
210
shape);
217
211
} else {
218
- dst = allocateEmptyHWCTensor (height, width, videoStreamOptions. device );
212
+ dst = allocateEmptyHWCTensor (height, width, device_ );
219
213
}
220
214
221
215
// Use the user-requested GPU for running the NPP kernel.
222
- c10::cuda::CUDAGuard deviceGuard (device );
216
+ c10::cuda::CUDAGuard deviceGuard (device_ );
223
217
224
218
NppiSize oSizeROI = {width, height};
225
219
Npp8u* input[2 ] = {avFrame->data [0 ], avFrame->data [1 ]};
@@ -247,7 +241,7 @@ void convertAVFrameToFrameOutputOnCuda(
247
241
// output.
248
242
at::cuda::CUDAEvent nppDoneEvent;
249
243
at::cuda::CUDAStream nppStreamWrapper =
250
- c10::cuda::getStreamFromExternal (nppGetStream (), device .index ());
244
+ c10::cuda::getStreamFromExternal (nppGetStream (), device_ .index ());
251
245
nppDoneEvent.record (nppStreamWrapper);
252
246
nppDoneEvent.block (at::cuda::getCurrentCUDAStream ());
253
247
@@ -262,11 +256,7 @@ void convertAVFrameToFrameOutputOnCuda(
262
256
// we have to do this because of an FFmpeg bug where hardware decoding is not
263
257
// appropriately set, so we just go off and find the matching codec for the CUDA
264
258
// device
265
- std::optional<const AVCodec*> findCudaCodec (
266
- const torch::Device& device,
267
- const AVCodecID& codecId) {
268
- throwErrorIfNonCudaDevice (device);
269
-
259
+ std::optional<const AVCodec*> CudaDevice::findCodec (const AVCodecID& codecId) {
270
260
void * i = nullptr ;
271
261
const AVCodec* codec = nullptr ;
272
262
while ((codec = av_codec_iterate (&i)) != nullptr ) {
0 commit comments