diff --git a/tests/py/test_api.py b/tests/py/test_api.py index ca308e54da..1ce99dde94 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -51,15 +51,15 @@ def test_from_torch_tensor(self): "enabled_precisions": {torch.float} } - trt_mod = trtorch.compile(self.scripted_model, compile_spec) - same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() + trt_mod = trtorch.compile(self.traced_model, compile_spec) + same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max() self.assertTrue(same < 2e-2) def test_device(self): compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}} - trt_mod = trtorch.compile(self.scripted_model, compile_spec) - same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() + trt_mod = trtorch.compile(self.traced_model, compile_spec) + same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max() self.assertTrue(same < 2e-2) @@ -169,7 +169,7 @@ class TestPTtoTRTtoPT(ModelTestCase): def setUp(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") - self.ts_model = torch.jit.script(self.model) + self.ts_model = torch.jit.trace(self.model, [self.input]) def test_pt_to_trt_to_pt(self): compile_spec = {