6
6
import torch_tensorrt as torchtrt
7
7
import torchvision .models as models
8
8
from torch ._export .serde .serialize import deserialize , serialize
9
- from torch_tensorrt .dynamo .export import create_trt_exp_program , transform
10
9
from torch_tensorrt .dynamo .utils import COSINE_THRESHOLD , cosine_similarity
11
10
12
11
assertions = unittest .TestCase ()
@@ -45,9 +44,8 @@ def forward(self, x):
45
44
46
45
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
47
46
trt_gm = torchtrt .dynamo .compile (exp_program , ** compile_spec )
48
- trt_gm = transform (trt_gm , [input ])
49
- trt_exp_program = create_trt_exp_program (
50
- trt_gm , exp_program .call_spec , trt_gm .state_dict ()
47
+ trt_exp_program = torchtrt .dynamo .serialize (
48
+ trt_gm , [input ], call_spec = exp_program .call_spec , ir = "exported_program"
51
49
)
52
50
serialized_prog = serialize (trt_exp_program )
53
51
deserialized_prog = deserialize (* serialized_prog )
@@ -100,11 +98,9 @@ def forward(self, x):
100
98
101
99
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
102
100
trt_gm = torchtrt .dynamo .compile (exp_program , ** compile_spec )
103
- trt_gm = transform (trt_gm , [input ])
104
- trt_exp_program = create_trt_exp_program (
105
- trt_gm , exp_program .call_spec , trt_gm .state_dict ()
101
+ trt_exp_program = torchtrt .dynamo .serialize (
102
+ trt_gm , [input ], call_spec = exp_program .call_spec , ir = "exported_program"
106
103
)
107
-
108
104
serialized_prog = serialize (trt_exp_program )
109
105
deserialized_prog = deserialize (* serialized_prog )
110
106
# Check Pyt and TRT exported program outputs
@@ -161,11 +157,9 @@ def forward(self, x):
161
157
162
158
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
163
159
trt_gm = torchtrt .dynamo .compile (exp_program , ** compile_spec )
164
- trt_gm = transform (trt_gm , [input ])
165
- trt_exp_program = create_trt_exp_program (
166
- trt_gm , exp_program .call_spec , trt_gm .state_dict ()
160
+ trt_exp_program = torchtrt .dynamo .serialize (
161
+ trt_gm , [input ], call_spec = exp_program .call_spec , ir = "exported_program"
167
162
)
168
-
169
163
torch ._export .save (trt_exp_program , "/tmp/trt.ep" )
170
164
deser_trt_exp_program = torch ._export .load ("/tmp/trt.ep" )
171
165
@@ -224,11 +218,9 @@ def forward(self, x):
224
218
225
219
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
226
220
trt_gm = torchtrt .dynamo .compile (exp_program , ** compile_spec )
227
- trt_gm = transform (trt_gm , [input ])
228
- trt_exp_program = create_trt_exp_program (
229
- trt_gm , exp_program .call_spec , trt_gm .state_dict ()
221
+ trt_exp_program = torchtrt .dynamo .serialize (
222
+ trt_gm , [input ], call_spec = exp_program .call_spec , ir = "exported_program"
230
223
)
231
-
232
224
torch ._export .save (trt_exp_program , "/tmp/trt.ep" )
233
225
deser_trt_exp_program = torch ._export .load ("/tmp/trt.ep" )
234
226
@@ -270,9 +262,8 @@ def test_resnet18_save_load(ir):
270
262
271
263
exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
272
264
trt_gm = torchtrt .dynamo .compile (exp_program , ** compile_spec )
273
- trt_gm = transform (trt_gm , [input ])
274
- trt_exp_program = create_trt_exp_program (
275
- trt_gm , exp_program .call_spec , trt_gm .state_dict ()
265
+ trt_exp_program = torchtrt .dynamo .serialize (
266
+ trt_gm , [input ], call_spec = exp_program .call_spec , ir = "exported_program"
276
267
)
277
268
torch ._export .save (trt_exp_program , "/tmp/trt.ep" )
278
269
deser_trt_exp_program = torch ._export .load ("/tmp/trt.ep" )
@@ -291,59 +282,3 @@ def test_resnet18_save_load(ir):
291
282
cos_sim > COSINE_THRESHOLD ,
292
283
msg = f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
293
284
)
294
-
295
-
296
- # Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341
297
- # @pytest.mark.unit
298
- # def test_hybrid_conv_fallback(ir):
299
- # """
300
- # This tests export save and load functionality on a hybrid
301
- # model where a conv (a weighted layer) has been forced to fallback to Pytorch.
302
- # """
303
-
304
- # class MyModule(torch.nn.Module):
305
- # def __init__(self):
306
- # super().__init__()
307
- # self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
308
- # self.relu = torch.nn.ReLU()
309
-
310
- # def forward(self, x):
311
- # conv = self.conv(x)
312
- # relu = self.relu(conv)
313
- # mul = relu * 0.5
314
- # return mul
315
-
316
- # model = MyModule().eval().cuda()
317
- # input = torch.randn((1, 3, 224, 224)).to("cuda")
318
-
319
- # compile_spec = {
320
- # "inputs": [
321
- # torchtrt.Input(
322
- # input.shape, dtype=torch.float, format=torch.contiguous_format
323
- # )
324
- # ],
325
- # "ir": ir,
326
- # "min_block_size": 1,
327
- # "torch_executed_ops": "torch.ops.aten.convolution.default",
328
- # }
329
-
330
- # trt_exp_program = torchtrt.compile(model, **compile_spec)
331
- # torch._export.save(trt_exp_program, "/tmp/trt.ep")
332
- # deser_trt_exp_program = torch._export.load("/tmp/trt.ep")
333
-
334
- # outputs_pyt = model(input)
335
- # outputs_trt = trt_exp_program(input)
336
- # for idx in range(len(outputs_pyt)):
337
- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
338
- # assertions.assertTrue(
339
- # cos_sim > COSINE_THRESHOLD,
340
- # msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
341
- # )
342
-
343
- # outputs_trt_deser = deser_trt_exp_program(input)
344
- # for idx in range(len(outputs_pyt)):
345
- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
346
- # assertions.assertTrue(
347
- # cos_sim > COSINE_THRESHOLD,
348
- # msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
349
- # )
0 commit comments