-
Notifications
You must be signed in to change notification settings - Fork 365
feat: Implement Dynamic shapes + fallback support for export path #2271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 86 commits
f1f202e
abaf047
bb1f3cf
3d05b4d
8d99be5
0aad214
8899735
8af2627
f6969be
bad1594
db56dd6
bf961f5
b13aa82
dd95620
6bd3c64
248073f
3e5f434
6bf6945
3b6e1e7
2107d8e
fd5a41e
f047651
de9795d
ab76c0d
5f2a4f3
566fbb0
4949549
e4df382
d022f4a
9610ba7
e19aae7
d95c360
2860be6
d7f2477
b50d362
91b47fb
51266db
2005db7
a8cb1fe
7ff9309
692921e
e926724
0cfd23b
2c85bc7
1de79b3
33ddf46
39e7d98
760eda6
b3c9666
fbfb8ef
b7056a1
27681c2
056cbf3
ece276c
73a0bce
890ba72
980dc1c
dfc4899
09b099a
157bb2d
0005a31
5526bca
3420fb0
4a0afd3
399f929
4b44ff2
a94a075
4e308f1
95d3f98
529262a
16d19e8
aee529b
89acc8e
6858c0a
695dc9b
a5cdd24
24922fc
645816e
708ac64
9bcaf49
6704cb7
560c779
03f5f2d
0349810
912fcab
31c09b2
9f0a589
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,8 @@ class DataType { | |
enum Value : int8_t { | ||
/// INT64 | ||
kLong, | ||
/// FP64 | ||
kDouble, | ||
/// FP32 | ||
kFloat, | ||
/// FP16 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
.. _runtime: | ||
|
||
Dynamic shapes with Torch-TensorRT | ||
==================================== | ||
|
||
By default, you can run a pytorch model with varied input shapes and the output shapes are determined eagerly. | ||
However, Torch-TensorRT is an AOT compiler which requires some prior information about the input shapes to compile and optimize the model. | ||
In the case of dynamic input shapes, we must provide the (min_shape, opt_shape, max_shape) arguments so that the model can be optimized for | ||
these range of input shapes. An example usage of static and dynamic shapes is as follows. | ||
|
||
NOTE: The following code uses dynamo IR. Incase of Torchscript IR, please swap out ``ir=dynamo`` with ``ir=ts`` and the behavior is exactly the same. | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
|
||
model = MyModel().eval().cuda() | ||
# Compile with static shapes | ||
inputs = torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32) | ||
# or compile with dynamic shapes | ||
inputs = torch_tensorrt.Input(min_shape=[1, 3, 224, 224], | ||
opt_shape=[4, 3, 224, 224], | ||
max_shape=[8, 3, 224, 224], | ||
dtype=torch.float32) | ||
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) | ||
|
||
Under the hood | ||
-------------- | ||
|
||
There are two phases of compilation when we use ``torch_tensorrt.compile`` API with ``ir=dynamo`` (default). | ||
|
||
- aten_tracer.trace (which uses torch.export to trace the graph with the given inputs) | ||
|
||
In the tracing phase, we use torch.export along with the constraints. In the case of | ||
dynamic shaped inputs, the range can be provided to the tracing via constraints. Please | ||
refer to this `docstring <https://github.com/pytorch/pytorch/blob/5dcee01c2b89f6bedeef9dd043fd8d6728286582/torch/export/__init__.py#L372-L434>`_ | ||
for detailed information on how to set constraints. In short, we create new inputs for | ||
torch.export tracing and provide constraints on the min and max values(provided by the user), a particular dimension can take. | ||
Please take a look at ``aten_tracer.py`` file to understand how this works under the hood. | ||
|
||
- dynamo.compile (which compiles a torch.fx.GraphModule object using TensorRT) | ||
|
||
In the conversion to TensorRT, we use the user provided dynamic shape inputs. | ||
We perform shape analysis using dummy inputs (across min, opt and max shapes) and store the | ||
intermediate output shapes which can be used in case the graph has a mix of Pytorch | ||
and TensorRT submodules. | ||
|
||
Custom Constraints | ||
------------------ | ||
|
||
Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``, | ||
Torch-TensorRT automatically sets the constraints during ``torch.export`` tracing as follows | ||
|
||
.. code-block:: python | ||
|
||
for dim in constraint_dims: | ||
if min_shape[dim] > 1: | ||
constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim)) | ||
if max_shape[dim] > 1: | ||
constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim]) | ||
|
||
Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them. | ||
For example, in the case of BERT model compilation, there are two inputs and a constraint has to be set involving the sequence length size of these two inputs. | ||
|
||
.. code-block:: python | ||
|
||
constraints.append(dynamic_dim(trace_inputs[0], 0) == dynamic_dim(trace_inputs[1], 0)) | ||
|
||
|
||
If you have to provide any custom constraints to your model, the overall workflow for model compilation using ``ir=dynamo`` would involve a few steps. | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions | ||
# Assume the model has two inputs | ||
model = MyModel() | ||
torch_input_1 = torch.randn((1, 14), dtype=torch.int32).cuda() | ||
torch_input_2 = torch.randn((1, 14), dtype=torch.int32).cuda() | ||
|
||
dynamic_inputs = [torch_tensorrt.Input(min_shape=[1, 14], | ||
opt_shape=[4, 14], | ||
max_shape=[8, 14], | ||
dtype=torch.int32), | ||
torch_tensorrt.Input(min_shape=[1, 14], | ||
opt_shape=[4, 14], | ||
max_shape=[8, 14], | ||
dtype=torch.int32)] | ||
|
||
# Export the model with additional constraints | ||
constraints = [] | ||
# The following constraints are automatically added by Torch-TensorRT in the | ||
# general case when you call torch_tensorrt.compile directly on MyModel() | ||
constraints.append(dynamic_dim(torch_input_1, 0) < 8) | ||
constraints.append(dynamic_dim(torch_input_2, 0) < 8) | ||
# This is an additional constraint as instructed by Torchdynamo | ||
constraints.append(dynamic_dim(torch_input_1, 0) == dynamic_dim(torch_input_2, 0)) | ||
with unittest.mock.patch( | ||
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) | ||
): | ||
graph_module = export( | ||
model, (torch_input_1, torch_input_2), constraints=constraints | ||
).module() | ||
|
||
# Use the dynamo.compile API | ||
trt_mod = torch_tensorrt.dynamo.compile(graph_module, inputs=dynamic_inputs, **compile_spec) | ||
|
||
Limitations | ||
----------- | ||
|
||
If there are operations in the graph that use the dynamic dimension of the input, Pytorch | ||
introduces ``torch.ops.aten.sym_size.int`` ops in the graph. Currently, we cannot handle these operators and | ||
the compilation results in undefined behavior. We plan to add support for these operators and implement | ||
robust support for shape tensors in the next release. Here is an example of the limitation described above | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
|
||
class MyModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) | ||
|
||
def forward(self, x): | ||
x = self.avgpool(x) | ||
out = torch.flatten(x, 1) | ||
return out | ||
|
||
model = MyModel().eval().cuda() | ||
# Compile with dynamic shapes | ||
inputs = torch_tensorrt.Input(min_shape=(1, 512, 1, 1), | ||
opt_shape=(4, 512, 1, 1), | ||
max_shape=(8, 512, 1, 1), | ||
dtype=torch.float32) | ||
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) | ||
|
||
|
||
The traced graph of `MyModule()` looks as follows | ||
|
||
.. code-block:: python | ||
|
||
Post export graph: graph(): | ||
%arg0_1 : [num_users=2] = placeholder[target=arg0_1] | ||
%mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%arg0_1, [-1, -2], True), kwargs = {}) | ||
%sym_size : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%arg0_1, 0), kwargs = {}) | ||
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mean, [%sym_size, 512]), kwargs = {}) | ||
return (view,) | ||
|
||
|
||
Here the ``%sym_size`` node captures the dynamic batch and uses it in the ``aten.view`` layer. This requires shape tensors support | ||
which would be a part of our next release. | ||
|
||
Workaround (BERT static compilation example) | ||
------------------------------------------ | ||
|
||
In the case where you encounter the issues mentioned in the **Limitations** section, | ||
you can compile the model (static mode) with max input size that can be provided. In the cases of smaller inputs, | ||
we can pad them accordingly. This is only a workaround until we address the limitations. | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
from transformers.utils.fx import symbolic_trace as transformers_trace | ||
|
||
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() | ||
|
||
# Input sequence length is 20. | ||
input1 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda") | ||
input2 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda") | ||
|
||
model = transformers_trace(model, input_names=["input_ids", "attention_mask"]).eval().cuda() | ||
trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec) | ||
model_outputs = model(input, input2) | ||
|
||
# If you have a sequence of length 14, pad 6 zero tokens and run inference | ||
# or recompile for sequence length of 14. | ||
input1 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda") | ||
input2 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda") | ||
trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec) | ||
model_outputs = model(input, input2) | ||
|
||
|
||
Dynamic shapes with ir=torch_compile | ||
------------------------------------ | ||
|
||
``torch_tensorrt.compile(model, inputs, ir="torch_compile")`` returns a torch.compile boxed function with the backend | ||
configured to Tensorrt. In the case of ``ir=torch_compile``, users have to recompile for different input shapes. | ||
In the future, we plan to explore the option of compiling with dynamic shapes in the first execution of the model. | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
|
||
model = MyModel().eval().cuda() | ||
inputs = torch.randn((1, 3, 224, 224), dtype=float32) | ||
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs) | ||
# Compilation happens when you call the model | ||
trt_gm(inputs) | ||
|
||
# Recompilation happens with modified batch size | ||
inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32) | ||
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs_bs2) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -46,6 +46,7 @@ class _ShapeMode(Enum): | |||
low_tensor_domain_incl: float = 0.0 | ||||
high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET | ||||
torch_dtype: torch.dtype = torch.float32 | ||||
torch_tensor: torch.Tensor = None | ||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: | ||||
"""__init__ Method for torch_tensorrt.Input | ||||
|
@@ -171,6 +172,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: | |||
|
||||
self.tensor_domain = Input._parse_tensor_domain(domain) | ||||
|
||||
if "torch_tensor" in kwargs: | ||||
self.torch_tensor = kwargs["torch_tensor"] | ||||
else: | ||||
if self.shape_mode == Input._ShapeMode.DYNAMIC: | ||||
self.torch_tensor = self.example_tensor("opt_shape") | ||||
else: | ||||
self.torch_tensor = self.example_tensor() | ||||
Comment on lines
+175
to
+181
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
So the usecase you are saying is if the user provides
We could maybe pass a warning to users that torch_tensor is not being used in such cases. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see - I believe the feature ask from #2323 is to allow providing tensors for each of |
||||
|
||||
def __str__(self) -> str: | ||||
if self.shape_mode == Input._ShapeMode.STATIC: | ||||
return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format( | ||||
|
@@ -220,6 +229,8 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: | |||
return _enums.dtype.half | ||||
elif dtype == torch.float: | ||||
return _enums.dtype.float | ||||
elif dtype == torch.float64: | ||||
return _enums.dtype.double | ||||
elif dtype == torch.bool: | ||||
return _enums.dtype.bool | ||||
else: | ||||
|
@@ -249,6 +260,8 @@ def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: | |||
return torch.float | ||||
elif dtype == _enums.dtype.bool: | ||||
return torch.bool | ||||
elif dtype == _enums.dtype.double: | ||||
return torch.float64 | ||||
else: | ||||
# Default torch_dtype used in FX path | ||||
return torch.float32 | ||||
|
@@ -354,7 +367,7 @@ def from_tensor( | |||
) | ||||
else torch.channels_last | ||||
) | ||||
return cls(shape=t.shape, dtype=t.dtype, format=frmt) | ||||
return cls(shape=t.shape, dtype=t.dtype, format=frmt, torch_tensor=t) | ||||
|
||||
@classmethod | ||||
def from_tensors( | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a reference to this document in
index.rst
to have it reflect on the webpageThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done. thanks