diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index bfc19cce45..8a6063c33d 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -141,10 +141,36 @@ jobs: cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py - ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py popd + tests-py-dynamo-serde: + name: Test dynamo export serde [Python] + needs: [generate-matrix, build] + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/tensorrt + package-name: torch_tensorrt + pre-script: packaging/pre_build_script.sh + uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main + with: + job-name: tests-py-dynamo-serde + repository: "pytorch/tensorrt" + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + script: | + export USE_HOST_DEPS=1 + pushd . + cd tests/py/dynamo + ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver + ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py + popd + tests-py-torch-compile-be: name: Test torch compile backend [Python] needs: [generate-matrix, build] diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 6e43a23903..9acb073c62 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -47,6 +47,7 @@ class _ShapeMode(Enum): high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET torch_dtype: torch.dtype = torch.float32 torch_tensor: torch.Tensor = None + name: str = "" def __init__(self, *args: Any, **kwargs: Any) -> None: """__init__ Method for torch_tensorrt.Input @@ -68,7 +69,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: format (torch.memory_format or torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) tensor_domain (Tuple(float, float), optional): The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]). Note: Entering "None" (or not specifying) will set the bound to [0, 2) - + torch_tensor (torch.Tensor): Holds a corresponding torch tensor with this Input. + name (str, optional): Name of this input in the input nn.Module's forward function. Used to specify dynamic shapes for the corresponding input in dynamo tracer. Examples: - Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last) - Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW) @@ -180,6 +182,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: else: self.torch_tensor = self.example_tensor() + if "name" in kwargs: + self.name = kwargs["name"] + def __str__(self) -> str: if self.shape_mode == Input._ShapeMode.STATIC: return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format( diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d31be8a413..c6895e7907 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -34,7 +34,7 @@ convert_module, repair_long_or_double_inputs, ) -from torch_tensorrt.dynamo.lowering import apply_lowering_passes +from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.utils import ( get_torch_inputs, prepare_inputs, @@ -146,6 +146,13 @@ def compile( inputs = prepare_inputs(inputs) device = to_torch_tensorrt_device(device) + if not isinstance(exported_program, ExportedProgram): + raise AssertionError( + f"Input graph should be an ExportedProgram but got type {type(exported_program)}" + ) + exported_program = exported_program.run_decompositions( + get_decompositions(enable_experimental_decompositions) + ) gm = exported_program.module() logger.debug("Input graph: " + str(gm.graph)) diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index fa1bd214ac..df9150ea2d 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -229,6 +229,8 @@ def create_trt_exp_program( """ input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] output_nodes = [node for node in gm.graph.nodes if node.op == "output"] + assert output_nodes + output_nodes = output_nodes[0].args[0] input_specs = [ InputSpec(InputKind.USER_INPUT, TensorArgument(name=node.name), node.target) @@ -276,6 +278,7 @@ def inline_trt_modules( (trt_module_node.args, trt_module.engine), ) trt_node.meta["val"] = [] + assert num_outputs > 0 # Generate meta data for TRT node (a FakeTensor with corresponding output shape) for idx in range(num_outputs): trt_node.meta["val"].append( @@ -292,12 +295,16 @@ def inline_trt_modules( # Insert getitem nodes as outputs (for export serialization to work) with gm.graph.inserting_after(trt_node): getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0)) + getitem_output.meta["val"] = trt_node.meta["val"] trt_module_node.replace_all_uses_with(getitem_output) else: # Multiple outputs case: # Replace uses of submodule with the trt_node. # getitem nodes are already added inherently by the partitioner trt_module_node.replace_all_uses_with(trt_node) + getitem_nodes = trt_node.users + for idx, getitem_node in enumerate(getitem_nodes): + getitem_node.meta["val"] = trt_node.meta["val"][idx] # Erase the TRT submodule (call_module) node. gm.graph.erase_node(trt_module_node) diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 5fdca08399..11c0f6b3ac 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -1,45 +1,20 @@ from __future__ import annotations import logging -import unittest.mock -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Tuple import torch -from torch._export import dynamic_dim, export -from torch_tensorrt._Device import Device +from torch.export import Dim, export from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._defaults import ( - DEBUG, - DEVICE, - ENABLE_EXPERIMENTAL_DECOMPOSITIONS, - default_device, -) -from torch_tensorrt.dynamo.lowering import get_decompositions +from torch_tensorrt.dynamo._defaults import DEBUG, default_device from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device logger = logging.getLogger(__name__) -def get_random_tensor( - shape: List[Any], dtype: torch.dtype, device: torch.device -) -> torch.Tensor: - if dtype == torch.int32 or dtype == torch.int64: - return torch.randint(2, 10, shape, dtype=dtype, device=device) - elif dtype in (torch.float64, torch.float32, torch.float16): - return torch.randn(shape, dtype=dtype, device=device) - else: - logger.critical( - "Invalid dtype detected in creating input tensors for tracing the graph." - ) - raise - - def trace( mod: torch.nn.Module | torch.fx.GraphModule, inputs: Tuple[Any, ...], - device: Optional[Union[Device, torch.device, str]] = DEVICE, - debug: bool = DEBUG, - enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, **kwargs: Any, ) -> torch.export.ExportedProgram: """Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT @@ -65,9 +40,9 @@ def trace( torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] Keyword Arguments: - device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: + device (Union(torch.device, dict)): Target device for TensorRT engines to run on :: - device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) + device=torch.device("cuda:0") debug (bool): Enable debuggable engine enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT. @@ -77,50 +52,36 @@ def trace( """ # Set log level at the top of compilation (torch_tensorrt.dynamo) + debug = kwargs.get("debug", DEBUG) if debug: set_log_level(logger.parent, logging.DEBUG) - device = to_torch_device(device if device else default_device()) - # Determine the dynamic dimension and setup constraints to input dimensions as dictated by TensorRT - # Torch dynamo does not allow 0/1 value for dynamic dimensions - # for inputs during tracing. Hence we create new inputs for export + device = to_torch_device(kwargs.get("device", default_device())) torch_inputs = get_torch_inputs(inputs, device) - trace_inputs = [] - constraints = [] - for idx, input in enumerate(inputs): - if input.shape_mode == Input._ShapeMode.DYNAMIC: + dynamic_shapes = {} + for input in inputs: + if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC: + if not input.name: + raise AssertionError( + f"Expected a name for a dynamic input with shape {input.shape} but found none" + ) min_shape = input.shape["min_shape"] opt_shape = input.shape["opt_shape"] max_shape = input.shape["max_shape"] assert len(min_shape) == len(opt_shape) == len(max_shape) - - constraint_dims = [] - new_shape = [] + dynamic_dims = {} for dim in range(len(min_shape)): if min_shape[dim] == opt_shape[dim] == max_shape[dim]: - new_shape.append(torch_inputs[idx].shape[dim]) + continue else: - constraint_dims.append(dim) - if torch_inputs[idx].shape[dim] == 1: - new_shape.append(torch_inputs[idx].shape[dim] + 1) - else: - new_shape.append(torch_inputs[idx].shape[dim]) - - trace_input = get_random_tensor(new_shape, torch_inputs[idx].dtype, device) + dynamic_dims[dim] = Dim( + input.name + "_" + str(dim), + min=min_shape[dim], + max=max_shape[dim], + ) - for dim in constraint_dims: - if min_shape[dim] > 1: - constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim)) - if max_shape[dim] > 1: - constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim]) - trace_inputs.append(trace_input) - else: - trace_inputs.append(torch_inputs[idx]) + dynamic_shapes[input.name] = dynamic_dims - with unittest.mock.patch( - "torch._export.DECOMP_TABLE", - get_decompositions(enable_experimental_decompositions), - ): - exp_program = export(mod, tuple(trace_inputs), constraints=constraints) + exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes) return exp_program diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 057a95879d..d110845145 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -36,6 +36,7 @@ def forward(self, x): opt_shape=(4, 3, 224, 224), max_shape=(8, 3, 224, 224), dtype=torch.float32, + name="x", ) ], "device": torchtrt.Device("cuda:0"), @@ -88,6 +89,7 @@ def forward(self, x): opt_shape=(4, 3, 224, 224), max_shape=(8, 3, 224, 224), dtype=torch.float32, + name="x", ) ], "device": torchtrt.Device("cuda:0"), diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index d66ef5d89e..f5911cb940 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -5,7 +5,6 @@ import torch import torch_tensorrt as torchtrt import torchvision.models as models -from torch._export.serde.serialize import deserialize, serialize from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity assertions = unittest.TestCase() @@ -45,8 +44,8 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") - serialized_prog = serialize(trt_exp_program) - deserialized_prog = deserialize(*serialized_prog) + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") # Check Pyt and TRT exported program outputs cos_sim = cosine_similarity(model(input), trt_exp_program(input)[0]) @@ -55,7 +54,7 @@ def forward(self, x): msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Check Pyt and deserialized TRT exported program outputs - cos_sim = cosine_similarity(model(input), deserialized_prog(input)[0]) + cos_sim = cosine_similarity(model(input), deser_trt_exp_program(input)[0]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", @@ -97,8 +96,8 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") - serialized_prog = serialize(trt_exp_program) - deserialized_prog = deserialize(*serialized_prog) + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") # Check Pyt and TRT exported program outputs outputs_pyt = model(input) outputs_trt = trt_exp_program(input) @@ -110,7 +109,7 @@ def forward(self, x): ) # Check Pyt and deserialized TRT exported program outputs - outputs_trt_deser = deserialized_prog(input) + outputs_trt_deser = deser_trt_exp_program(input) for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) assertions.assertTrue( @@ -154,8 +153,8 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") - torch._export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") outputs_pyt = model(input) outputs_trt = trt_exp_program(input) @@ -213,8 +212,8 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") - torch._export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") outputs_pyt = model(input) outputs_trt = trt_exp_program(input) @@ -234,45 +233,45 @@ def forward(self, x): ) -# TODO (peri044) : Enable this test once the _frozen_param0 attribute resulting in sym_int ops issue is fixed. -# @pytest.mark.unit -# def test_resnet18_save_load(ir): -# """ -# This tests export save and load functionality on Resnet18 model -# """ -# model = models.resnet18().eval().cuda() -# input = torch.randn((1, 3, 224, 224)).to("cuda") +@pytest.mark.unit +def test_resnet18_save_load(ir): + """ + This tests export save and load functionality on Resnet18 model + """ + model = models.resnet18().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") -# compile_spec = { -# "inputs": [ -# torchtrt.Input( -# input.shape, dtype=torch.float, format=torch.contiguous_format -# ) -# ], -# "ir": ir, -# "min_block_size": 1, -# } + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "ir": ir, + "min_block_size": 1, + } -# exp_program = torchtrt.dynamo.trace(model, **compile_spec) -# trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) -# trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") -# torch._export.save(trt_exp_program, "/tmp/trt.ep") -# deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + exp_program = torchtrt.dynamo.trace(model, **compile_spec) + trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) + trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") -# outputs_pyt = model(input) -# outputs_trt = trt_exp_program(input) -# cos_sim = cosine_similarity(outputs_pyt, outputs_trt) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) + outputs_pyt = model(input) + outputs_trt = trt_exp_program(input) + cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) -# outputs_trt_deser = deser_trt_exp_program(input) -# cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) + outputs_trt_deser = deser_trt_exp_program(input) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) # Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341 @@ -310,8 +309,8 @@ def forward(self, x): # } # trt_exp_program = torchtrt.compile(model, **compile_spec) -# torch._export.save(trt_exp_program, "/tmp/trt.ep") -# deser_trt_exp_program = torch._export.load("/tmp/trt.ep") +# torch.export.save(trt_exp_program, "/tmp/trt.ep") +# deser_trt_exp_program = torch.export.load("/tmp/trt.ep") # outputs_pyt = model(input) # outputs_trt = trt_exp_program(input)