@@ -38,6 +38,14 @@ def __init__(self):
38
38
def get_frames_from_video (self , video_file , pts_list ):
39
39
pass
40
40
41
+ @abc .abstractmethod
42
+ def get_consecutive_frames_from_video (self , video_file , numFramesToDecode ):
43
+ pass
44
+
45
+ @abc .abstractmethod
46
+ def decode_and_transform (self , video_file , pts_list , height , width , device ):
47
+ pass
48
+
41
49
42
50
class DecordAccurate (AbstractDecoder ):
43
51
def __init__ (self ):
@@ -89,8 +97,10 @@ def __init__(self, backend):
89
97
self ._backend = backend
90
98
self ._print_each_iteration_time = False
91
99
import torchvision # noqa: F401
100
+ from torchvision .transforms import v2 as transforms_v2
92
101
93
102
self .torchvision = torchvision
103
+ self .transforms_v2 = transforms_v2
94
104
95
105
def get_frames_from_video (self , video_file , pts_list ):
96
106
self .torchvision .set_video_backend (self ._backend )
@@ -111,6 +121,20 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
111
121
frames .append (frame ["data" ].permute (1 , 2 , 0 ))
112
122
return frames
113
123
124
+ def decode_and_transform (self , video_file , pts_list , height , width , device ):
125
+ self .torchvision .set_video_backend (self ._backend )
126
+ reader = self .torchvision .io .VideoReader (video_file , "video" )
127
+ frames = []
128
+ for pts in pts_list :
129
+ reader .seek (pts )
130
+ frame = next (reader )
131
+ frames .append (frame ["data" ].permute (1 , 2 , 0 ))
132
+ frames = [
133
+ self .transforms_v2 .functional .resize (frame .to (device ), (height , width ))
134
+ for frame in frames
135
+ ]
136
+ return frames
137
+
114
138
115
139
class TorchCodecCore (AbstractDecoder ):
116
140
def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
@@ -239,6 +263,10 @@ def __init__(self, num_ffmpeg_threads=None, device="cpu"):
239
263
)
240
264
self ._device = device
241
265
266
+ from torchvision .transforms import v2 as transforms_v2
267
+
268
+ self .transforms_v2 = transforms_v2
269
+
242
270
def get_frames_from_video (self , video_file , pts_list ):
243
271
decoder = VideoDecoder (
244
272
video_file , num_ffmpeg_threads = self ._num_ffmpeg_threads , device = self ._device
@@ -258,6 +286,14 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
258
286
break
259
287
return frames
260
288
289
+ def decode_and_transform (self , video_file , pts_list , height , width , device ):
290
+ decoder = VideoDecoder (
291
+ video_file , num_ffmpeg_threads = self ._num_ffmpeg_threads , device = self ._device
292
+ )
293
+ frames = decoder .get_frames_played_at (pts_list )
294
+ frames = self .transforms_v2 .functional .resize (frames .data , (height , width ))
295
+ return frames
296
+
261
297
262
298
@torch .compile (fullgraph = True , backend = "eager" )
263
299
def compiled_seek_and_next (decoder , pts ):
@@ -299,7 +335,9 @@ def __init__(self):
299
335
300
336
self .torchaudio = torchaudio
301
337
302
- pass
338
+ from torchvision .transforms import v2 as transforms_v2
339
+
340
+ self .transforms_v2 = transforms_v2
303
341
304
342
def get_frames_from_video (self , video_file , pts_list ):
305
343
stream_reader = self .torchaudio .io .StreamReader (src = video_file )
@@ -325,6 +363,21 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
325
363
326
364
return frames
327
365
366
+ def decode_and_transform (self , video_file , pts_list , height , width , device ):
367
+ stream_reader = self .torchaudio .io .StreamReader (src = video_file )
368
+ stream_reader .add_basic_video_stream (frames_per_chunk = 1 )
369
+ frames = []
370
+ for pts in pts_list :
371
+ stream_reader .seek (pts )
372
+ stream_reader .fill_buffer ()
373
+ clip = stream_reader .pop_chunks ()
374
+ frames .append (clip [0 ][0 ])
375
+ frames = [
376
+ self .transforms_v2 .functional .resize (frame .to (device ), (height , width ))
377
+ for frame in frames
378
+ ]
379
+ return frames
380
+
328
381
329
382
def create_torchcodec_decoder_from_file (video_file ):
330
383
video_decoder = create_from_file (video_file )
@@ -443,7 +496,7 @@ def plot_data(df_data, plot_path):
443
496
444
497
# Set the title for the subplot
445
498
base_video = Path (video ).name .removesuffix (".mp4" )
446
- ax .set_title (f"{ base_video } \n { vcount } x { vtype } " , fontsize = 11 )
499
+ ax .set_title (f"{ base_video } \n { vtype } " , fontsize = 11 )
447
500
448
501
# Plot bars with error bars
449
502
ax .barh (
@@ -486,6 +539,14 @@ class BatchParameters:
486
539
batch_size : int
487
540
488
541
542
+ @dataclass
543
+ class DataLoaderInspiredWorkloadParameters :
544
+ batch_parameters : BatchParameters
545
+ resize_height : int
546
+ resize_width : int
547
+ resize_device : str
548
+
549
+
489
550
def run_batch_using_threads (
490
551
function ,
491
552
* args ,
@@ -525,6 +586,7 @@ def run_benchmarks(
525
586
num_sequential_frames_from_start : list [int ],
526
587
min_runtime_seconds : float ,
527
588
benchmark_video_creation : bool ,
589
+ dataloader_parameters : DataLoaderInspiredWorkloadParameters = None ,
528
590
batch_parameters : BatchParameters = None ,
529
591
) -> list [dict [str , str | float | int ]]:
530
592
# Ensure that we have the same seed across benchmark runs.
@@ -550,6 +612,39 @@ def run_benchmarks(
550
612
for decoder_name , decoder in decoder_dict .items ():
551
613
print (f"video={ video_file_path } , decoder={ decoder_name } " )
552
614
615
+ if dataloader_parameters :
616
+ bp = dataloader_parameters .batch_parameters
617
+ dataloader_result = benchmark .Timer (
618
+ stmt = "run_batch_using_threads(decoder.decode_and_transform, video_file, pts_list, height, width, device, batch_parameters=batch_parameters)" ,
619
+ globals = {
620
+ "video_file" : str (video_file_path ),
621
+ "pts_list" : uniform_pts_list ,
622
+ "decoder" : decoder ,
623
+ "run_batch_using_threads" : run_batch_using_threads ,
624
+ "batch_parameters" : dataloader_parameters .batch_parameters ,
625
+ "height" : dataloader_parameters .resize_height ,
626
+ "width" : dataloader_parameters .resize_width ,
627
+ "device" : dataloader_parameters .resize_device ,
628
+ },
629
+ label = f"video={ video_file_path } { metadata_label } " ,
630
+ sub_label = decoder_name ,
631
+ description = f"dataloader[threads={ bp .num_threads } batch_size={ bp .batch_size } ] { num_samples } decode_and_transform()" ,
632
+ )
633
+ results .append (
634
+ dataloader_result .blocked_autorange (
635
+ min_run_time = min_runtime_seconds
636
+ )
637
+ )
638
+ df_data .append (
639
+ convert_result_to_df_item (
640
+ results [- 1 ],
641
+ decoder_name ,
642
+ video_file_path ,
643
+ num_samples * dataloader_parameters .batch_parameters .batch_size ,
644
+ f"dataloader[threads={ bp .num_threads } batch_size={ bp .batch_size } ] { num_samples } x decode_and_transform()" ,
645
+ )
646
+ )
647
+
553
648
for kind , pts_list in [
554
649
("uniform" , uniform_pts_list ),
555
650
("random" , random_pts_list ),
0 commit comments