7
7
8
8
from transformers import BertModel
9
9
10
- from utils import COSINE_THRESHOLD , cosine_similarity
10
+ from torch_tensorrt .dynamo .common_utils .test_utils import (
11
+ COSINE_THRESHOLD ,
12
+ cosine_similarity ,
13
+ )
11
14
12
15
13
16
@pytest .mark .unit
@@ -24,13 +27,14 @@ def test_resnet18(ir):
24
27
"device" : torchtrt .Device ("cuda:0" ),
25
28
"enabled_precisions" : {torch .float },
26
29
"ir" : ir ,
30
+ "pass_through_build_failures" : True ,
27
31
}
28
32
29
33
trt_mod = torchtrt .compile (model , ** compile_spec )
30
34
cos_sim = cosine_similarity (model (input ), trt_mod (input ))
31
35
assert (
32
36
cos_sim > COSINE_THRESHOLD ,
33
- f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
37
+ f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
34
38
)
35
39
36
40
# Clean up model env
@@ -54,6 +58,7 @@ def test_mobilenet_v2(ir):
54
58
"device" : torchtrt .Device ("cuda:0" ),
55
59
"enabled_precisions" : {torch .float },
56
60
"ir" : ir ,
61
+ "pass_through_build_failures" : True ,
57
62
}
58
63
59
64
trt_mod = torchtrt .compile (model , ** compile_spec )
@@ -84,6 +89,7 @@ def test_efficientnet_b0(ir):
84
89
"device" : torchtrt .Device ("cuda:0" ),
85
90
"enabled_precisions" : {torch .float },
86
91
"ir" : ir ,
92
+ "pass_through_build_failures" : True ,
87
93
}
88
94
89
95
trt_mod = torchtrt .compile (model , ** compile_spec )
@@ -123,6 +129,7 @@ def test_bert_base_uncased(ir):
123
129
"enabled_precisions" : {torch .float },
124
130
"truncate_long_and_double" : True ,
125
131
"ir" : ir ,
132
+ "pass_through_build_failures" : True ,
126
133
}
127
134
trt_mod = torchtrt .compile (model , ** compile_spec )
128
135
@@ -157,13 +164,14 @@ def test_resnet18_half(ir):
157
164
"device" : torchtrt .Device ("cuda:0" ),
158
165
"enabled_precisions" : {torch .half },
159
166
"ir" : ir ,
167
+ "pass_through_build_failures" : True ,
160
168
}
161
169
162
170
trt_mod = torchtrt .compile (model , ** compile_spec )
163
171
cos_sim = cosine_similarity (model (input ), trt_mod (input ))
164
172
assert (
165
173
cos_sim > COSINE_THRESHOLD ,
166
- f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
174
+ f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
167
175
)
168
176
169
177
# Clean up model env
0 commit comments