Skip to content

Commit e5a38ff

Browse files
committed
fix: Fix python API tests for mobilenet v2
This commit modifies test cases to use traced model instead of scripting. Model execution with mobilenet(using torch.jit.script) has problems with Pytorch 1.10 Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 4d95b04 commit e5a38ff

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

Diff for: tests/py/test_api.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def test_from_torch_tensor(self):
5151
"enabled_precisions": {torch.float}
5252
}
5353

54-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
55-
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
54+
trt_mod = trtorch.compile(self.traced_model, compile_spec)
55+
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
5656
self.assertTrue(same < 2e-2)
5757

5858
def test_device(self):
5959
compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}}
6060

61-
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
62-
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
61+
trt_mod = trtorch.compile(self.traced_model, compile_spec)
62+
same = (trt_mod(self.input) - self.traced_model(self.input)).abs().max()
6363
self.assertTrue(same < 2e-2)
6464

6565

@@ -169,7 +169,7 @@ class TestPTtoTRTtoPT(ModelTestCase):
169169

170170
def setUp(self):
171171
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
172-
self.ts_model = torch.jit.script(self.model)
172+
self.ts_model = torch.jit.trace(self.model, [self.input])
173173

174174
def test_pt_to_trt_to_pt(self):
175175
compile_spec = {

0 commit comments

Comments
 (0)