1
1
import torch
2
2
import timm
3
3
import pytest
4
+ import unittest
4
5
5
6
import torch_tensorrt as torchtrt
6
7
import torchvision .models as models
12
13
cosine_similarity ,
13
14
)
14
15
16
+ assertions = unittest .TestCase ()
17
+
15
18
16
19
@pytest .mark .unit
17
20
def test_resnet18 (ir ):
@@ -32,9 +35,9 @@ def test_resnet18(ir):
32
35
33
36
trt_mod = torchtrt .compile (model , ** compile_spec )
34
37
cos_sim = cosine_similarity (model (input ), trt_mod (input ))
35
- assert (
38
+ assertions . assertTrue (
36
39
cos_sim > COSINE_THRESHOLD ,
37
- f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
40
+ msg = f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
38
41
)
39
42
40
43
# Clean up model env
@@ -63,9 +66,9 @@ def test_mobilenet_v2(ir):
63
66
64
67
trt_mod = torchtrt .compile (model , ** compile_spec )
65
68
cos_sim = cosine_similarity (model (input ), trt_mod (input ))
66
- assert (
69
+ assertions . assertTrue (
67
70
cos_sim > COSINE_THRESHOLD ,
68
- f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
71
+ msg = f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
69
72
)
70
73
71
74
# Clean up model env
@@ -94,9 +97,9 @@ def test_efficientnet_b0(ir):
94
97
95
98
trt_mod = torchtrt .compile (model , ** compile_spec )
96
99
cos_sim = cosine_similarity (model (input ), trt_mod (input ))
97
- assert (
100
+ assertions . assertTrue (
98
101
cos_sim > COSINE_THRESHOLD ,
99
- f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
102
+ msg = f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
100
103
)
101
104
102
105
# Clean up model env
@@ -138,9 +141,9 @@ def test_bert_base_uncased(ir):
138
141
for key in model_outputs .keys ():
139
142
out , trt_out = model_outputs [key ], trt_model_outputs [key ]
140
143
cos_sim = cosine_similarity (out , trt_out )
141
- assert (
144
+ assertions . assertTrue (
142
145
cos_sim > COSINE_THRESHOLD ,
143
- f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
146
+ msg = f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
144
147
)
145
148
146
149
# Clean up model env
@@ -169,9 +172,9 @@ def test_resnet18_half(ir):
169
172
170
173
trt_mod = torchtrt .compile (model , ** compile_spec )
171
174
cos_sim = cosine_similarity (model (input ), trt_mod (input ))
172
- assert (
175
+ assertions . assertTrue (
173
176
cos_sim > COSINE_THRESHOLD ,
174
- f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
177
+ msg = f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
175
178
)
176
179
177
180
# Clean up model env
0 commit comments