From d06c74af7f748ee6028eab1fed5ed3346882995e Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 3 Oct 2023 11:55:48 -0700 Subject: [PATCH 01/14] chore: Switch to new export apis Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/aten_tracer.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py index da346635a2..def04e7057 100644 --- a/py/torch_tensorrt/dynamo/aten_tracer.py +++ b/py/torch_tensorrt/dynamo/aten_tracer.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import unittest.mock from typing import Any, List, Tuple import torch @@ -77,12 +76,9 @@ def trace( experimental_decompositions = kwargs.get( "enable_experimental_decompositions", False ) - with unittest.mock.patch( - "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) - ): - graph_module = export( - model, tuple(trace_inputs), constraints=constraints - ).module() - logger.debug("Post export graph: " + str(graph_module.graph)) - return graph_module + exp_program = export( + model, tuple(trace_inputs), constraints=constraints + ).run_decompositions(get_decompositions(experimental_decompositions)) + + return exp_program From ad3b0311b33508a85ae33dfdd591962561e453ac Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 19 Oct 2023 15:16:13 -0700 Subject: [PATCH 02/14] feat: Add support for dynamic shapes and remove constraints API Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/_Input.py | 7 ++- py/torch_tensorrt/dynamo/aten_tracer.py | 53 +++++------------------ tests/py/dynamo/models/test_dyn_models.py | 2 + 3 files changed, 20 insertions(+), 42 deletions(-) diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 6e43a23903..4dd3cf62c2 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -47,6 +47,7 @@ class _ShapeMode(Enum): high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET torch_dtype: torch.dtype = torch.float32 torch_tensor: torch.Tensor = None + name: str = "" def __init__(self, *args: Any, **kwargs: Any) -> None: """__init__ Method for torch_tensorrt.Input @@ -68,7 +69,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: format (torch.memory_format or torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) tensor_domain (Tuple(float, float), optional): The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]). Note: Entering "None" (or not specifying) will set the bound to [0, 2) - + torch_tensor (torch.Tensor): Holds a corresponding torch tensor with this Input. + name (str, optional): Name of this input in the pytorch graph. Used to specify dynamic shapes in dynamo tracer. Examples: - Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last) - Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW) @@ -180,6 +182,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: else: self.torch_tensor = self.example_tensor() + if "name" in kwargs: + self.name = kwargs["name"] + def __str__(self) -> str: if self.shape_mode == Input._ShapeMode.STATIC: return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format( diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py index f6d0ad4625..c894ca6f3c 100644 --- a/py/torch_tensorrt/dynamo/aten_tracer.py +++ b/py/torch_tensorrt/dynamo/aten_tracer.py @@ -1,10 +1,10 @@ from __future__ import annotations import logging -from typing import Any, List, Tuple +from typing import Any, Tuple import torch -from torch._export import dynamic_dim, export +from torch.export import Dim, export from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import ( ENABLE_EXPERIMENTAL_DECOMPOSITIONS, @@ -16,20 +16,6 @@ logger = logging.getLogger(__name__) -def get_random_tensor( - shape: List[Any], dtype: torch.dtype, device: torch.device -) -> torch.Tensor: - if dtype == torch.int32 or dtype == torch.int64: - return torch.randint(2, 10, shape, dtype=dtype, device=device) - elif dtype in (torch.float64, torch.float32, torch.float16): - return torch.randn(shape, dtype=dtype, device=device) - else: - logger.critical( - "Invalid dtype detected in creating input tensors for tracing the graph." - ) - raise - - def trace( model: torch.nn.Module | torch.fx.GraphModule, inputs: Tuple[Any, ...], @@ -39,49 +25,34 @@ def trace( if "debug" in kwargs and kwargs["debug"]: set_log_level(logger.parent, logging.DEBUG) - # Determine the dynamic dimension and setup constraints to input dimensions as dictated by TensorRT - # Torch dynamo does not allow 0/1 value for dynamic dimensions - # for inputs during tracing. Hence we create new inputs for export device = to_torch_device(kwargs.get("device", default_device())) torch_inputs = get_torch_inputs(inputs, device) - trace_inputs = [] - constraints = [] + dynamic_shapes = {} for idx, input in enumerate(inputs): if input.shape_mode == Input._ShapeMode.DYNAMIC: min_shape = input.shape["min_shape"] opt_shape = input.shape["opt_shape"] max_shape = input.shape["max_shape"] assert len(min_shape) == len(opt_shape) == len(max_shape) - - constraint_dims = [] - new_shape = [] + dynamic_dims = {} for dim in range(len(min_shape)): if min_shape[dim] == opt_shape[dim] == max_shape[dim]: - new_shape.append(torch_inputs[idx].shape[dim]) + continue else: - constraint_dims.append(dim) - if torch_inputs[idx].shape[dim] == 1: - new_shape.append(torch_inputs[idx].shape[dim] + 1) - else: - new_shape.append(torch_inputs[idx].shape[dim]) - - trace_input = get_random_tensor(new_shape, torch_inputs[idx].dtype, device) + dynamic_dims[dim] = Dim( + input.name + "_" + str(dim), + min=min_shape[dim], + max=max_shape[dim], + ) - for dim in constraint_dims: - if min_shape[dim] > 1: - constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim)) - if max_shape[dim] > 1: - constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim]) - trace_inputs.append(trace_input) - else: - trace_inputs.append(torch_inputs[idx]) + dynamic_shapes[input.name] = dynamic_dims experimental_decompositions = kwargs.get( "enable_experimental_decompositions", ENABLE_EXPERIMENTAL_DECOMPOSITIONS ) exp_program = export( - model, tuple(trace_inputs), constraints=constraints + model, tuple(torch_inputs), dynamic_shapes=dynamic_shapes ).run_decompositions(get_decompositions(experimental_decompositions)) return exp_program diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 057a95879d..d110845145 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -36,6 +36,7 @@ def forward(self, x): opt_shape=(4, 3, 224, 224), max_shape=(8, 3, 224, 224), dtype=torch.float32, + name="x", ) ], "device": torchtrt.Device("cuda:0"), @@ -88,6 +89,7 @@ def forward(self, x): opt_shape=(4, 3, 224, 224), max_shape=(8, 3, 224, 224), dtype=torch.float32, + name="x", ) ], "device": torchtrt.Device("cuda:0"), From 2d09249ca83a71b3d347bc63a4a9080f3b8a73a9 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 16 Nov 2023 10:22:27 -0800 Subject: [PATCH 03/14] chore: updates Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/_tracer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 43812fd062..cd705916b4 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -70,6 +70,10 @@ def trace( dynamic_shapes = {} for input in inputs: if input.shape_mode == Input._ShapeMode.DYNAMIC: + if not input.name: + raise AssertionError( + f"Expected a name for a dynamic input with shape {input.shape} but found none" + ) min_shape = input.shape["min_shape"] opt_shape = input.shape["opt_shape"] max_shape = input.shape["max_shape"] From d58bc5df85c49982e6d6c559647ba585541bd38d Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 16 Nov 2023 17:04:49 -0800 Subject: [PATCH 04/14] chore: isolate serde tests in a separate job Signed-off-by: Dheeraj Peri --- .github/workflows/build-test.yml | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index bfc19cce45..afa4cd5ed7 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -141,10 +141,35 @@ jobs: cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py - ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py popd + tests-py-dynamo-serde: + name: Test dynamo export serde [Python] + needs: [generate-matrix, build] + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/tensorrt + package-name: torch_tensorrt + pre-script: packaging/pre_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + with: + job-name: tests-py-dynamo-serde + repository: "pytorch/tensorrt" + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + script: | + export USE_HOST_DEPS=1 + pushd . + cd tests/py/dynamo + ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py + popd + tests-py-torch-compile-be: name: Test torch compile backend [Python] needs: [generate-matrix, build] From bc7eaa3db0ef6e4534152dff226c321693019462 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 17 Nov 2023 16:56:22 -0800 Subject: [PATCH 05/14] chore: fix output node Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/_exporter.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index fa1bd214ac..017bb8df99 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -228,7 +228,10 @@ def create_trt_exp_program( and constructs an Exported Program object with the new IO node names and state_dict """ input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] - output_nodes = [node for node in gm.graph.nodes if node.op == "output"] + # output_nodes = [node for node in gm.graph.nodes if node.op == "output"] + # graph = [placeholder, conv1, relu, trt_node, getitem, output] + # output_nodes[0].args[0] + output_nodes = list(gm.graph.nodes)[-1].args[0] input_specs = [ InputSpec(InputKind.USER_INPUT, TensorArgument(name=node.name), node.target) @@ -292,6 +295,13 @@ def inline_trt_modules( # Insert getitem nodes as outputs (for export serialization to work) with gm.graph.inserting_after(trt_node): getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0)) + getitem_output.meta["val"] = cast( + FakeTensor, + torch.empty_strided( + tuple(outputs_map[trt_module_node.name][idx]), + tuple([1] * len(outputs_map[trt_module_node.name][idx])), + ), + ) trt_module_node.replace_all_uses_with(getitem_output) else: # Multiple outputs case: From ab92f25ec6d1ff360cea390429e92b725a569a7b Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 20 Nov 2023 13:07:39 -0800 Subject: [PATCH 06/14] chore: fix export serde tests Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/_exporter.py | 12 ++- tests/py/dynamo/models/test_export_serde.py | 89 +++++++++++---------- 2 files changed, 50 insertions(+), 51 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index 017bb8df99..8c06b3ca2f 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -279,6 +279,7 @@ def inline_trt_modules( (trt_module_node.args, trt_module.engine), ) trt_node.meta["val"] = [] + assert num_outputs # Generate meta data for TRT node (a FakeTensor with corresponding output shape) for idx in range(num_outputs): trt_node.meta["val"].append( @@ -295,19 +296,16 @@ def inline_trt_modules( # Insert getitem nodes as outputs (for export serialization to work) with gm.graph.inserting_after(trt_node): getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0)) - getitem_output.meta["val"] = cast( - FakeTensor, - torch.empty_strided( - tuple(outputs_map[trt_module_node.name][idx]), - tuple([1] * len(outputs_map[trt_module_node.name][idx])), - ), - ) + getitem_output.meta["val"] = trt_node.meta["val"] trt_module_node.replace_all_uses_with(getitem_output) else: # Multiple outputs case: # Replace uses of submodule with the trt_node. # getitem nodes are already added inherently by the partitioner trt_module_node.replace_all_uses_with(trt_node) + getitem_nodes = trt_node.users + for idx, getitem_node in enumerate(getitem_nodes): + getitem_node.meta["val"] = trt_node.meta["val"][idx] # Erase the TRT submodule (call_module) node. gm.graph.erase_node(trt_module_node) diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index d66ef5d89e..061a983e08 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -45,8 +45,8 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") - serialized_prog = serialize(trt_exp_program) - deserialized_prog = deserialize(*serialized_prog) + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") # Check Pyt and TRT exported program outputs cos_sim = cosine_similarity(model(input), trt_exp_program(input)[0]) @@ -55,7 +55,7 @@ def forward(self, x): msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Check Pyt and deserialized TRT exported program outputs - cos_sim = cosine_similarity(model(input), deserialized_prog(input)[0]) + cos_sim = cosine_similarity(model(input), deser_trt_exp_program(input)[0]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", @@ -97,8 +97,8 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") - serialized_prog = serialize(trt_exp_program) - deserialized_prog = deserialize(*serialized_prog) + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") # Check Pyt and TRT exported program outputs outputs_pyt = model(input) outputs_trt = trt_exp_program(input) @@ -110,7 +110,7 @@ def forward(self, x): ) # Check Pyt and deserialized TRT exported program outputs - outputs_trt_deser = deserialized_prog(input) + outputs_trt_deser = deser_trt_exp_program(input) for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) assertions.assertTrue( @@ -154,8 +154,8 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") - torch._export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") outputs_pyt = model(input) outputs_trt = trt_exp_program(input) @@ -213,8 +213,8 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") - torch._export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") outputs_pyt = model(input) outputs_trt = trt_exp_program(input) @@ -235,44 +235,45 @@ def forward(self, x): # TODO (peri044) : Enable this test once the _frozen_param0 attribute resulting in sym_int ops issue is fixed. -# @pytest.mark.unit -# def test_resnet18_save_load(ir): -# """ -# This tests export save and load functionality on Resnet18 model -# """ -# model = models.resnet18().eval().cuda() -# input = torch.randn((1, 3, 224, 224)).to("cuda") +@pytest.mark.unit +def test_resnet18_save_load(ir): + """ + This tests export save and load functionality on Resnet18 model + """ + model = models.resnet18().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") -# compile_spec = { -# "inputs": [ -# torchtrt.Input( -# input.shape, dtype=torch.float, format=torch.contiguous_format -# ) -# ], -# "ir": ir, -# "min_block_size": 1, -# } + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "ir": ir, + "min_block_size": 1, + } -# exp_program = torchtrt.dynamo.trace(model, **compile_spec) -# trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) -# trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") -# torch._export.save(trt_exp_program, "/tmp/trt.ep") -# deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + exp_program = torchtrt.dynamo.trace(model, **compile_spec) + trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) + trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") + torch._export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch._export.load("/tmp/trt.ep") -# outputs_pyt = model(input) -# outputs_trt = trt_exp_program(input) -# cos_sim = cosine_similarity(outputs_pyt, outputs_trt) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) + outputs_pyt = model(input) + outputs_trt = trt_exp_program(input) + cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) -# outputs_trt_deser = deser_trt_exp_program(input) -# cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) + outputs_trt_deser = deser_trt_exp_program(input) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) # Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341 From 7a0d7bbc4842915454078ed826889e0179ede627 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 20 Nov 2023 15:35:49 -0800 Subject: [PATCH 07/14] chore: fix tests Signed-off-by: Dheeraj Peri --- .github/workflows/build-test.yml | 1 + tests/py/dynamo/models/test_export_serde.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index afa4cd5ed7..8a6063c33d 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -167,6 +167,7 @@ jobs: export USE_HOST_DEPS=1 pushd . cd tests/py/dynamo + ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py popd diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index 061a983e08..9978395f40 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -234,7 +234,6 @@ def forward(self, x): ) -# TODO (peri044) : Enable this test once the _frozen_param0 attribute resulting in sym_int ops issue is fixed. @pytest.mark.unit def test_resnet18_save_load(ir): """ From 8903382b914e4e182ce3950c506df990035d157a Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 28 Nov 2023 13:29:11 -0800 Subject: [PATCH 08/14] chore: remove redundant comments Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/_exporter.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index 8c06b3ca2f..25c48291e7 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -228,10 +228,9 @@ def create_trt_exp_program( and constructs an Exported Program object with the new IO node names and state_dict """ input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] - # output_nodes = [node for node in gm.graph.nodes if node.op == "output"] - # graph = [placeholder, conv1, relu, trt_node, getitem, output] - # output_nodes[0].args[0] - output_nodes = list(gm.graph.nodes)[-1].args[0] + output_nodes = [node for node in gm.graph.nodes if node.op == "output"] + assert output_nodes + output_nodes = output_nodes[0].args[0] input_specs = [ InputSpec(InputKind.USER_INPUT, TensorArgument(name=node.name), node.target) From 0db52359219aa8873cb30fb2a2c0c9a2aeadfe67 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 30 Nov 2023 15:38:48 -0800 Subject: [PATCH 09/14] chore: edit docstring Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/_Input.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 4dd3cf62c2..9acb073c62 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -70,7 +70,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: tensor_domain (Tuple(float, float), optional): The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]). Note: Entering "None" (or not specifying) will set the bound to [0, 2) torch_tensor (torch.Tensor): Holds a corresponding torch tensor with this Input. - name (str, optional): Name of this input in the pytorch graph. Used to specify dynamic shapes in dynamo tracer. + name (str, optional): Name of this input in the input nn.Module's forward function. Used to specify dynamic shapes for the corresponding input in dynamo tracer. Examples: - Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last) - Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW) From b78fe7a7027d774cf4d4a74a865204e90b3edf46 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sun, 3 Dec 2023 17:32:42 -0800 Subject: [PATCH 10/14] chore: address review comments Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/_compiler.py | 9 ++++++++- py/torch_tensorrt/dynamo/_exporter.py | 2 +- py/torch_tensorrt/dynamo/_tracer.py | 5 +---- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d31be8a413..c6895e7907 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -34,7 +34,7 @@ convert_module, repair_long_or_double_inputs, ) -from torch_tensorrt.dynamo.lowering import apply_lowering_passes +from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.utils import ( get_torch_inputs, prepare_inputs, @@ -146,6 +146,13 @@ def compile( inputs = prepare_inputs(inputs) device = to_torch_tensorrt_device(device) + if not isinstance(exported_program, ExportedProgram): + raise AssertionError( + f"Input graph should be an ExportedProgram but got type {type(exported_program)}" + ) + exported_program = exported_program.run_decompositions( + get_decompositions(enable_experimental_decompositions) + ) gm = exported_program.module() logger.debug("Input graph: " + str(gm.graph)) diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index 25c48291e7..df9150ea2d 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -278,7 +278,7 @@ def inline_trt_modules( (trt_module_node.args, trt_module.engine), ) trt_node.meta["val"] = [] - assert num_outputs + assert num_outputs > 0 # Generate meta data for TRT node (a FakeTensor with corresponding output shape) for idx in range(num_outputs): trt_node.meta["val"].append( diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index cd705916b4..c5089c5c3d 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -12,7 +12,6 @@ ENABLE_EXPERIMENTAL_DECOMPOSITIONS, default_device, ) -from torch_tensorrt.dynamo.lowering import get_decompositions from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device logger = logging.getLogger(__name__) @@ -95,8 +94,6 @@ def trace( "enable_experimental_decompositions", ENABLE_EXPERIMENTAL_DECOMPOSITIONS ) - exp_program = export( - mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes - ).run_decompositions(get_decompositions(experimental_decompositions)) + exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes) return exp_program From 4df8f6dda730a10817ac99eb724f9d4f529f2767 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 4 Dec 2023 14:52:18 -0800 Subject: [PATCH 11/14] chore: fix kwargs in tracer Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/_tracer.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index c5089c5c3d..6d24f88e8b 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -1,17 +1,12 @@ from __future__ import annotations import logging -from typing import Any, Optional, Tuple, Union +from typing import Any, Tuple import torch from torch.export import Dim, export from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._defaults import ( - DEBUG, - DEVICE, - ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - default_device, -) +from torch_tensorrt.dynamo._defaults import DEBUG, default_device from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device logger = logging.getLogger(__name__) @@ -20,9 +15,6 @@ def trace( mod: torch.nn.Module | torch.fx.GraphModule, inputs: Tuple[Any, ...], - device: Optional[Union[torch.device, str]] = DEVICE, - debug: bool = DEBUG, - enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, **kwargs: Any, ) -> torch.export.ExportedProgram: """Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT @@ -60,9 +52,9 @@ def trace( """ # Set log level at the top of compilation (torch_tensorrt.dynamo) + debug = kwargs.get("debug", DEBUG) if debug: set_log_level(logger.parent, logging.DEBUG) - device = to_torch_device(device if device else default_device()) device = to_torch_device(kwargs.get("device", default_device())) torch_inputs = get_torch_inputs(inputs, device) @@ -90,10 +82,6 @@ def trace( dynamic_shapes[input.name] = dynamic_dims - experimental_decompositions = kwargs.get( - "enable_experimental_decompositions", ENABLE_EXPERIMENTAL_DECOMPOSITIONS - ) - exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes) return exp_program From 536cbe6f6beed078415bca178dd87539224870c3 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 4 Dec 2023 14:57:32 -0800 Subject: [PATCH 12/14] chore: tracer inputs check Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/_tracer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 6d24f88e8b..11c0f6b3ac 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -60,7 +60,7 @@ def trace( torch_inputs = get_torch_inputs(inputs, device) dynamic_shapes = {} for input in inputs: - if input.shape_mode == Input._ShapeMode.DYNAMIC: + if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC: if not input.name: raise AssertionError( f"Expected a name for a dynamic input with shape {input.shape} but found none" From 073b3e2f4ad63069b192be4e463c526235d1e770 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 5 Dec 2023 12:29:55 -0800 Subject: [PATCH 13/14] chore: Update test case export utilities Signed-off-by: Dheeraj Peri --- tests/py/dynamo/models/test_export_serde.py | 112 ++++++++++---------- 1 file changed, 55 insertions(+), 57 deletions(-) diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index 9978395f40..c9750c1a50 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -5,7 +5,6 @@ import torch import torch_tensorrt as torchtrt import torchvision.models as models -from torch._export.serde.serialize import deserialize, serialize from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity assertions = unittest.TestCase() @@ -255,8 +254,8 @@ def test_resnet18_save_load(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") - torch._export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") outputs_pyt = model(input) outputs_trt = trt_exp_program(input) @@ -275,57 +274,56 @@ def test_resnet18_save_load(ir): ) -# Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341 -# @pytest.mark.unit -# def test_hybrid_conv_fallback(ir): -# """ -# This tests export save and load functionality on a hybrid -# model where a conv (a weighted layer) has been forced to fallback to Pytorch. -# """ - -# class MyModule(torch.nn.Module): -# def __init__(self): -# super().__init__() -# self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) -# self.relu = torch.nn.ReLU() - -# def forward(self, x): -# conv = self.conv(x) -# relu = self.relu(conv) -# mul = relu * 0.5 -# return mul - -# model = MyModule().eval().cuda() -# input = torch.randn((1, 3, 224, 224)).to("cuda") - -# compile_spec = { -# "inputs": [ -# torchtrt.Input( -# input.shape, dtype=torch.float, format=torch.contiguous_format -# ) -# ], -# "ir": ir, -# "min_block_size": 1, -# "torch_executed_ops": "torch.ops.aten.convolution.default", -# } - -# trt_exp_program = torchtrt.compile(model, **compile_spec) -# torch._export.save(trt_exp_program, "/tmp/trt.ep") -# deser_trt_exp_program = torch._export.load("/tmp/trt.ep") - -# outputs_pyt = model(input) -# outputs_trt = trt_exp_program(input) -# for idx in range(len(outputs_pyt)): -# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) - -# outputs_trt_deser = deser_trt_exp_program(input) -# for idx in range(len(outputs_pyt)): -# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) +@pytest.mark.unit +def test_hybrid_conv_fallback(ir): + """ + This tests export save and load functionality on a hybrid + model where a conv (a weighted layer) has been forced to fallback to Pytorch. + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + relu = self.relu(conv) + mul = relu * 0.5 + return mul + + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "ir": ir, + "min_block_size": 1, + "torch_executed_ops": "torch.ops.aten.convolution.default", + } + + trt_exp_program = torchtrt.compile(model, **compile_spec) + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") + + outputs_pyt = model(input) + outputs_trt = trt_exp_program(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + outputs_trt_deser = deser_trt_exp_program(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) From 282339bc2bfe90494105cb79ea0eaa2ac10cd81a Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 5 Dec 2023 14:20:02 -0800 Subject: [PATCH 14/14] chore: disable test_hybrid_conv_fallback Signed-off-by: Dheeraj Peri --- tests/py/dynamo/models/test_export_serde.py | 107 ++++++++++---------- 1 file changed, 54 insertions(+), 53 deletions(-) diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index c9750c1a50..f5911cb940 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -274,56 +274,57 @@ def test_resnet18_save_load(ir): ) -@pytest.mark.unit -def test_hybrid_conv_fallback(ir): - """ - This tests export save and load functionality on a hybrid - model where a conv (a weighted layer) has been forced to fallback to Pytorch. - """ - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) - self.relu = torch.nn.ReLU() - - def forward(self, x): - conv = self.conv(x) - relu = self.relu(conv) - mul = relu * 0.5 - return mul - - model = MyModule().eval().cuda() - input = torch.randn((1, 3, 224, 224)).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - input.shape, dtype=torch.float, format=torch.contiguous_format - ) - ], - "ir": ir, - "min_block_size": 1, - "torch_executed_ops": "torch.ops.aten.convolution.default", - } - - trt_exp_program = torchtrt.compile(model, **compile_spec) - torch.export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch.export.load("/tmp/trt.ep") - - outputs_pyt = model(input) - outputs_trt = trt_exp_program(input) - for idx in range(len(outputs_pyt)): - cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - outputs_trt_deser = deser_trt_exp_program(input) - for idx in range(len(outputs_pyt)): - cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) +# Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341 +# @pytest.mark.unit +# def test_hybrid_conv_fallback(ir): +# """ +# This tests export save and load functionality on a hybrid +# model where a conv (a weighted layer) has been forced to fallback to Pytorch. +# """ + +# class MyModule(torch.nn.Module): +# def __init__(self): +# super().__init__() +# self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) +# self.relu = torch.nn.ReLU() + +# def forward(self, x): +# conv = self.conv(x) +# relu = self.relu(conv) +# mul = relu * 0.5 +# return mul + +# model = MyModule().eval().cuda() +# input = torch.randn((1, 3, 224, 224)).to("cuda") + +# compile_spec = { +# "inputs": [ +# torchtrt.Input( +# input.shape, dtype=torch.float, format=torch.contiguous_format +# ) +# ], +# "ir": ir, +# "min_block_size": 1, +# "torch_executed_ops": "torch.ops.aten.convolution.default", +# } + +# trt_exp_program = torchtrt.compile(model, **compile_spec) +# torch.export.save(trt_exp_program, "/tmp/trt.ep") +# deser_trt_exp_program = torch.export.load("/tmp/trt.ep") + +# outputs_pyt = model(input) +# outputs_trt = trt_exp_program(input) +# for idx in range(len(outputs_pyt)): +# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) +# assertions.assertTrue( +# cos_sim > COSINE_THRESHOLD, +# msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", +# ) + +# outputs_trt_deser = deser_trt_exp_program(input) +# for idx in range(len(outputs_pyt)): +# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) +# assertions.assertTrue( +# cos_sim > COSINE_THRESHOLD, +# msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", +# )