Skip to content

Commit 0ed0447

Browse files
committed
fix: Implement aten.mean.default + aten.mean.dim converters
- Replace existing implementation of aten.mean.dim with general version invoking centralized `add_reduce_layer` utility - Add implementation of aten.mean.default by refactoring default case to be a special invocation of aten.mean.dim (specifically, one with `dim` being all dimensions in the Tensor and `keepdim` being False). - Add defaults for optional arguments in converters - Add test cases for combinations of input options to the converters
1 parent 35cf89d commit 0ed0447

File tree

2 files changed

+108
-10
lines changed

2 files changed

+108
-10
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def aten_ops_add(
4545
return add_add(network, target, kwargs_new, name)
4646

4747

48-
@tensorrt_converter(torch.ops.aten.mean.dim)
4948
@tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
5049
@tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
5150
def aten_ops_adaptive_avg_poolnd(
@@ -55,23 +54,38 @@ def aten_ops_adaptive_avg_poolnd(
5554
kwargs: Dict[str, Argument],
5655
name: str,
5756
) -> Union[TRTTensor, Sequence[TRTTensor]]:
58-
if target == torch.ops.aten.mean.dim:
59-
if list(args[1]) != [-1, -2]:
60-
raise RuntimeError(f"We do not support {target} has dim={args[1]}")
61-
else:
62-
output_size = [1, 1]
63-
else:
64-
output_size = args[1]
65-
6657
kwargs_new = {
6758
"input": args[0],
68-
"output_size": output_size,
59+
"output_size": args[1],
6960
}
7061
return acc_ops_converters.acc_ops_adaptive_avg_poolnd(
7162
network, target, None, kwargs_new, name
7263
)
7364

7465

66+
@tensorrt_converter(torch.ops.aten.mean.default)
67+
@tensorrt_converter(torch.ops.aten.mean.dim)
68+
def aten_ops_mean(
69+
network: TRTNetwork,
70+
target: Target,
71+
args: Tuple[Argument, ...],
72+
kwargs: Dict[str, Argument],
73+
name: str,
74+
) -> TRTTensor:
75+
# Default invocation of aten.mean only uses first argument and
76+
# averages over all elements (all dimensions)
77+
# aten.mean.dim invocation allows specification of dimensions to average
78+
# over, as well at the option to keep the dimension or not
79+
kwargs_new = {
80+
"input": args[0],
81+
"dim": args[1] if len(args) >= 2 else list(range(len(args[0].shape))),
82+
"keepdim": args[2] if len(args) >= 3 else False,
83+
}
84+
return add_reduce_layer(
85+
network, target, args, kwargs_new, trt.ReduceOperation.AVG, name
86+
)
87+
88+
7589
@tensorrt_converter(torch.ops.aten.batch_norm)
7690
def aten_ops_batch_norm(
7791
network: TRTNetwork,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestMeanDimConverter(DispatchTestCase):
8+
def test_mean_dim_keepdims(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return torch.mean(x, dim=[0, 1], keepdim=True)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.dim})
15+
16+
def test_mean_dim_keepdims_with_dynamic_shape(self):
17+
class TestModule(nn.Module):
18+
def forward(self, x):
19+
return torch.mean(x, dim=[0, 1, 2], keepdim=True)
20+
21+
input_specs = [
22+
InputTensorSpec(
23+
shape=(-1, -1, -1),
24+
dtype=torch.float32,
25+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
),
27+
]
28+
self.run_test_with_dynamic_shape(
29+
TestModule(), input_specs, expected_ops={torch.ops.aten.mean.dim}
30+
)
31+
32+
def test_mean_dim_keepdims_false(self):
33+
class TestModule(nn.Module):
34+
def forward(self, x):
35+
return torch.mean(x, dim=0, keepdim=False)
36+
37+
inputs = [torch.randn(3, 5, 7)]
38+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.dim})
39+
40+
def test_mean_dim_keepdims_false_with_dynamic_shape(self):
41+
class TestModule(nn.Module):
42+
def forward(self, x):
43+
return torch.mean(x, dim=-1, keepdim=False)
44+
45+
input_specs = [
46+
InputTensorSpec(
47+
shape=(-1, -1, -1),
48+
dtype=torch.float32,
49+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
50+
),
51+
]
52+
self.run_test_with_dynamic_shape(
53+
TestModule(), input_specs, expected_ops={torch.ops.aten.mean.dim}
54+
)
55+
56+
57+
class TestMeanConverter(DispatchTestCase):
58+
def test_mean(self):
59+
class TestModule(nn.Module):
60+
def forward(self, x):
61+
return torch.mean(x)
62+
63+
inputs = [torch.randn(3, 8, 5, 7, 1)]
64+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.default})
65+
66+
def test_mean_with_dynamic_shape(self):
67+
class TestModule(nn.Module):
68+
def forward(self, x):
69+
return torch.mean(x)
70+
71+
input_specs = [
72+
InputTensorSpec(
73+
shape=(-1, -1, -1),
74+
dtype=torch.float32,
75+
shape_ranges=[((1, 1, 1), (1, 5, 8), (3, 10, 10))],
76+
),
77+
]
78+
self.run_test_with_dynamic_shape(
79+
TestModule(), input_specs, expected_ops={torch.ops.aten.mean.default}
80+
)
81+
82+
83+
if __name__ == "__main__":
84+
run_tests()

0 commit comments

Comments
 (0)