Skip to content

chore: Switch to new export apis #2376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class _ShapeMode(Enum):
high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET
torch_dtype: torch.dtype = torch.float32
torch_tensor: torch.Tensor = None
name: str = ""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""__init__ Method for torch_tensorrt.Input
Expand All @@ -68,7 +69,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
format (torch.memory_format or torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
tensor_domain (Tuple(float, float), optional): The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]).
Note: Entering "None" (or not specifying) will set the bound to [0, 2)

torch_tensor (torch.Tensor): Holds a corresponding torch tensor with this Input.
name (str, optional): Name of this input in the pytorch graph. Used to specify dynamic shapes in dynamo tracer.
Examples:
- Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
- Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW)
Expand Down Expand Up @@ -180,6 +182,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
else:
self.torch_tensor = self.example_tensor()

if "name" in kwargs:
self.name = kwargs["name"]

def __str__(self) -> str:
if self.shape_mode == Input._ShapeMode.STATIC:
return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format(
Expand Down
78 changes: 25 additions & 53 deletions py/torch_tensorrt/dynamo/_tracer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

import logging
import unittest.mock
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union

import torch
from torch._export import dynamic_dim, export
from torch_tensorrt._Device import Device
from torch.export import Dim, export
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._defaults import (
DEBUG,
Expand All @@ -20,24 +18,10 @@
logger = logging.getLogger(__name__)


def get_random_tensor(
shape: List[Any], dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
if dtype == torch.int32 or dtype == torch.int64:
return torch.randint(2, 10, shape, dtype=dtype, device=device)
elif dtype in (torch.float64, torch.float32, torch.float16):
return torch.randn(shape, dtype=dtype, device=device)
else:
logger.critical(
"Invalid dtype detected in creating input tensors for tracing the graph."
)
raise


def trace(
mod: torch.nn.Module | torch.fx.GraphModule,
inputs: Tuple[Any, ...],
device: Optional[Union[Device, torch.device, str]] = DEVICE,
device: Optional[Union[torch.device, str]] = DEVICE,
debug: bool = DEBUG,
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
**kwargs: Any,
Expand Down Expand Up @@ -65,9 +49,9 @@ def trace(
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
]
Keyword Arguments:
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
device (Union(torch.device, dict)): Target device for TensorRT engines to run on ::

device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
device=torch.device("cuda:0")

debug (bool): Enable debuggable engine
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
Expand All @@ -81,46 +65,34 @@ def trace(
set_log_level(logger.parent, logging.DEBUG)
device = to_torch_device(device if device else default_device())

# Determine the dynamic dimension and setup constraints to input dimensions as dictated by TensorRT
# Torch dynamo does not allow 0/1 value for dynamic dimensions
# for inputs during tracing. Hence we create new inputs for export
device = to_torch_device(kwargs.get("device", default_device()))
torch_inputs = get_torch_inputs(inputs, device)
trace_inputs = []
constraints = []
for idx, input in enumerate(inputs):
dynamic_shapes = {}
for input in inputs:
if input.shape_mode == Input._ShapeMode.DYNAMIC:
min_shape = input.shape["min_shape"]
opt_shape = input.shape["opt_shape"]
max_shape = input.shape["max_shape"]
assert len(min_shape) == len(opt_shape) == len(max_shape)

constraint_dims = []
new_shape = []
dynamic_dims = {}
for dim in range(len(min_shape)):
if min_shape[dim] == opt_shape[dim] == max_shape[dim]:
new_shape.append(torch_inputs[idx].shape[dim])
continue
else:
constraint_dims.append(dim)
if torch_inputs[idx].shape[dim] == 1:
new_shape.append(torch_inputs[idx].shape[dim] + 1)
else:
new_shape.append(torch_inputs[idx].shape[dim])

trace_input = get_random_tensor(new_shape, torch_inputs[idx].dtype, device)

for dim in constraint_dims:
if min_shape[dim] > 1:
constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim))
if max_shape[dim] > 1:
constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim])
trace_inputs.append(trace_input)
else:
trace_inputs.append(torch_inputs[idx])

with unittest.mock.patch(
"torch._export.DECOMP_TABLE",
get_decompositions(enable_experimental_decompositions),
):
exp_program = export(mod, tuple(trace_inputs), constraints=constraints)
dynamic_dims[dim] = Dim(
input.name + "_" + str(dim),
min=min_shape[dim],
max=max_shape[dim],
)

dynamic_shapes[input.name] = dynamic_dims

experimental_decompositions = kwargs.get(
"enable_experimental_decompositions", ENABLE_EXPERIMENTAL_DECOMPOSITIONS
)

exp_program = export(
mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes
).run_decompositions(get_decompositions(experimental_decompositions))

return exp_program
2 changes: 2 additions & 0 deletions tests/py/dynamo/models/test_dyn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def forward(self, x):
opt_shape=(4, 3, 224, 224),
max_shape=(8, 3, 224, 224),
dtype=torch.float32,
name="x",
)
],
"device": torchtrt.Device("cuda:0"),
Expand Down Expand Up @@ -88,6 +89,7 @@ def forward(self, x):
opt_shape=(4, 3, 224, 224),
max_shape=(8, 3, 224, 224),
dtype=torch.float32,
name="x",
)
],
"device": torchtrt.Device("cuda:0"),
Expand Down