From 9f7d3043240c70332bd37510aa5df00984bf9ce8 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 1 Dec 2023 18:24:32 -0800 Subject: [PATCH 1/3] move arange converter to ops_evaluators.py --- .../dynamo/conversion/ops_evaluators.py | 3 ++ .../py/dynamo/conversion/test_arange_aten.py | 33 +++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_arange_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index 3a67c47fa3..ee67911754 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -1,3 +1,4 @@ +import builtins import logging import operator from typing import Dict, Sequence, Tuple, Union @@ -23,7 +24,9 @@ def getitem_validator(getitem_node: Node) -> bool: # TODO: Subsequent evaluators should be registered here with their own validators @dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) +@dynamo_tensorrt_converter(builtins.getattr) @dynamo_tensorrt_converter(torch.ops.aten.detach.default) +@dynamo_tensorrt_converter(torch.ops.aten.arange.start_step) def generic_evaluator( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_arange_aten.py b/tests/py/dynamo/conversion/test_arange_aten.py new file mode 100644 index 0000000000..1daa32850b --- /dev/null +++ b/tests/py/dynamo/conversion/test_arange_aten.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestArangeConverter(DispatchTestCase): + @parameterized.expand( + [ + (0, 5, 1), + (1, 5, 2), + (3, 5, 3), + (5, 0, -1), + (5, 1, -2), + (5, 3, -3), + ] + ) + def test_arange(self, start, end, step): + class Arange(nn.Module): + def forward(self, x): + return torch.ops.aten.arange.start_step(start, x.shape[0], step) + + inputs = [torch.randn(end, 1)] + self.run_test( + Arange(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From 80699d96eebdaf070b814d2a22fdb315de959ab7 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 1 Dec 2023 18:34:19 -0800 Subject: [PATCH 2/3] fix potential FakeTensor bug --- .../dynamo/conversion/ops_evaluators.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index ee67911754..163b6399d8 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -3,6 +3,7 @@ import operator from typing import Dict, Sequence, Tuple, Union +import numpy as np import torch from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -26,7 +27,6 @@ def getitem_validator(getitem_node: Node) -> bool: @dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) @dynamo_tensorrt_converter(builtins.getattr) @dynamo_tensorrt_converter(torch.ops.aten.detach.default) -@dynamo_tensorrt_converter(torch.ops.aten.arange.start_step) def generic_evaluator( ctx: ConversionContext, target: Target, @@ -38,3 +38,14 @@ def generic_evaluator( f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}" ) return target(*args) + + +@dynamo_tensorrt_converter(torch.ops.aten.arange.start_step) +def aten_ops_arange_start_step( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return np.arange(*args) From 8c2ff9995900beb9165dc57ff5208889b2e29734 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 5 Dec 2023 17:41:10 -0800 Subject: [PATCH 3/3] remove unnecessary builtins.getattr --- py/torch_tensorrt/dynamo/conversion/ops_evaluators.py | 2 -- tests/py/dynamo/conversion/test_arange_aten.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index 163b6399d8..f83e0e5008 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -1,4 +1,3 @@ -import builtins import logging import operator from typing import Dict, Sequence, Tuple, Union @@ -25,7 +24,6 @@ def getitem_validator(getitem_node: Node) -> bool: # TODO: Subsequent evaluators should be registered here with their own validators @dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) -@dynamo_tensorrt_converter(builtins.getattr) @dynamo_tensorrt_converter(torch.ops.aten.detach.default) def generic_evaluator( ctx: ConversionContext, diff --git a/tests/py/dynamo/conversion/test_arange_aten.py b/tests/py/dynamo/conversion/test_arange_aten.py index 1daa32850b..035b957865 100644 --- a/tests/py/dynamo/conversion/test_arange_aten.py +++ b/tests/py/dynamo/conversion/test_arange_aten.py @@ -26,6 +26,7 @@ def forward(self, x): self.run_test( Arange(), inputs, + use_dynamo_tracer=True, )