diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index f270ce3ea8..57c720ffba 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -1040,11 +1040,14 @@ def acc_ops_elu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - alpha = kwargs["alpha"] - operation_type = trt.ActivationType.ELU - return activation.convert_activation( - network, target, SourceIR.ACC, name, operation_type, input_val, alpha=alpha + + return activation.elu( + network, + target, + SourceIR.ACC, + name, + kwargs["input"], + kwargs["alpha"], ) @@ -1056,15 +1059,13 @@ def acc_ops_selu( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.ActivationType.SELU - return activation.convert_activation( + + return activation.selu( network, target, SourceIR.ACC, name, - operation_type, - input_val, + kwargs["input"], ) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 4f93d98a26..82847cc760 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -170,6 +170,33 @@ def aten_ops_div( ) +@tensorrt_converter(torch.ops.aten.elu.default) +def aten_ops_elu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + if len(args) > 2: + return activation.selu( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + return activation.elu( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @tensorrt_converter(torch.ops.aten.floor_divide.default) def aten_ops_floor_div( network: TRTNetwork, diff --git a/py/torch_tensorrt/fx/converters/impl/activation.py b/py/torch_tensorrt/fx/converters/impl/activation.py index 793d0f90c9..66c16b0892 100644 --- a/py/torch_tensorrt/fx/converters/impl/activation.py +++ b/py/torch_tensorrt/fx/converters/impl/activation.py @@ -202,3 +202,51 @@ def leaky_relu_dyn_range_fn(dyn_range): alpha, dyn_range_fn=leaky_relu_dyn_range_fn, ) + + +def elu( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + alpha: Optional[Any], +): + operation_type = trt.ActivationType.ELU + + def elu_dyn_range_fn(dyn_range): + return (torch.nn.ELU(dyn_range[0]), torch.nn.ELU(dyn_range[1])) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + alpha, + dyn_range_fn=elu_dyn_range_fn, + ) + + +def selu( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +): + operation_type = trt.ActivationType.SELU + + def elu_dyn_range_fn(dyn_range): + return (torch.nn.SELU(dyn_range[0]), torch.nn.ELU(dyn_range[1])) + + return convert_activation( + network, + target, + source_ir, + name, + operation_type, + input_val, + dyn_range_fn=elu_dyn_range_fn, + ) diff --git a/py/torch_tensorrt/fx/converters/nn_ops_converters.py b/py/torch_tensorrt/fx/converters/nn_ops_converters.py index 37b5de4115..a285055937 100644 --- a/py/torch_tensorrt/fx/converters/nn_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/nn_ops_converters.py @@ -82,3 +82,34 @@ def leaky_relu(network, submod, args, kwargs, layer_name): input_val=kwargs["input"], alpha=kwargs["negative_slope"], ) + + +@tensorrt_converter(torch.nn.functional.elu) +@tensorrt_converter(torch.nn.modules.activation.ELU) +def elu(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + + return activation.elu( + network=network, + target="torch.nn.functional.elu", + source_ir=SourceIR.NN, + name=layer_name, + input_val=kwargs["input"], + ) + + +@tensorrt_converter(torch.nn.functional.selu) +@tensorrt_converter(torch.nn.modules.activation.SELU) +def selu(network, submod, args, kwargs, layer_name): + # args/kwargs should have already been normalized to kwargs + assert len(args) == 0 + + return activation.selu( + network=network, + target="torch.nn.functional.selu", + source_ir=SourceIR.NN, + name=layer_name, + input_val=kwargs["input"], + alpha=kwargs["alpha"], + ) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_elu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_elu_aten.py new file mode 100644 index 0000000000..cd8ef1b48a --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_elu_aten.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestELUConverter(DispatchTestCase): + def test_elu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default}) + + def test_elu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + def test_elu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_selu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_selu_aten.py new file mode 100644 index 0000000000..a6e501daa0 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_selu_aten.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSeLUConverter(DispatchTestCase): + def test_selu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default}) + + def test_selu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + def test_selu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + +if __name__ == "__main__": + run_tests()