Skip to content

chore: Switch to new export apis #2376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,36 @@ jobs:
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
popd

tests-py-dynamo-serde:
name: Test dynamo export serde [Python]
needs: [generate-matrix, build]
strategy:
fail-fast: false
matrix:
include:
- repository: pytorch/tensorrt
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
with:
job-name: tests-py-dynamo-serde
repository: "pytorch/tensorrt"
ref: ""
test-infra-repository: pytorch/test-infra
test-infra-ref: main
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
pre-script: ${{ matrix.pre-script }}
script: |
export USE_HOST_DEPS=1
pushd .
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
popd

tests-py-torch-compile-be:
name: Test torch compile backend [Python]
needs: [generate-matrix, build]
Expand Down
7 changes: 6 additions & 1 deletion py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class _ShapeMode(Enum):
high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET
torch_dtype: torch.dtype = torch.float32
torch_tensor: torch.Tensor = None
name: str = ""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""__init__ Method for torch_tensorrt.Input
Expand All @@ -68,7 +69,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
format (torch.memory_format or torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
tensor_domain (Tuple(float, float), optional): The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]).
Note: Entering "None" (or not specifying) will set the bound to [0, 2)

torch_tensor (torch.Tensor): Holds a corresponding torch tensor with this Input.
name (str, optional): Name of this input in the pytorch graph. Used to specify dynamic shapes in dynamo tracer.
Examples:
- Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
- Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW)
Expand Down Expand Up @@ -180,6 +182,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
else:
self.torch_tensor = self.example_tensor()

if "name" in kwargs:
self.name = kwargs["name"]

def __str__(self) -> str:
if self.shape_mode == Input._ShapeMode.STATIC:
return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format(
Expand Down
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def create_trt_exp_program(
"""
input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
output_nodes = [node for node in gm.graph.nodes if node.op == "output"]
assert output_nodes
output_nodes = output_nodes[0].args[0]

input_specs = [
InputSpec(InputKind.USER_INPUT, TensorArgument(name=node.name), node.target)
Expand Down Expand Up @@ -276,6 +278,7 @@ def inline_trt_modules(
(trt_module_node.args, trt_module.engine),
)
trt_node.meta["val"] = []
assert num_outputs
# Generate meta data for TRT node (a FakeTensor with corresponding output shape)
for idx in range(num_outputs):
trt_node.meta["val"].append(
Expand All @@ -292,12 +295,16 @@ def inline_trt_modules(
# Insert getitem nodes as outputs (for export serialization to work)
with gm.graph.inserting_after(trt_node):
getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0))
getitem_output.meta["val"] = trt_node.meta["val"]
trt_module_node.replace_all_uses_with(getitem_output)
else:
# Multiple outputs case:
# Replace uses of submodule with the trt_node.
# getitem nodes are already added inherently by the partitioner
trt_module_node.replace_all_uses_with(trt_node)
getitem_nodes = trt_node.users
for idx, getitem_node in enumerate(getitem_nodes):
getitem_node.meta["val"] = trt_node.meta["val"][idx]

# Erase the TRT submodule (call_module) node.
gm.graph.erase_node(trt_module_node)
Expand Down
82 changes: 29 additions & 53 deletions py/torch_tensorrt/dynamo/_tracer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

import logging
import unittest.mock
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union

import torch
from torch._export import dynamic_dim, export
from torch_tensorrt._Device import Device
from torch.export import Dim, export
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._defaults import (
DEBUG,
Expand All @@ -20,24 +18,10 @@
logger = logging.getLogger(__name__)


def get_random_tensor(
shape: List[Any], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
if dtype == torch.int32 or dtype == torch.int64:
return torch.randint(2, 10, shape, dtype=dtype, device=device)
elif dtype in (torch.float64, torch.float32, torch.float16):
return torch.randn(shape, dtype=dtype, device=device)
else:
logger.critical(
"Invalid dtype detected in creating input tensors for tracing the graph."
)
raise


def trace(
mod: torch.nn.Module | torch.fx.GraphModule,
inputs: Tuple[Any, ...],
device: Optional[Union[Device, torch.device, str]] = DEVICE,
device: Optional[Union[torch.device, str]] = DEVICE,
debug: bool = DEBUG,
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
**kwargs: Any,
Expand Down Expand Up @@ -65,9 +49,9 @@ def trace(
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
]
Keyword Arguments:
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
device (Union(torch.device, dict)): Target device for TensorRT engines to run on ::

device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
device=torch.device("cuda:0")

debug (bool): Enable debuggable engine
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
Expand All @@ -81,46 +65,38 @@ def trace(
set_log_level(logger.parent, logging.DEBUG)
device = to_torch_device(device if device else default_device())

# Determine the dynamic dimension and setup constraints to input dimensions as dictated by TensorRT
# Torch dynamo does not allow 0/1 value for dynamic dimensions
# for inputs during tracing. Hence we create new inputs for export
device = to_torch_device(kwargs.get("device", default_device()))
torch_inputs = get_torch_inputs(inputs, device)
trace_inputs = []
constraints = []
for idx, input in enumerate(inputs):
dynamic_shapes = {}
for input in inputs:
if input.shape_mode == Input._ShapeMode.DYNAMIC:
if not input.name:
raise AssertionError(
f"Expected a name for a dynamic input with shape {input.shape} but found none"
)
min_shape = input.shape["min_shape"]
opt_shape = input.shape["opt_shape"]
max_shape = input.shape["max_shape"]
assert len(min_shape) == len(opt_shape) == len(max_shape)

constraint_dims = []
new_shape = []
dynamic_dims = {}
for dim in range(len(min_shape)):
if min_shape[dim] == opt_shape[dim] == max_shape[dim]:
new_shape.append(torch_inputs[idx].shape[dim])
continue
else:
constraint_dims.append(dim)
if torch_inputs[idx].shape[dim] == 1:
new_shape.append(torch_inputs[idx].shape[dim] + 1)
else:
new_shape.append(torch_inputs[idx].shape[dim])

trace_input = get_random_tensor(new_shape, torch_inputs[idx].dtype, device)

for dim in constraint_dims:
if min_shape[dim] > 1:
constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim))
if max_shape[dim] > 1:
constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim])
trace_inputs.append(trace_input)
else:
trace_inputs.append(torch_inputs[idx])

with unittest.mock.patch(
"torch._export.DECOMP_TABLE",
get_decompositions(enable_experimental_decompositions),
):
exp_program = export(mod, tuple(trace_inputs), constraints=constraints)
dynamic_dims[dim] = Dim(
input.name + "_" + str(dim),
min=min_shape[dim],
max=max_shape[dim],
)

dynamic_shapes[input.name] = dynamic_dims

experimental_decompositions = kwargs.get(
"enable_experimental_decompositions", ENABLE_EXPERIMENTAL_DECOMPOSITIONS
)

exp_program = export(
mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes
).run_decompositions(get_decompositions(experimental_decompositions))

return exp_program
2 changes: 2 additions & 0 deletions tests/py/dynamo/models/test_dyn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def forward(self, x):
opt_shape=(4, 3, 224, 224),
max_shape=(8, 3, 224, 224),
dtype=torch.float32,
name="x",
)
],
"device": torchtrt.Device("cuda:0"),
Expand Down Expand Up @@ -88,6 +89,7 @@ def forward(self, x):
opt_shape=(4, 3, 224, 224),
max_shape=(8, 3, 224, 224),
dtype=torch.float32,
name="x",
)
],
"device": torchtrt.Device("cuda:0"),
Expand Down
90 changes: 45 additions & 45 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
serialized_prog = serialize(trt_exp_program)
deserialized_prog = deserialize(*serialized_prog)
torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")

# Check Pyt and TRT exported program outputs
cos_sim = cosine_similarity(model(input), trt_exp_program(input)[0])
Expand All @@ -55,7 +55,7 @@ def forward(self, x):
msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# Check Pyt and deserialized TRT exported program outputs
cos_sim = cosine_similarity(model(input), deserialized_prog(input)[0])
cos_sim = cosine_similarity(model(input), deser_trt_exp_program(input)[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
Expand Down Expand Up @@ -97,8 +97,8 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
serialized_prog = serialize(trt_exp_program)
deserialized_prog = deserialize(*serialized_prog)
torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
# Check Pyt and TRT exported program outputs
outputs_pyt = model(input)
outputs_trt = trt_exp_program(input)
Expand All @@ -110,7 +110,7 @@ def forward(self, x):
)

# Check Pyt and deserialized TRT exported program outputs
outputs_trt_deser = deserialized_prog(input)
outputs_trt_deser = deser_trt_exp_program(input)
for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
assertions.assertTrue(
Expand Down Expand Up @@ -154,8 +154,8 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
torch._export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")
torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")

outputs_pyt = model(input)
outputs_trt = trt_exp_program(input)
Expand Down Expand Up @@ -213,8 +213,8 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
torch._export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")
torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")

outputs_pyt = model(input)
outputs_trt = trt_exp_program(input)
Expand All @@ -234,45 +234,45 @@ def forward(self, x):
)


# TODO (peri044) : Enable this test once the _frozen_param0 attribute resulting in sym_int ops issue is fixed.
# @pytest.mark.unit
# def test_resnet18_save_load(ir):
# """
# This tests export save and load functionality on Resnet18 model
# """
# model = models.resnet18().eval().cuda()
# input = torch.randn((1, 3, 224, 224)).to("cuda")
@pytest.mark.unit
def test_resnet18_save_load(ir):
"""
This tests export save and load functionality on Resnet18 model
"""
model = models.resnet18().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

# compile_spec = {
# "inputs": [
# torchtrt.Input(
# input.shape, dtype=torch.float, format=torch.contiguous_format
# )
# ],
# "ir": ir,
# "min_block_size": 1,
# }
compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"ir": ir,
"min_block_size": 1,
}

# exp_program = torchtrt.dynamo.trace(model, **compile_spec)
# trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
# trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
# torch._export.save(trt_exp_program, "/tmp/trt.ep")
# deser_trt_exp_program = torch._export.load("/tmp/trt.ep")
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
torch._export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch._export.load("/tmp/trt.ep")

# outputs_pyt = model(input)
# outputs_trt = trt_exp_program(input)
# cos_sim = cosine_similarity(outputs_pyt, outputs_trt)
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )
outputs_pyt = model(input)
outputs_trt = trt_exp_program(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# outputs_trt_deser = deser_trt_exp_program(input)
# cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser)
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )
outputs_trt_deser = deser_trt_exp_program(input)

cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


# Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341
Expand Down