Skip to content

Commit c0cfeb8

Browse files
committed
clean flatten converter and add tests
1 parent 0f24a52 commit c0cfeb8

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def aten_ops_slice(
506506
{
507507
0: (TRTTensor,),
508508
}
509-
)
509+
) # type: ignore[misc]
510510
def aten_ops_permute(
511511
ctx: ConversionContext,
512512
target: Target,
@@ -1394,7 +1394,7 @@ def conv_param_validator(conv_node: Node) -> bool:
13941394
1: (np.ndarray, torch.Tensor, TRTTensor),
13951395
2: (np.ndarray, torch.Tensor, TRTTensor),
13961396
}
1397-
)
1397+
) # type: ignore[misc]
13981398
def aten_ops_convolution(
13991399
ctx: ConversionContext,
14001400
target: Target,

tests/py/dynamo/conversion/test_converter_utils.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import numpy as np
22
import torch
3+
from parameterized import parameterized
34
from torch.testing._internal.common_utils import TestCase, run_tests
4-
from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types
5+
from torch_tensorrt.dynamo.conversion.converter_utils import (
6+
enforce_tensor_types,
7+
flatten_dims,
8+
)
59
from torch_tensorrt.fx.types import TRTTensor
610

711
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
@@ -37,5 +41,39 @@ def test_invalid_invocation_type(self):
3741
enforce_tensor_types({0: (int, bool)})
3842

3943

44+
class TestFlattenDimsEnforcement(TestCase):
45+
@parameterized.expand(
46+
[
47+
((1, 2), 0, 0, (1, 2)),
48+
((1, 2), 0, 1, (2,)),
49+
((2, 3, 4), 1, 2, (2, 12)),
50+
((2, 3, 4), 0, 1, (6, 4)),
51+
((2, 3, 4), -3, 2, (24,)),
52+
((2, 3, 4, 5), 0, -2, (24, 5)),
53+
((2, 3, 4, 5), -4, -1, (120,)),
54+
]
55+
)
56+
def test_numpy_array(self, input_shape, start_dim, end_dim, true_shape):
57+
inputs = np.random.randn(*input_shape)
58+
new_shape = flatten_dims(inputs, start_dim, end_dim)
59+
self.assertEqual(new_shape, true_shape)
60+
61+
@parameterized.expand(
62+
[
63+
((1, 2), 0, 0, (1, 2)),
64+
((1, 2), 0, 1, (2,)),
65+
((2, 3, 4), 1, 2, (2, 12)),
66+
((2, 3, 4), 0, 1, (6, 4)),
67+
((2, 3, 4), -3, 2, (24,)),
68+
((2, 3, 4, 5), 0, -2, (24, 5)),
69+
((2, 3, 4, 5), -4, -1, (120,)),
70+
]
71+
)
72+
def test_torch_tensor(self, input_shape, start_dim, end_dim, true_shape):
73+
inputs = torch.randn(input_shape)
74+
new_shape = flatten_dims(inputs, start_dim, end_dim)
75+
self.assertEqual(new_shape, true_shape)
76+
77+
4078
if __name__ == "__main__":
4179
run_tests()

0 commit comments

Comments
 (0)