4
4
#include < torch/types.h>
5
5
#include < mutex>
6
6
7
- #include " src/torchcodec/_core/DeviceInterface .h"
7
+ #include " src/torchcodec/_core/CudaDevice .h"
8
8
#include " src/torchcodec/_core/FFMPEGCommon.h"
9
9
#include " src/torchcodec/_core/SingleStreamDecoder.h"
10
10
@@ -16,6 +16,10 @@ extern "C" {
16
16
namespace facebook ::torchcodec {
17
17
namespace {
18
18
19
+ bool g_cuda = registerDeviceInterface(" cuda" , [](const std::string& device) {
20
+ return new CudaDevice (device);
21
+ });
22
+
19
23
// We reuse cuda contexts across VideoDeoder instances. This is because
20
24
// creating a cuda context is expensive. The cache mechanism is as follows:
21
25
// 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for
@@ -158,39 +162,29 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
158
162
device, nonNegativeDeviceIndex, type);
159
163
#endif
160
164
}
165
+ } // namespace
161
166
162
- void throwErrorIfNonCudaDevice (const torch::Device& device) {
163
- TORCH_CHECK (
164
- device.type () != torch::kCPU ,
165
- " Device functions should only be called if the device is not CPU." )
166
- if (device.type () != torch::kCUDA ) {
167
- throw std::runtime_error (" Unsupported device: " + device.str ());
167
+ CudaDevice::CudaDevice (const std::string& device) : DeviceInterface(device) {
168
+ if (device_.type () != torch::kCUDA ) {
169
+ throw std::runtime_error (" Unsupported device: " + device_.str ());
168
170
}
169
171
}
170
- } // namespace
171
172
172
- void releaseContextOnCuda (
173
- const torch::Device& device,
174
- AVCodecContext* codecContext) {
175
- throwErrorIfNonCudaDevice (device);
176
- addToCacheIfCacheHasCapacity (device, codecContext);
173
+ void CudaDevice::releaseContext (AVCodecContext* codecContext) {
174
+ addToCacheIfCacheHasCapacity (device_, codecContext);
177
175
}
178
176
179
- void initializeContextOnCuda (
180
- const torch::Device& device,
181
- AVCodecContext* codecContext) {
182
- throwErrorIfNonCudaDevice (device);
177
+ void CudaDevice::initializeContext (AVCodecContext* codecContext) {
183
178
// It is important for pytorch itself to create the cuda context. If ffmpeg
184
179
// creates the context it may not be compatible with pytorch.
185
180
// This is a dummy tensor to initialize the cuda context.
186
181
torch::Tensor dummyTensorForCudaInitialization = torch::empty (
187
- {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device ));
188
- codecContext->hw_device_ctx = getCudaContext (device );
182
+ {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_ ));
183
+ codecContext->hw_device_ctx = getCudaContext (device_ );
189
184
return ;
190
185
}
191
186
192
- void convertAVFrameToFrameOutputOnCuda (
193
- const torch::Device& device,
187
+ void CudaDevice::convertAVFrameToFrameOutput (
194
188
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
195
189
UniqueAVFrame& avFrame,
196
190
SingleStreamDecoder::FrameOutput& frameOutput,
@@ -217,11 +211,11 @@ void convertAVFrameToFrameOutputOnCuda(
217
211
" x3, got " ,
218
212
shape);
219
213
} else {
220
- dst = allocateEmptyHWCTensor (height, width, videoStreamOptions. device );
214
+ dst = allocateEmptyHWCTensor (height, width, device_ );
221
215
}
222
216
223
217
// Use the user-requested GPU for running the NPP kernel.
224
- c10::cuda::CUDAGuard deviceGuard (device );
218
+ c10::cuda::CUDAGuard deviceGuard (device_ );
225
219
226
220
NppiSize oSizeROI = {width, height};
227
221
Npp8u* input[2 ] = {avFrame->data [0 ], avFrame->data [1 ]};
@@ -249,7 +243,7 @@ void convertAVFrameToFrameOutputOnCuda(
249
243
// output.
250
244
at::cuda::CUDAEvent nppDoneEvent;
251
245
at::cuda::CUDAStream nppStreamWrapper =
252
- c10::cuda::getStreamFromExternal (nppGetStream (), device .index ());
246
+ c10::cuda::getStreamFromExternal (nppGetStream (), device_ .index ());
253
247
nppDoneEvent.record (nppStreamWrapper);
254
248
nppDoneEvent.block (at::cuda::getCurrentCUDAStream ());
255
249
@@ -264,11 +258,7 @@ void convertAVFrameToFrameOutputOnCuda(
264
258
// we have to do this because of an FFmpeg bug where hardware decoding is not
265
259
// appropriately set, so we just go off and find the matching codec for the CUDA
266
260
// device
267
- std::optional<const AVCodec*> findCudaCodec (
268
- const torch::Device& device,
269
- const AVCodecID& codecId) {
270
- throwErrorIfNonCudaDevice (device);
271
-
261
+ std::optional<const AVCodec*> CudaDevice::findCodec (const AVCodecID& codecId) {
272
262
void * i = nullptr ;
273
263
const AVCodec* codec = nullptr ;
274
264
while ((codec = av_codec_iterate (&i)) != nullptr ) {
0 commit comments