Skip to content

Commit 2735856

Browse files
committed
feat: support aten.arange.start_step dynamo converter
1 parent cd158b6 commit 2735856

File tree

4 files changed

+82
-0
lines changed

4 files changed

+82
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+19
Original file line numberDiff line numberDiff line change
@@ -2054,3 +2054,22 @@ def aten_ops_addmm(
20542054
beta=kwargs.get("beta", 1),
20552055
alpha=kwargs.get("alpha", 1),
20562056
)
2057+
2058+
2059+
@dynamo_tensorrt_converter(torch.ops.aten.arange.start_step)
2060+
def aten_ops_arange_start_step(
2061+
ctx: ConversionContext,
2062+
target: Target,
2063+
args: Tuple[Argument, ...],
2064+
kwargs: Dict[str, Argument],
2065+
name: str,
2066+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2067+
return impl.sequence.arange(
2068+
ctx,
2069+
target,
2070+
SourceIR.ATEN,
2071+
name,
2072+
args[0],
2073+
args[1],
2074+
args_bounds_check(args, 2, 1),
2075+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
pool,
2121
reduce,
2222
select,
23+
sequence,
2324
shape,
2425
shuffle,
2526
slice,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Optional, Union
2+
3+
import numpy as np
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
8+
from torch_tensorrt.fx.types import TRTTensor
9+
10+
11+
def arange(
12+
ctx: ConversionContext,
13+
target: Target,
14+
source_ir: Optional[SourceIR],
15+
name: str,
16+
start: Union[int, float],
17+
end: Union[int, float],
18+
step: Union[int, float] = 1,
19+
) -> TRTTensor:
20+
values = np.arange(start, end, step)
21+
return get_trt_tensor(ctx, values, f"{name}_arange")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestArangeConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
# int
13+
(0, 5, 1),
14+
(1, 5, 2),
15+
(3, 5, 3),
16+
(5, 0, -1),
17+
(5, 1, -2),
18+
(5, 3, -3),
19+
# float and mixed
20+
(0.0, 5.0, 1.0),
21+
(1.0, 5.0, 2.2),
22+
(2, 10, 3.3),
23+
(5.0, 0.0, -3),
24+
(5.0, 1, -2.0001),
25+
(5, 3.0, -3.9999),
26+
]
27+
)
28+
def test_arange(self, start, end, step):
29+
class Arange(nn.Module):
30+
def forward(self):
31+
return torch.ops.aten.arange.start_step(start, end, step)
32+
33+
inputs = []
34+
self.run_test(
35+
Arange(),
36+
inputs,
37+
)
38+
39+
40+
if __name__ == "__main__":
41+
run_tests()

0 commit comments

Comments
 (0)