|
1 | 1 | import numpy as np
|
2 | 2 | import torch
|
| 3 | +from parameterized import parameterized |
3 | 4 | from torch.testing._internal.common_utils import TestCase, run_tests
|
4 |
| -from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types |
| 5 | +from torch_tensorrt.dynamo.conversion.converter_utils import ( |
| 6 | + enforce_tensor_types, |
| 7 | + flatten_dims, |
| 8 | +) |
5 | 9 | from torch_tensorrt.fx.types import TRTTensor
|
6 | 10 |
|
7 |
| -from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing |
| 11 | +# from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing |
8 | 12 |
|
9 | 13 |
|
10 | 14 | class TestTensorTypeEnforcement(TestCase):
|
@@ -37,5 +41,39 @@ def test_invalid_invocation_type(self):
|
37 | 41 | enforce_tensor_types({0: (int, bool)})
|
38 | 42 |
|
39 | 43 |
|
| 44 | +class TestFlattenDimsEnforcement(TestCase): |
| 45 | + @parameterized.expand( |
| 46 | + [ |
| 47 | + ((1, 2), 0, 0, (1, 2)), |
| 48 | + ((1, 2), 0, 1, (2,)), |
| 49 | + ((2, 3, 4), 1, 2, (2, 12)), |
| 50 | + ((2, 3, 4), 0, 1, (6, 4)), |
| 51 | + ((2, 3, 4), -3, 2, (24,)), |
| 52 | + ((2, 3, 4, 5), 0, -2, (24, 5)), |
| 53 | + ((2, 3, 4, 5), -4, -1, (120,)), |
| 54 | + ] |
| 55 | + ) |
| 56 | + def test_numpy_array(self, input_shape, start_dim, end_dim, true_shape): |
| 57 | + inputs = np.random.randn(*input_shape) |
| 58 | + new_shape = flatten_dims(inputs, start_dim, end_dim) |
| 59 | + self.assertEqual(new_shape, true_shape) |
| 60 | + |
| 61 | + @parameterized.expand( |
| 62 | + [ |
| 63 | + ((1, 2), 0, 0, (1, 2)), |
| 64 | + ((1, 2), 0, 1, (2,)), |
| 65 | + ((2, 3, 4), 1, 2, (2, 12)), |
| 66 | + ((2, 3, 4), 0, 1, (6, 4)), |
| 67 | + ((2, 3, 4), -3, 2, (24,)), |
| 68 | + ((2, 3, 4, 5), 0, -2, (24, 5)), |
| 69 | + ((2, 3, 4, 5), -4, -1, (120,)), |
| 70 | + ] |
| 71 | + ) |
| 72 | + def test_torch_tensor(self, input_shape, start_dim, end_dim, true_shape): |
| 73 | + inputs = torch.randn(input_shape) |
| 74 | + new_shape = flatten_dims(inputs, start_dim, end_dim) |
| 75 | + self.assertEqual(new_shape, true_shape) |
| 76 | + |
| 77 | + |
40 | 78 | if __name__ == "__main__":
|
41 | 79 | run_tests()
|
0 commit comments