From e5a38ffeb6d6b6d4592d9bd1ffe7aaa77a0f34d7 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 19 Oct 2021 18:20:40 -0700 Subject: [PATCH] fix: Fix python API tests for mobilenet v2 This commit modifies test cases to use traced model instead of scripting. Model execution with mobilenet(using torch.jit.script) has problems with Pytorch 1.10 Signed-off-by: Dheeraj Peri --- tests/py/test_api.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/py/test_api.py b/tests/py/test_api.py index ca308e54da..1ce99dde94 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -51,15 +51,15 @@ def test_from_torch_tensor(self): "enabled_precisions": {torch.float} } - trt_mod = trtorch.compile(self.scripted_model, compile_spec) - same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() + trt_mod = trtorch.compile(self.traced_model, compile_spec) + same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max() self.assertTrue(same < 2e-2) def test_device(self): compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}} - trt_mod = trtorch.compile(self.scripted_model, compile_spec) - same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() + trt_mod = trtorch.compile(self.traced_model, compile_spec) + same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max() self.assertTrue(same < 2e-2) @@ -169,7 +169,7 @@ class TestPTtoTRTtoPT(ModelTestCase): def setUp(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") - self.ts_model = torch.jit.script(self.model) + self.ts_model = torch.jit.trace(self.model, [self.input]) def test_pt_to_trt_to_pt(self): compile_spec = {