Skip to content

Commit faa2fd9

Browse files
ahmadsharif1NicolasHug
authored andcommitted
Added a decode+resize benchmark and cuda decoder (pytorch#378)
1 parent 0256788 commit faa2fd9

File tree

5 files changed

+407
-128
lines changed

5 files changed

+407
-128
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ format you want. Refer to Nvidia's GPU support matrix for more details
185185

186186
## Benchmark Results
187187

188-
The following was generated by running [our benchmark script](./benchmarks/decoders/generate_readme_data.py) on a lightly loaded 56-core machine.
188+
The following was generated by running [our benchmark script](./benchmarks/decoders/generate_readme_data.py) on a lightly loaded 22-core machine with an Nvidia A100 with
189+
5 [NVDEC decoders](https://docs.nvidia.com/video-technologies/video-codec-sdk/12.1/nvdec-application-note/index.html#).
189190

190191
![benchmark_results](./benchmarks/decoders/benchmark_readme_chart.png)
191192

benchmarks/decoders/benchmark_decoders_library.py

+97-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ def __init__(self):
3838
def get_frames_from_video(self, video_file, pts_list):
3939
pass
4040

41+
@abc.abstractmethod
42+
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
43+
pass
44+
45+
@abc.abstractmethod
46+
def decode_and_transform(self, video_file, pts_list, height, width, device):
47+
pass
48+
4149

4250
class DecordAccurate(AbstractDecoder):
4351
def __init__(self):
@@ -89,8 +97,10 @@ def __init__(self, backend):
8997
self._backend = backend
9098
self._print_each_iteration_time = False
9199
import torchvision # noqa: F401
100+
from torchvision.transforms import v2 as transforms_v2
92101

93102
self.torchvision = torchvision
103+
self.transforms_v2 = transforms_v2
94104

95105
def get_frames_from_video(self, video_file, pts_list):
96106
self.torchvision.set_video_backend(self._backend)
@@ -111,6 +121,20 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
111121
frames.append(frame["data"].permute(1, 2, 0))
112122
return frames
113123

124+
def decode_and_transform(self, video_file, pts_list, height, width, device):
125+
self.torchvision.set_video_backend(self._backend)
126+
reader = self.torchvision.io.VideoReader(video_file, "video")
127+
frames = []
128+
for pts in pts_list:
129+
reader.seek(pts)
130+
frame = next(reader)
131+
frames.append(frame["data"].permute(1, 2, 0))
132+
frames = [
133+
self.transforms_v2.functional.resize(frame.to(device), (height, width))
134+
for frame in frames
135+
]
136+
return frames
137+
114138

115139
class TorchCodecCore(AbstractDecoder):
116140
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
@@ -239,6 +263,10 @@ def __init__(self, num_ffmpeg_threads=None, device="cpu"):
239263
)
240264
self._device = device
241265

266+
from torchvision.transforms import v2 as transforms_v2
267+
268+
self.transforms_v2 = transforms_v2
269+
242270
def get_frames_from_video(self, video_file, pts_list):
243271
decoder = VideoDecoder(
244272
video_file, num_ffmpeg_threads=self._num_ffmpeg_threads, device=self._device
@@ -258,6 +286,14 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
258286
break
259287
return frames
260288

289+
def decode_and_transform(self, video_file, pts_list, height, width, device):
290+
decoder = VideoDecoder(
291+
video_file, num_ffmpeg_threads=self._num_ffmpeg_threads, device=self._device
292+
)
293+
frames = decoder.get_frames_played_at(pts_list)
294+
frames = self.transforms_v2.functional.resize(frames.data, (height, width))
295+
return frames
296+
261297

262298
@torch.compile(fullgraph=True, backend="eager")
263299
def compiled_seek_and_next(decoder, pts):
@@ -299,7 +335,9 @@ def __init__(self):
299335

300336
self.torchaudio = torchaudio
301337

302-
pass
338+
from torchvision.transforms import v2 as transforms_v2
339+
340+
self.transforms_v2 = transforms_v2
303341

304342
def get_frames_from_video(self, video_file, pts_list):
305343
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
@@ -325,6 +363,21 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
325363

326364
return frames
327365

366+
def decode_and_transform(self, video_file, pts_list, height, width, device):
367+
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
368+
stream_reader.add_basic_video_stream(frames_per_chunk=1)
369+
frames = []
370+
for pts in pts_list:
371+
stream_reader.seek(pts)
372+
stream_reader.fill_buffer()
373+
clip = stream_reader.pop_chunks()
374+
frames.append(clip[0][0])
375+
frames = [
376+
self.transforms_v2.functional.resize(frame.to(device), (height, width))
377+
for frame in frames
378+
]
379+
return frames
380+
328381

329382
def create_torchcodec_decoder_from_file(video_file):
330383
video_decoder = create_from_file(video_file)
@@ -443,7 +496,7 @@ def plot_data(df_data, plot_path):
443496

444497
# Set the title for the subplot
445498
base_video = Path(video).name.removesuffix(".mp4")
446-
ax.set_title(f"{base_video}\n{vcount} x {vtype}", fontsize=11)
499+
ax.set_title(f"{base_video}\n{vtype}", fontsize=11)
447500

448501
# Plot bars with error bars
449502
ax.barh(
@@ -486,6 +539,14 @@ class BatchParameters:
486539
batch_size: int
487540

488541

542+
@dataclass
543+
class DataLoaderInspiredWorkloadParameters:
544+
batch_parameters: BatchParameters
545+
resize_height: int
546+
resize_width: int
547+
resize_device: str
548+
549+
489550
def run_batch_using_threads(
490551
function,
491552
*args,
@@ -525,6 +586,7 @@ def run_benchmarks(
525586
num_sequential_frames_from_start: list[int],
526587
min_runtime_seconds: float,
527588
benchmark_video_creation: bool,
589+
dataloader_parameters: DataLoaderInspiredWorkloadParameters = None,
528590
batch_parameters: BatchParameters = None,
529591
) -> list[dict[str, str | float | int]]:
530592
# Ensure that we have the same seed across benchmark runs.
@@ -550,6 +612,39 @@ def run_benchmarks(
550612
for decoder_name, decoder in decoder_dict.items():
551613
print(f"video={video_file_path}, decoder={decoder_name}")
552614

615+
if dataloader_parameters:
616+
bp = dataloader_parameters.batch_parameters
617+
dataloader_result = benchmark.Timer(
618+
stmt="run_batch_using_threads(decoder.decode_and_transform, video_file, pts_list, height, width, device, batch_parameters=batch_parameters)",
619+
globals={
620+
"video_file": str(video_file_path),
621+
"pts_list": uniform_pts_list,
622+
"decoder": decoder,
623+
"run_batch_using_threads": run_batch_using_threads,
624+
"batch_parameters": dataloader_parameters.batch_parameters,
625+
"height": dataloader_parameters.resize_height,
626+
"width": dataloader_parameters.resize_width,
627+
"device": dataloader_parameters.resize_device,
628+
},
629+
label=f"video={video_file_path} {metadata_label}",
630+
sub_label=decoder_name,
631+
description=f"dataloader[threads={bp.num_threads} batch_size={bp.batch_size}] {num_samples} decode_and_transform()",
632+
)
633+
results.append(
634+
dataloader_result.blocked_autorange(
635+
min_run_time=min_runtime_seconds
636+
)
637+
)
638+
df_data.append(
639+
convert_result_to_df_item(
640+
results[-1],
641+
decoder_name,
642+
video_file_path,
643+
num_samples * dataloader_parameters.batch_parameters.batch_size,
644+
f"dataloader[threads={bp.num_threads} batch_size={bp.batch_size}] {num_samples} x decode_and_transform()",
645+
)
646+
)
647+
553648
for kind, pts_list in [
554649
("uniform", uniform_pts_list),
555650
("random", random_pts_list),
28.5 KB
Loading

0 commit comments

Comments
 (0)