Skip to content

Commit 844afc6

Browse files
authored
feat: Wrap ExportedPrograms transformations with an API, allow dynamo.compile to accept graphmodules. (#2388)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 26d6b2e commit 844afc6

File tree

6 files changed

+68
-116
lines changed

6 files changed

+68
-116
lines changed

docsrc/user_guide/saving_models.rst

+6-8
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ The following code illustrates this approach.
2222
model = MyModel().eval().cuda()
2323
inputs = torch.randn((1, 3, 224, 224)).cuda()
2424
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
25-
trt_script_model = torch.jit.trace(trt_gm, inputs)
26-
torch.jit.save(trt_script_model, "trt_model.ts")
25+
trt_traced_model = torchtrt.dynamo.serialize(trt_gm, inputs)
26+
torch.jit.save(trt_traced_model, "trt_model.ts")
2727
2828
# Later, you can load it and run inference
2929
model = torch.jit.load("trt_model.ts").cuda()
@@ -37,21 +37,19 @@ b) ExportedProgram
3737
3838
import torch
3939
import torch_tensorrt
40-
from torch_tensorrt.dynamo.export import transform, create_exported_program
4140
4241
model = MyModel().eval().cuda()
4342
inputs = torch.randn((1, 3, 224, 224)).cuda()
4443
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
4544
# Transform and create an exported program
46-
trt_gm = transform(trt_gm, inputs)
47-
trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict())
48-
torch._export.save(trt_exp_program, "trt_model.ep")
45+
trt_exp_program = torch_tensorrt.dynamo.serialize(trt_gm, inputs, call_spec, ir="exported_program")
46+
torch.export.save(trt_exp_program, "trt_model.ep")
4947
5048
# Later, you can load it and run inference
51-
model = torch._export.load("trt_model.ep")
49+
model = torch.export.load("trt_model.ep")
5250
model(inputs)
5351
54-
`torch_tensorrt.dynamo.export.transform` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
52+
`torch_tensorrt.dynamo.transform` inlines the submodules within a GraphModule to their corresponding nodes, stiches all the nodes together and creates an ExportedProgram.
5553
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
5654

5755
NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341

py/requirements.txt

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
numpy
22
packaging
33
pybind11==2.6.2
4-
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
5-
torch>=2.1.0,<2.2.0
6-
torchvision>=0.16.0,<0.17.0
4+
torch==2.1.0
5+
torchvision==0.16.0
76
--extra-index-url https://pypi.ngc.nvidia.com
87
tensorrt==8.6.1
98
pyyaml

py/torch_tensorrt/dynamo/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
DYNAMO_CONVERTERS,
1717
dynamo_tensorrt_converter,
1818
)
19+
from .export import serialize

py/torch_tensorrt/dynamo/compile.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747

4848
def compile(
49-
exported_program: ExportedProgram,
49+
exported_program: Union[torch.fx.GraphModule, ExportedProgram],
5050
inputs: Any,
5151
*,
5252
device: Optional[Union[Device, torch.device, str]] = DEVICE,
@@ -86,7 +86,15 @@ def compile(
8686
inputs = prepare_inputs(inputs)
8787
device = to_torch_tensorrt_device(device)
8888

89-
gm = exported_program.module()
89+
if isinstance(exported_program, torch.fx.GraphModule):
90+
gm = exported_program
91+
elif isinstance(exported_program, ExportedProgram):
92+
gm = exported_program.module()
93+
else:
94+
raise AssertionError(
95+
f"Input graph should either be an ExportedProgram or a GraphModule but got type {type(exported_program)}"
96+
)
97+
9098
logger.debug("Input graph: " + str(gm.graph))
9199

92100
# Apply lowering on the graph module

py/torch_tensorrt/dynamo/export.py

+39-28
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import operator
3-
from typing import Any, Dict, Sequence, Tuple, Union, cast
3+
from typing import Any, Dict, Sequence, Tuple, cast
44

55
import torch
66
from torch._export.exported_program import CallSpec
@@ -10,28 +10,42 @@
1010
from torch_tensorrt.dynamo import partitioning
1111

1212

13-
def transform(
14-
gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]
15-
) -> torch.fx.GraphModule:
16-
# Run shape analysis
17-
_, outputs_map = partitioning.run_shape_analysis(gm, inputs)
18-
19-
# Inline TensorRT submodules
20-
inline_trt_modules(gm, outputs_map)
21-
22-
# Inline pytorch submodules
23-
inline_torch_modules(gm)
24-
25-
# Lift constant buffers and parameters in the graph
26-
# torch.export serialization expects them to be lifted
27-
lift_constant_pass(gm)
28-
29-
# Clean the graph
30-
gm.delete_all_unused_submodules()
31-
gm.graph.eliminate_dead_code()
32-
gm.graph.lint()
33-
34-
return gm
13+
def serialize(
14+
gm: torch.fx.GraphModule,
15+
inputs: Sequence[torch.Tensor],
16+
call_spec: CallSpec = None,
17+
ir: str = "torchscript",
18+
) -> ExportedProgram:
19+
if ir == "torchscript":
20+
return torch.jit.trace(gm, inputs)
21+
elif ir == "exported_program":
22+
assert call_spec
23+
# Run shape analysis
24+
_, outputs_map = partitioning.run_shape_analysis(gm, inputs)
25+
26+
# Inline TensorRT submodules
27+
inline_trt_modules(gm, outputs_map)
28+
29+
# Inline pytorch submodules
30+
inline_torch_modules(gm)
31+
32+
# Lift constant buffers and parameters in the graph
33+
# torch.export serialization expects them to be lifted
34+
lift_constant_pass(gm)
35+
36+
# Clean the graph
37+
gm.delete_all_unused_submodules()
38+
gm.graph.eliminate_dead_code()
39+
gm.graph.lint()
40+
41+
# Create an exported program with the TRT GraphModule
42+
exp_program = create_trt_exp_program(gm, call_spec)
43+
44+
return exp_program
45+
else:
46+
raise ValueError(
47+
"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
48+
)
3549

3650

3751
def lift_constant_pass(trt_gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
@@ -115,7 +129,6 @@ def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
115129

116130
# Copy all nodes in the submodule into gm and
117131
# store the output node of this submodule which is now present in gm
118-
119132
submodule_output = gm.graph.graph_copy(submodule.graph, val_map)
120133

121134
# Get their references (since we copied) in the parent graph (gm)
@@ -174,9 +187,7 @@ def copy_submodule_attributes(
174187

175188

176189
def create_trt_exp_program(
177-
gm: torch.fx.GraphModule,
178-
call_spec: CallSpec,
179-
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
190+
gm: torch.fx.GraphModule, call_spec: CallSpec
180191
) -> ExportedProgram:
181192
"""Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines
182193
and constructs an Exported Program object with the new IO node names, call_spec and state_dict
@@ -208,7 +219,7 @@ def create_trt_exp_program(
208219
)
209220

210221
trt_exp_program = ExportedProgram(
211-
gm, gm.graph, trt_graph_signature, call_spec, state_dict, {}, [], []
222+
gm, gm.graph, trt_graph_signature, call_spec, gm.state_dict(), {}, [], []
212223
)
213224

214225
return trt_exp_program

tests/py/dynamo/models/test_export_serde.py

+10-75
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch_tensorrt as torchtrt
77
import torchvision.models as models
88
from torch._export.serde.serialize import deserialize, serialize
9-
from torch_tensorrt.dynamo.export import create_trt_exp_program, transform
109
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
1110

1211
assertions = unittest.TestCase()
@@ -45,9 +44,8 @@ def forward(self, x):
4544

4645
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
4746
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
48-
trt_gm = transform(trt_gm, [input])
49-
trt_exp_program = create_trt_exp_program(
50-
trt_gm, exp_program.call_spec, trt_gm.state_dict()
47+
trt_exp_program = torchtrt.dynamo.serialize(
48+
trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program"
5149
)
5250
serialized_prog = serialize(trt_exp_program)
5351
deserialized_prog = deserialize(*serialized_prog)
@@ -100,11 +98,9 @@ def forward(self, x):
10098

10199
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
102100
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
103-
trt_gm = transform(trt_gm, [input])
104-
trt_exp_program = create_trt_exp_program(
105-
trt_gm, exp_program.call_spec, trt_gm.state_dict()
101+
trt_exp_program = torchtrt.dynamo.serialize(
102+
trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program"
106103
)
107-
108104
serialized_prog = serialize(trt_exp_program)
109105
deserialized_prog = deserialize(*serialized_prog)
110106
# Check Pyt and TRT exported program outputs
@@ -161,11 +157,9 @@ def forward(self, x):
161157

162158
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
163159
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
164-
trt_gm = transform(trt_gm, [input])
165-
trt_exp_program = create_trt_exp_program(
166-
trt_gm, exp_program.call_spec, trt_gm.state_dict()
160+
trt_exp_program = torchtrt.dynamo.serialize(
161+
trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program"
167162
)
168-
169163
torch._export.save(trt_exp_program, "/tmp/trt.ep")
170164
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")
171165

@@ -224,11 +218,9 @@ def forward(self, x):
224218

225219
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
226220
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
227-
trt_gm = transform(trt_gm, [input])
228-
trt_exp_program = create_trt_exp_program(
229-
trt_gm, exp_program.call_spec, trt_gm.state_dict()
221+
trt_exp_program = torchtrt.dynamo.serialize(
222+
trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program"
230223
)
231-
232224
torch._export.save(trt_exp_program, "/tmp/trt.ep")
233225
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")
234226

@@ -270,9 +262,8 @@ def test_resnet18_save_load(ir):
270262

271263
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
272264
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
273-
trt_gm = transform(trt_gm, [input])
274-
trt_exp_program = create_trt_exp_program(
275-
trt_gm, exp_program.call_spec, trt_gm.state_dict()
265+
trt_exp_program = torchtrt.dynamo.serialize(
266+
trt_gm, [input], call_spec=exp_program.call_spec, ir="exported_program"
276267
)
277268
torch._export.save(trt_exp_program, "/tmp/trt.ep")
278269
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")
@@ -291,59 +282,3 @@ def test_resnet18_save_load(ir):
291282
cos_sim > COSINE_THRESHOLD,
292283
msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
293284
)
294-
295-
296-
# Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341
297-
# @pytest.mark.unit
298-
# def test_hybrid_conv_fallback(ir):
299-
# """
300-
# This tests export save and load functionality on a hybrid
301-
# model where a conv (a weighted layer) has been forced to fallback to Pytorch.
302-
# """
303-
304-
# class MyModule(torch.nn.Module):
305-
# def __init__(self):
306-
# super().__init__()
307-
# self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
308-
# self.relu = torch.nn.ReLU()
309-
310-
# def forward(self, x):
311-
# conv = self.conv(x)
312-
# relu = self.relu(conv)
313-
# mul = relu * 0.5
314-
# return mul
315-
316-
# model = MyModule().eval().cuda()
317-
# input = torch.randn((1, 3, 224, 224)).to("cuda")
318-
319-
# compile_spec = {
320-
# "inputs": [
321-
# torchtrt.Input(
322-
# input.shape, dtype=torch.float, format=torch.contiguous_format
323-
# )
324-
# ],
325-
# "ir": ir,
326-
# "min_block_size": 1,
327-
# "torch_executed_ops": "torch.ops.aten.convolution.default",
328-
# }
329-
330-
# trt_exp_program = torchtrt.compile(model, **compile_spec)
331-
# torch._export.save(trt_exp_program, "/tmp/trt.ep")
332-
# deser_trt_exp_program = torch._export.load("/tmp/trt.ep")
333-
334-
# outputs_pyt = model(input)
335-
# outputs_trt = trt_exp_program(input)
336-
# for idx in range(len(outputs_pyt)):
337-
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
338-
# assertions.assertTrue(
339-
# cos_sim > COSINE_THRESHOLD,
340-
# msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
341-
# )
342-
343-
# outputs_trt_deser = deser_trt_exp_program(input)
344-
# for idx in range(len(outputs_pyt)):
345-
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
346-
# assertions.assertTrue(
347-
# cos_sim > COSINE_THRESHOLD,
348-
# msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
349-
# )

0 commit comments

Comments
 (0)