Skip to content

feat: Transition export workflows to use torch._export APIs #2195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 56 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
f1f202e
feat: Move tracing to use aot export apis
peri044 Aug 8, 2023
abaf047
chore: minor changes
peri044 Aug 9, 2023
bb1f3cf
chore: minor changes
peri044 Aug 11, 2023
3d05b4d
chore: Rebase with main
peri044 Aug 11, 2023
8d99be5
chore: rebase
peri044 Aug 16, 2023
0aad214
chore: minor logging updates
peri044 Aug 17, 2023
8899735
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
8af2627
fix: Refactor tensor freezing in Dynamo
gs-olive Aug 5, 2023
f6969be
Key op fixes for failing tests
gs-olive Aug 5, 2023
bad1594
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
db56dd6
chore: Move to new export APIs
peri044 Aug 17, 2023
bf961f5
chore: rebase with dynamo_tensor_freeze branch
peri044 Aug 17, 2023
b13aa82
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
dd95620
fix: Refactor tensor freezing in Dynamo
gs-olive Aug 5, 2023
6bd3c64
Key op fixes for failing tests
gs-olive Aug 5, 2023
248073f
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
3e5f434
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
6bf6945
fix: Refactor tensor freezing in Dynamo
gs-olive Aug 5, 2023
3b6e1e7
Key op fixes for failing tests
gs-olive Aug 5, 2023
2107d8e
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
fd5a41e
chore: add BERT test case
peri044 Aug 18, 2023
f047651
chore: remove pdb
peri044 Aug 21, 2023
ab76c0d
chore: rebase
peri044 Aug 23, 2023
e4df382
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
d022f4a
fix: Refactor tensor freezing in Dynamo
gs-olive Aug 5, 2023
9610ba7
Key op fixes for failing tests
gs-olive Aug 5, 2023
e19aae7
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
2860be6
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 Aug 25, 2023
51266db
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
2005db7
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
a8cb1fe
fix: Move tracer code into try/except
gs-olive Aug 29, 2023
7ff9309
Custom implementation of AOT for compile
gs-olive Aug 29, 2023
692921e
Move fixes into Dynamo directory
gs-olive Aug 30, 2023
e926724
chore: rebase
peri044 Sep 5, 2023
27681c2
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
056cbf3
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
ece276c
fix: Move tracer code into try/except
gs-olive Aug 29, 2023
73a0bce
Custom implementation of AOT for compile
gs-olive Aug 29, 2023
890ba72
Move fixes into Dynamo directory
gs-olive Aug 30, 2023
980dc1c
chore: rebase
peri044 Sep 9, 2023
dfc4899
Move fixes into Dynamo directory
gs-olive Aug 30, 2023
09b099a
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 Sep 9, 2023
157bb2d
chore: updates
peri044 Sep 9, 2023
0005a31
Move fixes into Dynamo directory
gs-olive Aug 30, 2023
5526bca
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 Sep 11, 2023
3420fb0
chore: updates
peri044 Sep 11, 2023
399f929
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive Jul 21, 2023
4b44ff2
fix: Add constant folding utility to freezing
gs-olive Aug 12, 2023
a94a075
fix: Move tracer code into try/except
gs-olive Aug 29, 2023
4e308f1
Custom implementation of AOT for compile
gs-olive Aug 29, 2023
95d3f98
Move fixes into Dynamo directory
gs-olive Aug 30, 2023
529262a
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 Sep 11, 2023
aee529b
chore: rebase
peri044 Sep 12, 2023
e6d2d8d
chore: address review comments
peri044 Sep 12, 2023
6cd2bab
Merge branch 'main' into export_prototype
peri044 Sep 12, 2023
c7b2f3c
chore: fix imports
peri044 Sep 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 18 additions & 146 deletions py/torch_tensorrt/dynamo/aten_tracer.py
Original file line number Diff line number Diff line change
@@ -1,160 +1,32 @@
from __future__ import annotations

import copy
import logging
import sys
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import unittest.mock
from typing import Any, Tuple

import torch
import torch._dynamo as torchdynamo
from torch.fx.passes.infra.pass_base import PassResult
from torch_tensorrt.dynamo.utils import req_torch_version
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
compose_bmm,
compose_chunk,
compose_getitem_slice,
remove_ops,
replace_aten_op_with_indices,
replace_aten_reshape_alias_with_replace,
replace_builtin_ops,
replace_inplace_ops,
replace_native_layernorm_with_layernorm,
replace_transpose_mm_op_with_linear,
run_const_fold,
)
from typing_extensions import TypeAlias

Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]]
from torch._export import export
from torch_tensorrt.dynamo.backend.backends import constant_fold
from torch_tensorrt.dynamo.lowering import get_decompositions

logger = logging.getLogger(__name__)


class DynamoConfig:
"""
Manage Exir-specific configurations of Dynamo.
"""

def __init__(
self,
capture_scalar_outputs: bool = True,
guard_nn_modules: bool = True,
dynamic_shapes: bool = True,
specialize_int: bool = True,
verbose: bool = True,
) -> None:
self.capture_scalar_outputs = capture_scalar_outputs
self.guard_nn_modules = guard_nn_modules
self.dynamic_shapes = dynamic_shapes
self.specialize_int = specialize_int
self.verbose = verbose

def activate(self) -> None:
torchdynamo.config.capture_scalar_outputs = self.capture_scalar_outputs
torchdynamo.config.guard_nn_modules = self.guard_nn_modules
torchdynamo.config.dynamic_shapes = self.dynamic_shapes
torchdynamo.config.specialize_int = self.specialize_int
torchdynamo.config.verbose = self.verbose

def deactivate(self) -> None:
torchdynamo.config.capture_scalar_outputs = True
torchdynamo.config.guard_nn_modules = True
torchdynamo.config.dynamic_shapes = True
torchdynamo.config.specialize_int = True
torchdynamo.config.verbose = True


@contextmanager
def using_config(config: DynamoConfig) -> Generator[DynamoConfig, None, None]:
config.activate()
try:
yield config
finally:
config.deactivate()


@contextmanager
def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, None]:
"""
Temporarily increase the python interpreter stack recursion limit.
This is mostly used for pickling large scale modules.
"""
default = sys.getrecursionlimit()
if limit > default:
sys.setrecursionlimit(limit)
try:
yield
finally:
sys.setrecursionlimit(default)


@req_torch_version("2.dev")
def dynamo_trace(
f: Callable[..., Value],
# pyre-ignore
args: Tuple[Any, ...],
aten_graph: bool,
tracing_mode: str = "real",
dynamo_config: Optional[DynamoConfig] = None,
) -> Any: # Tuple[torch.fx.GraphModule, Set[_guards.Guard]]:
"""
TODO: Once we fully migrate to torchdynamo frontend, we will remove
this config option alltogether. For now, it helps with quick
experiments with playing around with TorchDynamo
"""
if dynamo_config is None:
dynamo_config = DynamoConfig()
with using_config(dynamo_config), setting_python_recursive_limit(2000):
torchdynamo.reset()
try:
return torchdynamo.export(
f,
*copy.deepcopy(args),
aten_graph=aten_graph,
tracing_mode=tracing_mode,
)
except torchdynamo.exc.Unsupported as exc:
raise RuntimeError(
"The user code is using a feature we don't support. "
"Please try torchdynamo.explain() to get possible the reasons",
) from exc
except Exception as exc:
raise RuntimeError(
"torchdynamo internal error occured. Please see above stacktrace"
) from exc


@req_torch_version("2.dev")
def trace(
model: torch.nn.Module | torch.fx.GraphModule,
inputs: Tuple[Any, ...],
**kwargs: Any,
) -> torch.fx.GraphModule:
"""
Optimized trace with necessary passes which re-compose some ops or replace some ops
These passes should be general and functional purpose
"""
passes_list = [
compose_bmm,
compose_chunk,
compose_getitem_slice,
replace_aten_reshape_alias_with_replace,
replace_aten_op_with_indices,
replace_transpose_mm_op_with_linear, # after compose_bmm
replace_native_layernorm_with_layernorm,
remove_ops,
replace_builtin_ops, # after replace_native_layernorm_with_layernorm
replace_inplace_ops, # remove it once functionalization is enabled
]

fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic")

for passes in passes_list:
pr: PassResult = passes(fx_module)
fx_module = pr.graph_module

fx_module(*inputs)

fx_module = run_const_fold(fx_module)
logger.info("Post export graph : %s\n", fx_module.graph)
return fx_module
# Set log level at the top of compilation (torch_tensorrt.dynamo)
if ("debug" in kwargs and kwargs["debug"]) and logger.parent:
logger.parent.setLevel(logging.DEBUG)
experimental_decompositions = kwargs.get(
"enable_experimental_decompositions", False
)
with unittest.mock.patch(
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
):
graph_module = export(model, tuple(inputs)).module()
constant_fold(graph_module)
logger.debug("Post export graph: " + str(graph_module.graph))
return graph_module
67 changes: 58 additions & 9 deletions tests/py/dynamo/models/test_models_export.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import torch
import timm
import pytest
import unittest

import pytest
import timm
import torch
import torch_tensorrt as torchtrt
import torchvision.models as models

from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
from transformers import BertModel

from torch_tensorrt.dynamo.utils import (
COSINE_THRESHOLD,
cosine_similarity,
)
from transformers.utils.fx import symbolic_trace as transformers_trace

assertions = unittest.TestCase()

Expand Down Expand Up @@ -118,6 +114,59 @@ def test_efficientnet_b0(ir):
torch.cuda.empty_cache()


@pytest.mark.unit
def test_bert_base_uncased(ir):
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
model = (
transformers_trace(model, input_names=["input_ids", "attention_mask"])
.eval()
.cuda()
)

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape,
dtype=input.dtype,
format=torch.contiguous_format,
),
torchtrt.Input(
input.shape,
dtype=input.dtype,
format=torch.contiguous_format,
),
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"truncate_long_and_double": True,
"ir": ir,
"min_block_size": 10,
"torch_executed_ops": {"torch.ops.aten.gelu.default"},
}
trt_mod = torchtrt.compile(model, **compile_spec)
model_outputs = model(input, input2)
trt_model_outputs = trt_mod(input, input2)
assertions.assertTrue(
len(model_outputs) == len(trt_model_outputs),
msg=f"Number of outputs for BERT model compilation is different with Pytorch {len(model_outputs)} and TensorRT {len(trt_model_outputs)}. Please check the compilation.",
)
for index, key in enumerate(model_outputs):
out, trt_out = model_outputs[key], trt_model_outputs[index]
cos_sim = cosine_similarity(out, trt_out)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
torch._dynamo.reset()

with torch.no_grad():
torch.cuda.empty_cache()


@pytest.mark.unit
def test_resnet18_half(ir):
model = models.resnet18(pretrained=True).eval().to("cuda").half()
Expand Down