Skip to content

Commit 346ec80

Browse files
committed
rand converters
rand converters Correcting rand test cases linting fixes adding validators to rand() test moving the error output to validator consolidating the two validators and removing assertion check from evaluator correcting rand test removing device kwargs since not used changing the test to compare size instead of elements changing the data type in interpretor setting Change name of run_test_comparator function in harness.py removing precision in harness.py interpreter compilationsettings and interpreter.run
1 parent 30aff3a commit 346ec80

File tree

3 files changed

+243
-4
lines changed

3 files changed

+243
-4
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,74 @@ def aten_ops_arange_start_step(
4747
name: str,
4848
) -> Union[TRTTensor, Sequence[TRTTensor]]:
4949
return np.arange(*args)
50+
51+
52+
def rand_validator(rand_node: Node) -> bool:
53+
dtype = rand_node.kwargs.get("dtype", None)
54+
layout = rand_node.kwargs.get("layout", None)
55+
if dtype is not None:
56+
_LOGGER.debug(
57+
f"Currently we don't support specifying output dtype, got {dtype}."
58+
)
59+
return False
60+
if layout is not None:
61+
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
62+
return False
63+
return True
64+
65+
66+
@dynamo_tensorrt_converter(
67+
torch.ops.aten.rand.default, capability_validator=rand_validator
68+
)
69+
def aten_ops_rand(
70+
ctx: ConversionContext,
71+
target: Target,
72+
args: Tuple[Argument, ...],
73+
kwargs: Dict[str, Argument],
74+
name: str,
75+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
76+
return np.random.rand(*args[0])
77+
78+
79+
@dynamo_tensorrt_converter(
80+
torch.ops.aten.randn.default, capability_validator=rand_validator
81+
)
82+
def aten_ops_randn(
83+
ctx: ConversionContext,
84+
target: Target,
85+
args: Tuple[Argument, ...],
86+
kwargs: Dict[str, Argument],
87+
name: str,
88+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
89+
return np.random.randn(*args[0])
90+
91+
92+
def randperm_validator(randperm_node: Node) -> bool:
93+
dtype = randperm_node.kwargs.get("dtype", None)
94+
layout = randperm_node.kwargs.get("layout", None)
95+
input = randperm_node.args[0]
96+
if not isinstance(input, int):
97+
_LOGGER.error(f"Input should be of type int.")
98+
return False
99+
if dtype is not None:
100+
_LOGGER.debug(
101+
f"Currently we don't support specifying output dtype, got {dtype}."
102+
)
103+
return False
104+
if layout is not None:
105+
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
106+
return False
107+
return True
108+
109+
110+
@dynamo_tensorrt_converter(
111+
torch.ops.aten.randperm.default, capability_validator=randperm_validator
112+
)
113+
def aten_ops_randperm(
114+
ctx: ConversionContext,
115+
target: Target,
116+
args: Tuple[Argument, ...],
117+
kwargs: Dict[str, Argument],
118+
name: str,
119+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
120+
return np.random.permutation(args[0])

tests/py/dynamo/conversion/harness.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,7 @@ def run_test_custom_compare_results(
138138
if len(expected_ops):
139139
self.assert_has_op(mod, expected_ops)
140140

141-
interpreter_result = interpreter.run(
142-
precision=torch.half if fp16_mode else torch.float
143-
)
141+
interpreter_result = interpreter.run()
144142
trt_mod = PythonTorchTensorRTModule(
145143
interpreter_result.engine,
146144
interpreter_result.input_names,
@@ -149,7 +147,6 @@ def run_test_custom_compare_results(
149147
res_trt = trt_mod(*cuda_inputs).cpu()
150148
res_cpu = mod(*cuda_inputs).cpu()
151149
assert len(res_trt) == len(res_cpu)
152-
assert len(res_cpu) == len(comparators)
153150
for output_trt, output_cpu, comparator in zip(
154151
res_trt, res_cpu, comparators
155152
):
@@ -270,6 +267,42 @@ def run_test(
270267
check_dtype,
271268
)
272269

270+
def run_test_compare_tensor_attributes_only(
271+
self,
272+
mod,
273+
inputs,
274+
expected_ops,
275+
comparators: List[Tuple[Callable, List]],
276+
precision=torch.float,
277+
output_dtypes=None,
278+
use_dynamo_tracer=False,
279+
enable_passes=False,
280+
):
281+
mod.eval()
282+
mod = self.generate_graph(
283+
mod,
284+
inputs,
285+
use_dynamo_tracer=use_dynamo_tracer,
286+
enable_passes=enable_passes,
287+
)
288+
# Previous instance of the interpreter auto-casted 64-bit inputs
289+
# We replicate this behavior here
290+
compilation_settings = CompilationSettings(
291+
enabled_precisions={dtype._from(precision)},
292+
truncate_long_and_double=True,
293+
debug=True,
294+
)
295+
296+
interp = TRTInterpreter(
297+
mod,
298+
Input.from_tensors(inputs),
299+
output_dtypes=output_dtypes,
300+
compilation_settings=compilation_settings,
301+
)
302+
super().run_test_custom_compare_results(
303+
mod, inputs, expected_ops, interp, comparators
304+
)
305+
273306
def run_test_with_dynamic_shape(
274307
self,
275308
mod,
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch_tensorrt
4+
from parameterized import parameterized
5+
from torch.testing._internal.common_utils import TestCase, run_tests
6+
7+
from .harness import DispatchTestCase
8+
9+
rand_ops = [
10+
(
11+
"rand_one_dimension",
12+
(lambda shape: torch.ops.aten.rand(shape)),
13+
[1],
14+
),
15+
(
16+
"rand_two_dimension",
17+
(lambda shape: torch.ops.aten.rand(shape)),
18+
[1, 2],
19+
),
20+
(
21+
"rand_three_dimension",
22+
(lambda shape: torch.ops.aten.rand(shape)),
23+
[2, 3, 4],
24+
),
25+
(
26+
"randn_one_dimension",
27+
(lambda shape: torch.ops.aten.randn(shape)),
28+
[1],
29+
),
30+
(
31+
"randn_two_dimension",
32+
(lambda shape: torch.ops.aten.randn(shape)),
33+
[2, 3],
34+
),
35+
(
36+
"randn_three_dimension",
37+
(lambda shape: torch.ops.aten.randn(shape)),
38+
[2, 3, 4],
39+
),
40+
]
41+
42+
43+
rand_perm_ops = [
44+
(
45+
"randperm_one_case",
46+
(lambda x: torch.ops.aten.randperm(x)),
47+
[1],
48+
),
49+
(
50+
"randperm_two_case",
51+
(lambda x: torch.ops.aten.randperm(x)),
52+
[150],
53+
),
54+
(
55+
"randperm_three_case",
56+
(lambda x: torch.ops.aten.randperm(x)),
57+
[1500],
58+
),
59+
]
60+
61+
62+
class TestRandConverter(DispatchTestCase):
63+
@parameterized.expand(
64+
[
65+
(
66+
rand_op[0],
67+
rand_op[1],
68+
rand_op[2],
69+
)
70+
for rand_op in rand_ops
71+
]
72+
)
73+
def test_rand(self, name, op, shape_or_input):
74+
class TestModule(nn.Module):
75+
def __init__(self):
76+
super().__init__()
77+
78+
def forward(self, x):
79+
shape_or_input[0] = x.shape[0]
80+
return op(shape_or_input)
81+
82+
rand_model = TestModule()
83+
84+
inputs = [torch.randint(1, 3, shape_or_input, dtype=torch.int32)]
85+
comparator_shape = lambda x, y, check_dtype: x.shape == y.shape and (
86+
x.dtype == y.dtype if check_dtype else True
87+
)
88+
expected_ops = []
89+
self.run_test_compare_tensor_attributes_only(
90+
rand_model,
91+
inputs,
92+
expected_ops,
93+
[(comparator_shape, [True])],
94+
use_dynamo_tracer=True,
95+
)
96+
97+
@parameterized.expand(
98+
[
99+
(
100+
rand_op[0],
101+
rand_op[1],
102+
rand_op[2],
103+
)
104+
for rand_op in rand_perm_ops
105+
]
106+
)
107+
def test_rand(self, name, op, shape_or_input):
108+
class TestModule(nn.Module):
109+
def __init__(self):
110+
super().__init__()
111+
112+
def forward(self, x):
113+
shape_or_input[0] = x.shape[0]
114+
return op(shape_or_input[0])
115+
116+
rand_model = TestModule()
117+
# cannot use self.run_test() since it expects input in form of tensor
118+
119+
inputs = [torch.randint(1, 3, shape_or_input, dtype=torch.int32)]
120+
comparator_shape = lambda x, y, check_dtype: x.shape == y.shape and (
121+
x.dtype == y.dtype if check_dtype else True
122+
)
123+
expected_ops = []
124+
# TRT-TRT returns int32 while torch returns int64
125+
self.run_test_compare_tensor_attributes_only(
126+
rand_model,
127+
inputs,
128+
expected_ops,
129+
[(comparator_shape, [False])],
130+
use_dynamo_tracer=True,
131+
)
132+
133+
134+
if __name__ == "__main__":
135+
run_tests()

0 commit comments

Comments
 (0)