Skip to content

Commit d48d611

Browse files
committed
clean flatten converter and add tests
1 parent 1502cf5 commit d48d611

File tree

2 files changed

+41
-23
lines changed

2 files changed

+41
-23
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+2-22
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def aten_ops_slice(
483483
{
484484
0: (TRTTensor,),
485485
}
486-
)
486+
) # type: ignore[misc]
487487
def aten_ops_permute(
488488
ctx: ConversionContext,
489489
target: Target,
@@ -1371,7 +1371,7 @@ def conv_param_validator(conv_node: Node) -> bool:
13711371
1: (np.ndarray, torch.Tensor, TRTTensor),
13721372
2: (np.ndarray, torch.Tensor, TRTTensor),
13731373
}
1374-
)
1374+
) # type: ignore[misc]
13751375
def aten_ops_convolution(
13761376
ctx: ConversionContext,
13771377
target: Target,
@@ -1551,23 +1551,3 @@ def aten_ops_reshape(
15511551
input=args[0],
15521552
shape=args[1],
15531553
)
1554-
1555-
1556-
# # TODO: need tests for this converter
1557-
# @dynamo_tensorrt_converter(torch.ops.aten.flatten.using_ints) # type: ignore[misc]
1558-
# def aten_ops_flatten(
1559-
# ctx: ConversionContext,
1560-
# target: Target,
1561-
# args: Tuple[Argument, ...],
1562-
# kwargs: Dict[str, Argument],
1563-
# name: str,
1564-
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
1565-
# return impl.shuffle.flatten(
1566-
# ctx,
1567-
# target,
1568-
# SourceIR.ATEN,
1569-
# name,
1570-
# input=args[0],
1571-
# start_dim=args_bounds_check(args, 1, 0),
1572-
# end_dim=args_bounds_check(args, 2, -1),
1573-
# )

tests/py/dynamo/conversion/test_converter_utils.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import numpy as np
22
import torch
3+
from parameterized import parameterized
34
from torch.testing._internal.common_utils import TestCase, run_tests
4-
from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types
5+
from torch_tensorrt.dynamo.conversion.converter_utils import (
6+
enforce_tensor_types,
7+
flatten_dims,
8+
)
59
from torch_tensorrt.fx.types import TRTTensor
610

711
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
@@ -37,5 +41,39 @@ def test_invalid_invocation_type(self):
3741
enforce_tensor_types({0: (int, bool)})
3842

3943

44+
class TestFlattenDimsEnforcement(TestCase):
45+
@parameterized.expand(
46+
[
47+
((1, 2), 0, 0, (1, 2)),
48+
((1, 2), 0, 1, (2,)),
49+
((2, 3, 4), 1, 2, (2, 12)),
50+
((2, 3, 4), 0, 1, (6, 4)),
51+
((2, 3, 4), -3, 2, (24,)),
52+
((2, 3, 4, 5), 0, -2, (24, 5)),
53+
((2, 3, 4, 5), -4, -1, (120,)),
54+
]
55+
)
56+
def test_numpy_array(self, input_shape, start_dim, end_dim, true_shape):
57+
inputs = np.random.randn(*input_shape)
58+
new_shape = flatten_dims(inputs, start_dim, end_dim)
59+
self.assertEqual(new_shape, true_shape)
60+
61+
@parameterized.expand(
62+
[
63+
((1, 2), 0, 0, (1, 2)),
64+
((1, 2), 0, 1, (2,)),
65+
((2, 3, 4), 1, 2, (2, 12)),
66+
((2, 3, 4), 0, 1, (6, 4)),
67+
((2, 3, 4), -3, 2, (24,)),
68+
((2, 3, 4, 5), 0, -2, (24, 5)),
69+
((2, 3, 4, 5), -4, -1, (120,)),
70+
]
71+
)
72+
def test_torch_tensor(self, input_shape, start_dim, end_dim, true_shape):
73+
inputs = torch.randn(input_shape)
74+
new_shape = flatten_dims(inputs, start_dim, end_dim)
75+
self.assertEqual(new_shape, true_shape)
76+
77+
4078
if __name__ == "__main__":
4179
run_tests()

0 commit comments

Comments
 (0)