From 4400da1ad3fa5fcb55b6fa694271030fbb2d9a34 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 8 Nov 2024 05:49:45 -0800 Subject: [PATCH] Test x265 on CUDA --- test/decoders/test_video_decoder.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 19ca1a20..8b787322 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -11,7 +11,14 @@ from torchcodec.decoders import _core, VideoDecoder -from ..utils import assert_tensor_close, assert_tensor_equal, H265_VIDEO, NASA_VIDEO +from ..utils import ( + assert_tensor_close, + assert_tensor_equal, + cpu_and_cuda, + get_frame_compare_function, + H265_VIDEO, + NASA_VIDEO, +) class TestVideoDecoder: @@ -405,11 +412,13 @@ def test_get_frame_played_at(self): assert isinstance(decoder.get_frame_played_at(6.02).pts_seconds, float) assert isinstance(decoder.get_frame_played_at(6.02).duration_seconds, float) - def test_get_frame_played_at_h265(self): + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frame_played_at_h265(self, device): # Non-regression test for https://github.com/pytorch/torchcodec/issues/179 - decoder = VideoDecoder(H265_VIDEO.path) - ref_frame6 = H265_VIDEO.get_frame_data_by_index(5) - assert_tensor_equal(ref_frame6, decoder.get_frame_played_at(0.5).data) + decoder = VideoDecoder(H265_VIDEO.path, device=device) + frame_compare_function = get_frame_compare_function(device) + ref_frame6 = H265_VIDEO.get_frame_data_by_index(5).to(device) + frame_compare_function(ref_frame6, decoder.get_frame_played_at(0.5).data) def test_get_frame_played_at_fails(self): decoder = VideoDecoder(NASA_VIDEO.path)