Skip to content

Commit cb7a547

Browse files
committed
refactor(//tests/py): Apply linting to new tests
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 88d07a9 commit cb7a547

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

tests/py/test_api.py

+3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ def test_compile_script(self):
4545
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
4646
self.assertTrue(same < 2e-3)
4747

48+
4849
class TestPTtoTRTtoPT(ModelTestCase):
50+
4951
def setUp(self):
5052
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
5153
self.ts_model = torch.jit.script(self.model)
@@ -67,6 +69,7 @@ def test_pt_to_trt_to_pt(self):
6769
same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max()
6870
self.assertTrue(same < 2e-3)
6971

72+
7073
class TestCheckMethodOpSupport(unittest.TestCase):
7174

7275
def setUp(self):

tests/py/test_trt_intercompatability.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
class TestPyTorchToTRTEngine(ModelTestCase):
11+
1112
def setUp(self):
1213
self.input = torch.randn((1, 3, 224, 224)).to("cuda:0")
1314
self.ts_model = torch.jit.script(self.model)
@@ -32,10 +33,13 @@ def test_pt_to_trt(self):
3233
with engine.create_execution_context() as ctx:
3334
out = torch.empty(size=tuple(engine.get_binding_shape(1))).to("cuda:0")
3435
bindings = [self.input.contiguous().data_ptr(), out.contiguous().data_ptr()]
35-
ctx.execute_async(batch_size=1, bindings=bindings, stream_handle=torch.cuda.current_stream(device='cuda:0').cuda_stream)
36+
ctx.execute_async(batch_size=1,
37+
bindings=bindings,
38+
stream_handle=torch.cuda.current_stream(device='cuda:0').cuda_stream)
3639
same = (out - self.ts_model(self.input)).abs().max()
3740
self.assertTrue(same < 2e-3)
3841

42+
3943
def test_suite():
4044
suite = unittest.TestSuite()
4145
suite.addTest(TestPyTorchToTRTEngine.parametrize(TestPyTorchToTRTEngine, model=models.resnet18(pretrained=True)))

0 commit comments

Comments
 (0)