diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 943eb203b3..c86f2bd228 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -41,7 +41,6 @@ def aten_ops_add( return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name) -@tensorrt_converter(torch.ops.aten.mean.dim) @tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default) @tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default) def aten_ops_adaptive_avg_poolnd( @@ -51,24 +50,38 @@ def aten_ops_adaptive_avg_poolnd( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - if target == torch.ops.aten.mean.dim: - - if list(args[1]) != [-1, -2]: - raise RuntimeError(f"We do not support {target} has dim={args[1]}") - else: - output_size = [1, 1] - else: - output_size = args[1] - kwargs_new = { "input": args[0], - "output_size": output_size, + "output_size": args[1], } return acc_ops_converters.acc_ops_adaptive_avg_poolnd( network, target, None, kwargs_new, name ) +@tensorrt_converter(torch.ops.aten.mean.default) +@tensorrt_converter(torch.ops.aten.mean.dim) +def aten_ops_mean( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> TRTTensor: + # Default invocation of aten.mean only uses first argument and + # averages over all elements (all dimensions) + # aten.mean.dim invocation allows specification of dimensions to average + # over, as well at the option to keep the dimension or not + kwargs_new = { + "input": args[0], + "dim": args[1] if len(args) >= 2 else list(range(len(args[0].shape))), + "keepdim": args[2] if len(args) >= 3 else False, + } + return add_reduce_layer( + network, target, args, kwargs_new, trt.ReduceOperation.AVG, name + ) + + @tensorrt_converter(torch.ops.aten.batch_norm) def aten_ops_batch_norm( network: TRTNetwork, diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_mean_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_mean_aten.py new file mode 100644 index 0000000000..23ec89e56a --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_mean_aten.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestMeanDimConverter(DispatchTestCase): + def test_mean_dim_keepdims(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x, dim=[0, 1], keepdim=True) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.dim}) + + def test_mean_dim_keepdims_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x, dim=[0, 1, 2], keepdim=True) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.mean.dim} + ) + + def test_mean_dim_keepdims_false(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x, dim=0, keepdim=False) + + inputs = [torch.randn(3, 5, 7)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.dim}) + + def test_mean_dim_keepdims_false_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x, dim=-1, keepdim=False) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.mean.dim} + ) + + +class TestMeanConverter(DispatchTestCase): + def test_mean(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x) + + inputs = [torch.randn(3, 8, 5, 7, 1)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.default}) + + def test_mean_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 5, 8), (3, 10, 10))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.mean.default} + ) + + +if __name__ == "__main__": + run_tests()