diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 2ba9f4d754..6be97d42d9 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -1,12 +1,16 @@ from __future__ import annotations import logging -from functools import partial -from typing import Any, Callable, Sequence +import unittest +from typing import Any, Callable, Dict, Optional, Sequence import torch import torch._dynamo as td -from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler +import torch.utils._pytree as pytree +from torch._dynamo.utils import detect_fake_mode +from torch._functorch.aot_autograd import _aot_export_function +from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant +from torch._ops import OpOverload from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.compile import compile_module from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions @@ -33,8 +37,7 @@ def torch_tensorrt_backend( DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend - compiled_mod: torch.nn.Module = DEFAULT_BACKEND(gm, sample_inputs, **kwargs) - return compiled_mod + return DEFAULT_BACKEND(gm, sample_inputs, **kwargs) @td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc] @@ -42,22 +45,7 @@ def aot_torch_tensorrt_aten_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any ) -> torch.nn.Module: settings = parse_dynamo_kwargs(kwargs) - - custom_backend = partial( - _pretraced_backend, - settings=settings, - ) - - # Perform Pre-AOT Lowering for Module-Level Replacement - gm = pre_aot_substitutions(gm) - - # Invoke AOTAutograd to translate operators to aten - return aot_module_simplified( - gm, - sample_inputs, - fw_compiler=make_boxed_compiler(custom_backend), - decompositions=get_decompositions(settings.enable_experimental_decompositions), - ) + return _pretraced_backend(gm, sample_inputs, settings) def _pretraced_backend( @@ -75,22 +63,44 @@ def _pretraced_backend( Compiled FX GraphModule """ try: - logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) + logger.debug("Pre-AOT Autograd graph:\n" + str(gm.graph)) + + # Perform Pre-AOT Lowering for Module-Level Replacement + gm = pre_aot_substitutions(gm) + + fake_mode = detect_fake_mode(sample_inputs) + + # Place backend tracing within FakeTensor context allowing nonfake Tensors + with unittest.mock.patch.object( + fake_mode, "allow_non_fake_inputs", True + ), fake_mode: + # Invoke AOTAutograd to translate operators to aten + graph_module = aot_export_for_compile( + gm, + sample_inputs, + decompositions=get_decompositions( + settings.enable_experimental_decompositions + ), + ) - trt_compiled = compile_module( - gm, - sample_inputs, - settings=settings, - ) - return trt_compiled - except AssertionError: + logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) + + constant_fold(graph_module) + + trt_compiled = compile_module( + graph_module, + sample_inputs, + settings=settings, + ) + return trt_compiled + except (AssertionError, RuntimeError): if not settings.pass_through_build_failures: logger.warning( "TRT conversion failed on the subgraph. See trace above. " + "Returning GraphModule forward instead.", exc_info=True, ) - return gm.forward + return gm else: logger.critical( "Halting compilation on build failure since " @@ -100,3 +110,82 @@ def _pretraced_backend( + "specify pass_through_build_failures=False." ) raise + + +@torch.utils._python_dispatch._disable_current_modes() # type: ignore +def constant_fold(gm: torch.fx.GraphModule) -> Any: + """Adapted from: + https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197 + + Folds constants in the graph module, not skipping constructors + + Modifies the graph in-place and replaces node with constants + """ + cf = ConstantFolder(gm, skip_constructors=False) + cf.run() + + for node, constant in cf.node_replacements.items(): + replace_node_with_constant(gm, node, constant) + + erased_params = [] + for node in gm.graph.nodes: + if node.op == "get_attr" and len(node.users) == 0: + delattr(gm, node.target) + erased_params.append(node) + + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def aot_export_for_compile( + func: torch.fx.GraphModule, + args: Sequence[torch.Tensor], + *, + decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None, +) -> torch.fx.GraphModule: + """Adapted from: + https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158 + + Removed check for input aliasing in resultant subgraph - TRT is functional-only + + Exports the function to ATen for torch compile + """ + # Trace function with input arguments and decompositions + with torch.no_grad(): + fx_g, metadata, in_spec, out_spec = _aot_export_function( + func, + args, + decompositions=decompositions, + ) + + # No input mutations + if ( + len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata]) + != 0 + ): + raise RuntimeError( + f"aot_export_joint_simple does not support input mutations. {str(metadata)}" + ) + # No pytrees + if type(in_spec) == pytree.LeafSpec: + raise RuntimeError( + f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={str(in_spec)}" + ) + if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0: + raise RuntimeError( + f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={str(in_spec)}" + ) + if type(out_spec) == pytree.LeafSpec: + raise RuntimeError( + f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={str(out_spec)}" + ) + if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0: + raise RuntimeError( + f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={str(out_spec)}" + ) + + return fx_g diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 29485a919b..9f3dc5deb9 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set -import numpy +import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt @@ -11,6 +11,7 @@ import torch.fx from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata +from torch.utils._python_dispatch import _disable_current_modes from torch_tensorrt._Input import Input from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name from torch_tensorrt.fx.observer import Observer @@ -169,7 +170,7 @@ def run( cache = None if timing_cache: - cache_file = numpy.array(timing_cache) + cache_file = np.array(timing_cache) cache = builder_config.create_timing_cache(cache_file.tobytes()) else: cache = builder_config.create_timing_cache(b"") @@ -323,6 +324,21 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: assert self._cur_node_name is not None return converter(self.network, target, args, kwargs, self._cur_node_name) + def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: + with _disable_current_modes(): + from torch_tensorrt.fx.converters import to_numpy + + frozen_attr = self.fetch_attr(target) + + if isinstance(frozen_attr, torch.nn.Parameter): + constant_tensor = frozen_attr.data + else: + constant_tensor = frozen_attr + + network_constant = to_numpy(constant_tensor) + + return network_constant + def call_method(self, target: str, args: Any, kwargs: Any) -> Any: assert isinstance(target, str) converter = CONVERTERS.get(self._cur_node) @@ -344,6 +360,17 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: else: outputs = (args[0],) + for output_idx in range(len(outputs)): + from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor + + output = outputs[output_idx] + + if not isinstance(output, trt.tensorrt.ITensor): + new_output = get_trt_tensor(self.network, output, target) + outputs = ( + outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :] + ) + if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): raise RuntimeError("TensorRT requires all outputs to be Tensor!") diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 9fcf959346..42d6165256 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -94,7 +94,7 @@ def aten_ops_fmod( return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1]) -@dynamo_tensorrt_converter(torch.ops.aten.relu.default) +@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc] def aten_ops_relu( network: TRTNetwork, target: Target, @@ -111,7 +111,7 @@ def aten_ops_relu( ) -@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default) +@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default) # type: ignore[misc] def aten_ops_sigmoid( network: TRTNetwork, target: Target, @@ -128,7 +128,7 @@ def aten_ops_sigmoid( ) -@dynamo_tensorrt_converter(torch.ops.aten.tanh.default) +@dynamo_tensorrt_converter(torch.ops.aten.tanh.default) # type: ignore[misc] def aten_ops_tanh( network: TRTNetwork, target: Target, @@ -145,7 +145,7 @@ def aten_ops_tanh( ) -@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default) +@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default) # type: ignore[misc] def aten_ops_leaky_relu( network: TRTNetwork, target: Target, @@ -163,7 +163,7 @@ def aten_ops_leaky_relu( ) -@dynamo_tensorrt_converter(torch.ops.aten.elu.default) +@dynamo_tensorrt_converter(torch.ops.aten.elu.default) # type: ignore[misc] def aten_ops_elu( network: TRTNetwork, target: Target, @@ -182,7 +182,7 @@ def aten_ops_elu( ) -@dynamo_tensorrt_converter(torch.ops.aten.softplus.default) +@dynamo_tensorrt_converter(torch.ops.aten.softplus.default) # type: ignore[misc] def aten_ops_softplus( network: TRTNetwork, target: Target, @@ -200,7 +200,7 @@ def aten_ops_softplus( ) -@dynamo_tensorrt_converter(torch.ops.aten.clip.default) +@dynamo_tensorrt_converter(torch.ops.aten.clip.default) # type: ignore[misc] def aten_ops_clip( network: TRTNetwork, target: Target, @@ -219,7 +219,7 @@ def aten_ops_clip( ) -@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default) +@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default) # type: ignore[misc] def aten_ops_hard_sigmoid( network: TRTNetwork, target: Target, @@ -296,7 +296,7 @@ def aten_ops_rsqrt( ) -@dynamo_tensorrt_converter(torch.ops.aten.neg.default) +@dynamo_tensorrt_converter(torch.ops.aten.neg.default) # type: ignore[misc] def aten_ops_neg( network: TRTNetwork, target: Target, @@ -304,18 +304,12 @@ def aten_ops_neg( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = args[0] - if (isinstance(input_val, TRTTensor)) and ( - input_val.dtype == trt.int8 or input_val.dtype == trt.int32 - ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) - return impl.unary.neg( network, target, SourceIR.ATEN, name, - input_val, + args[0], ) @@ -503,7 +497,7 @@ def aten_ops_clone( ) -@dynamo_tensorrt_converter(torch.ops.aten.expand.default) +@dynamo_tensorrt_converter(torch.ops.aten.expand.default) # type: ignore[misc] def aten_ops_expand( network: TRTNetwork, target: Target, @@ -533,7 +527,7 @@ def amax_param_validator(amax_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.amax.default, capability_validator=amax_param_validator -) +) # type: ignore[misc] def aten_ops_amax( network: TRTNetwork, target: Target, @@ -552,8 +546,8 @@ def aten_ops_amax( ) -@dynamo_tensorrt_converter(torch.ops.aten.sum.default) -@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) +@dynamo_tensorrt_converter(torch.ops.aten.sum.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) # type: ignore[misc] def aten_ops_sum( network: TRTNetwork, target: Target, @@ -946,8 +940,8 @@ def aten_ops_isinf( ) -@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) # type: ignore[misc] def aten_ops_add( network: TRTNetwork, target: Target, @@ -978,8 +972,8 @@ def aten_ops_add( ) -@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar) # type: ignore[misc] def aten_ops_mul( network: TRTNetwork, target: Target, @@ -997,7 +991,7 @@ def aten_ops_mul( ) -@dynamo_tensorrt_converter(torch.ops.aten.maximum.default) +@dynamo_tensorrt_converter(torch.ops.aten.maximum.default) # type: ignore[misc] def aten_ops_max( network: TRTNetwork, target: Target, @@ -1015,7 +1009,7 @@ def aten_ops_max( ) -@dynamo_tensorrt_converter(torch.ops.aten.minimum.default) +@dynamo_tensorrt_converter(torch.ops.aten.minimum.default) # type: ignore[misc] def aten_ops_min( network: TRTNetwork, target: Target, @@ -1033,8 +1027,8 @@ def aten_ops_min( ) -@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar) # type: ignore[misc] def aten_ops_sub( network: TRTNetwork, target: Target, @@ -1065,10 +1059,10 @@ def aten_ops_sub( ) -@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) -@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) -@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) # type: ignore[misc] def aten_ops_div( network: TRTNetwork, target: Target, @@ -1111,9 +1105,9 @@ def aten_ops_div( ) -@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar) -@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) # type: ignore[misc] def aten_ops_pow( network: TRTNetwork, target: Target, @@ -1131,8 +1125,8 @@ def aten_ops_pow( ) -@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default) -@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar) # type: ignore[misc] def aten_ops_floor_div( network: TRTNetwork, target: Target, @@ -1150,7 +1144,7 @@ def aten_ops_floor_div( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default) +@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default) # type: ignore[misc] def aten_ops_logical_and( network: TRTNetwork, target: Target, @@ -1168,7 +1162,7 @@ def aten_ops_logical_and( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default) +@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default) # type: ignore[misc] def aten_ops_logical_or( network: TRTNetwork, target: Target, @@ -1186,7 +1180,7 @@ def aten_ops_logical_or( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default) +@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default) # type: ignore[misc] def aten_ops_logical_xor( network: TRTNetwork, target: Target, @@ -1204,8 +1198,8 @@ def aten_ops_logical_xor( ) -@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc] def aten_ops_equal( network: TRTNetwork, target: Target, @@ -1223,8 +1217,8 @@ def aten_ops_equal( ) -@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc] def aten_ops_greater( network: TRTNetwork, target: Target, @@ -1242,8 +1236,8 @@ def aten_ops_greater( ) -@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc] def aten_ops_less( network: TRTNetwork, target: Target, @@ -1267,7 +1261,7 @@ def conv_param_validator(conv_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.convolution.default, capability_validator=conv_param_validator -) +) # type: ignore[misc] def aten_ops_convolution( network: TRTNetwork, target: Target, @@ -1291,7 +1285,8 @@ def aten_ops_convolution( ) -@dynamo_tensorrt_converter(torch.ops.aten.linear.default) +@dynamo_tensorrt_converter(torch.ops.aten.linear.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.linear) # type: ignore[misc] def aten_ops_linear( network: TRTNetwork, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index c5df3f9752..1d8dfecf3b 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,14 +1,16 @@ import functools import logging import re -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union +import numpy as np import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, get_axes_for_reduce_op, + to_numpy, unified_dtype_converter, ) from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor @@ -185,11 +187,85 @@ def extend_attr_to_tuple( if isinstance(val, list): val = tuple(val) - return val + + if isinstance(val, tuple): + return val + else: + raise AssertionError(f"Could not extend attribute {val}") -def cast_int_or_float_to_bool(network: TRTNetwork, name: str, tensor: TRTTensor): +def cast_int_or_float_to_bool( + network: TRTNetwork, name: str, tensor: TRTTensor +) -> TRTTensor: if tensor.dtype != trt.bool: return cast_trt_tensor(network, tensor, trt.bool, name) return tensor + + +def create_constant( + network: TRTNetwork, + value: Union[int, float, np.ndarray, torch.Tensor], + name: str, + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]], +) -> TRTTensor: + """ + Add a TensorRT constant layer whose value is `value` to `network`. + Args: + network (TRTNetwork): A TensorRT network to which we want to add + a constant layer. + value (Union[int, float, np.ndarray, torch.Tensor]): A literal value, Numpy array, + or a PyTorch tensor that will be used as value of the added TensorRT Constant layer. + name (str): Name of the added TensorRT Constant layer. + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): + If a dtype is given, we will convert the type of the given `value` to this dtype. + Returns: + A TensorRT ITensor that represents the given value. + """ + constant = network.add_constant( + (1,) if isinstance(value, (int, float)) else value.shape, + to_numpy(value, dtype).copy(), + ) + constant.name = name + return constant.get_output(0) + + +def get_trt_tensor( + network: TRTNetwork, + input_val: Any, + name: str, + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None, +) -> TRTTensor: + """ + Given a value of random type, we try to convert it to a TensorRT ITensor. + An runtime error is raised if we're not able to do that. + Args: + network (TRTNetwork): A TensorRT network. If we want to + add a TensorRT Constant layer, we will add it to this network. + input_val (Any): An value that we want to convert to a TensorRT ITensor. + name (str): The name of the created TensorRT Constant layer if there's + one. + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): + If dtype is provided, the given value will be converted to this dtype. + Returns: + A TensorRT ITensor that represents the given value. + """ + # TRT can not add constant for bool type. We do a work around to 1) cast it to int and 2)cast to bool later + # This is useful for logical operations which require input to be bool type + if isinstance(input_val, bool): + input_val = int(input_val) + elif isinstance(input_val, torch.Tensor) and ( + input_val.dtype == torch.bool or input_val.dtype == torch.int64 + ): + input_val = input_val.to(torch.int32) + elif isinstance(input_val, np.ndarray) and ( + input_val.dtype == np.bool_ or input_val.dtype == np.int64 + ): + input_val = input_val.astype(np.int32) + + if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)): + return create_constant(network, input_val, name, dtype) + elif isinstance(input_val, TRTTensor): + return input_val + else: + raise AssertionError(f"Cannot convert {input_val} to TRT constant") diff --git a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py index b81418490c..9c225357b5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py @@ -1,19 +1,17 @@ from typing import Optional +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion.converter_utils import broadcastable -from torch_tensorrt.dynamo.conversion.impl.slice import expand -from torch_tensorrt.fx.converters.converter_utils import ( - broadcast, +from torch_tensorrt.dynamo.conversion.converter_utils import ( + broadcastable, get_trt_tensor, - set_layer_name, ) +from torch_tensorrt.dynamo.conversion.impl.slice import expand +from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -import tensorrt as trt - def where( network: TRTNetwork, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index ff7deb0962..ebe4e37c9e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -7,11 +7,13 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo.conversion import impl -from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple +from torch_tensorrt.dynamo.conversion.converter_utils import ( + extend_attr_to_tuple, + get_trt_tensor, +) from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, get_dyn_range, - get_trt_tensor, has_dynamic_shape, mark_as_int8_layer, set_layer_name, @@ -27,8 +29,8 @@ def convNd( name: str, is_conv1d: bool, input: TRTTensor, - weight: Union[TRTTensor, torch.Tensor], - bias: Optional[Union[TRTTensor, torch.Tensor]], + weight: Union[TRTTensor, torch.Tensor, np.ndarray], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], stride: Optional[Union[int, Sequence[int]]], padding: Optional[Union[int, Sequence[int]]], dilation: Optional[Union[int, Sequence[int]]], @@ -97,19 +99,28 @@ def convNd( if isinstance(bias, TRTTensor): conv_layer.set_input(2, bias) + # Cast certain fields to tuples, in accordance with TRT requirements + padding = (padding,) if isinstance(padding, int) else padding + stride = (stride,) if isinstance(stride, int) else stride + dilation = (dilation,) if isinstance(dilation, int) else dilation + # Expand parameters manually for Conv1D computations if is_conv1d: - padding = tuple(padding) + (0,) - stride = extend_attr_to_tuple(stride, 2) - dilation = extend_attr_to_tuple(dilation, 2) + padding = (tuple(padding) + (0,)) if padding is not None else padding + stride = extend_attr_to_tuple(stride, 2) if stride is not None else stride + dilation = ( + extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation + ) set_layer_name(conv_layer, target, name, source_ir) # Set relevant attributes of convolution layer - conv_layer.padding_nd = padding - conv_layer.stride_nd = stride - conv_layer.dilation_nd = dilation - + if padding is not None: + conv_layer.padding_nd = padding + if stride is not None: + conv_layer.stride_nd = stride + if dilation is not None: + conv_layer.dilation_nd = dilation if groups is not None: conv_layer.num_groups = groups diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 46380cbec7..95dcd88a75 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -7,10 +7,12 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, + get_trt_tensor, +) from torch_tensorrt.fx.converters.converter_utils import ( broadcast, - get_trt_tensor, set_layer_name, squeeze_left, ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index f5d46efc17..75ff33f26f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -7,17 +7,14 @@ from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_int_int_div_trt_tensor, cast_int_or_float_to_bool, + get_trt_tensor, ) from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) from torch_tensorrt.dynamo.conversion.impl.unary import sign from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary -from torch_tensorrt.fx.converters.converter_utils import ( - get_trt_tensor, - set_layer_name, - squeeze_left, -) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index 26064f621c..8ddfdf015f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -3,7 +3,8 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor, set_layer_name +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor diff --git a/py/torch_tensorrt/dynamo/conversion/impl/linear.py b/py/torch_tensorrt/dynamo/conversion/impl/linear.py index 0a98087bce..cad97a5c9a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/linear.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/linear.py @@ -5,7 +5,7 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo.conversion import impl -from torch_tensorrt.fx.converters.converter_utils import SourceIR, get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor from torch_tensorrt.fx.types import TRTNetwork, TRTTensor diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py index 4b69b09d2a..a62d24121f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -3,11 +3,8 @@ import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.converters.converter_utils import ( - broadcast, - get_trt_tensor, - set_layer_name, -) +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 2ab74ef86b..7822b515f8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -2,6 +2,7 @@ from typing import Any, List, Optional, Sequence, Union, cast import numpy as np +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -19,8 +20,6 @@ from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import get_dynamic_dims -import tensorrt as trt - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -101,9 +100,15 @@ def layer_norm( "of the TensorRT region!" ) - gamma = weight.detach().cpu().float().numpy() + gamma = ( + weight.detach().cpu().float().numpy() + if isinstance(weight, torch.Tensor) + else weight + ) gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) - beta = bias.detach().cpu().float().numpy() + beta = ( + bias.detach().cpu().float().numpy() if isinstance(bias, torch.Tensor) else bias + ) beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) eps_field = trt.PluginField( "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index 9929e59d86..fae22888d8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -2,9 +2,9 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor from torch_tensorrt.fx.converters.converter_utils import ( get_positive_dim, - get_trt_tensor, set_layer_name, ) from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor diff --git a/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py index e69b9987c7..70cc5424af 100644 --- a/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py @@ -81,10 +81,6 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: """ logger.debug("Pre-module replacement graph:\n" + str(gm.graph)) - # Ensure all parameters are in inference mode - for param in gm.parameters(): - param.requires_grad = False - # Iterate over graph nodes, extracting module calls, to check for interceptions for n in gm.graph.nodes: exists_in_registry = False @@ -128,7 +124,6 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: # Replace all original node uses and clean up graph n.replace_all_uses_with(new_node) - gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() @@ -142,7 +137,6 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: continue # Perform cleanup and recompilation before returning module - gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index f25bd2df12..5399bc5d6f 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -43,7 +43,9 @@ def is_node_supported( ) -> bool: node_name = ConverterRegistry.qualified_name_or_str(node.target) - if node in CONVERTERS and node_name not in self.torch_executed_ops: + if ( + node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name) + ) and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator if not node.is_impure(): if node_name not in self.supported_operators: diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index bdb15b3394..19fccfc73f 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -153,7 +153,9 @@ def is_node_supported( ) -> bool: node_name = ConverterRegistry.qualified_name_or_str(node.target) - if node in CONVERTERS and node_name not in self.torch_executed_ops: + if ( + node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name) + ) and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator if not node.is_impure(): if node_name not in self.supported_operators: diff --git a/tests/py/dynamo/backend/test_specialized_models.py b/tests/py/dynamo/backend/test_specialized_models.py index 143aa9b241..1b9e5fb337 100644 --- a/tests/py/dynamo/backend/test_specialized_models.py +++ b/tests/py/dynamo/backend/test_specialized_models.py @@ -2,7 +2,7 @@ import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests -from ..testing_utilities import lower_graph_testing +from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing class TestFakeTensors(TestCase): @@ -157,5 +157,84 @@ def forward(self, x): torch._dynamo.reset() +class TestTensorFreezing(TestCase): + def test_tensor_freeze_attr(self): + class TensorFreeze(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.ones((8, 2), device="cuda") + + def forward(self, x): + return x @ self.const + + inputs = [ + torch.ones( + 7, + 8, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(TensorFreeze()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Frozen-Tensor TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_constant_fold(self): + class Arange(torch.nn.Module): + def forward(self, x): + y = torch.arange(10, device="cuda") + return x + y + + inputs = [ + torch.rand( + 10, + 10, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(Arange()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Constant Folded TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 909ded2690..fd834394c1 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -12,6 +12,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def forward(self, x, y): + x += 1 x = torch.ops.aten.add_.Tensor(x, y) x = torch.ops.aten.relu_.default(x) return x diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index e2607d859b..f311f2db2b 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -1,18 +1,18 @@ +import unittest from copy import deepcopy from functools import partial from typing import Any, List, Sequence, Set import torch -from torch._dynamo.backends.common import fake_tensor_unsupported -from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler +from torch._dynamo.utils import detect_fake_mode from torch_tensorrt.dynamo import partitioning +from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile, constant_fold from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions DECIMALS_OF_AGREEMENT = 4 -@fake_tensor_unsupported def fx_dynamo_testing_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], @@ -33,13 +33,26 @@ def fx_dynamo_testing_backend( gm = pre_aot_substitutions(gm) - # Invoke AOTAutograd to translate operators to aten - return aot_module_simplified( - gm, - sample_inputs, - fw_compiler=make_boxed_compiler(custom_backend), - decompositions=get_decompositions(), - ) + fake_mode = detect_fake_mode(sample_inputs) + + # Place backend tracing within FakeTensor context allowing nonfake Tensors + with unittest.mock.patch.object( + fake_mode, "allow_non_fake_inputs", True + ), fake_mode: + # Invoke AOTAutograd to translate operators to aten + graph_module = aot_export_for_compile( + gm, + sample_inputs, + decompositions=get_decompositions(), + ) + + constant_fold(graph_module) + + trt_compiled = custom_backend( + graph_module, + sample_inputs, + ) + return trt_compiled def compile_module_testing(