Skip to content

Commit 67cac89

Browse files
apbosenarendasan
authored andcommitted
Converter reorg elu
Adding selu converter Python linting correction
1 parent 3fc3c6d commit 67cac89

File tree

6 files changed

+219
-10
lines changed

6 files changed

+219
-10
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -1048,11 +1048,14 @@ def acc_ops_elu(
10481048
kwargs: Dict[str, Argument],
10491049
name: str,
10501050
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1051-
input_val = kwargs["input"]
1052-
alpha = kwargs["alpha"]
1053-
operation_type = trt.ActivationType.ELU
1054-
return activation.convert_activation(
1055-
network, target, SourceIR.ACC, name, operation_type, input_val, alpha=alpha
1051+
1052+
return activation.elu(
1053+
network,
1054+
target,
1055+
SourceIR.ACC,
1056+
name,
1057+
kwargs["input"],
1058+
kwargs["alpha"],
10561059
)
10571060

10581061

@@ -1064,15 +1067,13 @@ def acc_ops_selu(
10641067
kwargs: Dict[str, Argument],
10651068
name: str,
10661069
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1067-
input_val = kwargs["input"]
1068-
operation_type = trt.ActivationType.SELU
1069-
return activation.convert_activation(
1070+
1071+
return activation.selu(
10701072
network,
10711073
target,
10721074
SourceIR.ACC,
10731075
name,
1074-
operation_type,
1075-
input_val,
1076+
kwargs["input"],
10761077
)
10771078

10781079

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+27
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,33 @@ def aten_ops_div(
170170
)
171171

172172

173+
@tensorrt_converter(torch.ops.aten.elu.default)
174+
def aten_ops_elu(
175+
network: TRTNetwork,
176+
target: Target,
177+
args: Tuple[Argument, ...],
178+
kwargs: Dict[str, Argument],
179+
name: str,
180+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
181+
182+
if len(args) > 2:
183+
return activation.selu(
184+
network,
185+
target,
186+
SourceIR.ATEN,
187+
name,
188+
args[0],
189+
)
190+
return activation.elu(
191+
network,
192+
target,
193+
SourceIR.ATEN,
194+
name,
195+
args[0],
196+
args[1],
197+
)
198+
199+
173200
@tensorrt_converter(torch.ops.aten.floor_divide.default)
174201
def aten_ops_floor_div(
175202
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/activation.py

+48
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,51 @@ def sigmoid_fn(x):
148148
input_val,
149149
dyn_range_fn=sigmoid_dyn_range_fn,
150150
)
151+
152+
153+
def elu(
154+
network: TRTNetwork,
155+
target: Target,
156+
source_ir: Optional[SourceIR],
157+
name: str,
158+
input_val: TRTTensor,
159+
alpha: Optional[Any],
160+
):
161+
operation_type = trt.ActivationType.ELU
162+
163+
def elu_dyn_range_fn(dyn_range):
164+
return (torch.nn.ELU(dyn_range[0]), torch.nn.ELU(dyn_range[1]))
165+
166+
return convert_activation(
167+
network,
168+
target,
169+
source_ir,
170+
name,
171+
operation_type,
172+
input_val,
173+
alpha,
174+
dyn_range_fn=elu_dyn_range_fn,
175+
)
176+
177+
178+
def selu(
179+
network: TRTNetwork,
180+
target: Target,
181+
source_ir: Optional[SourceIR],
182+
name: str,
183+
input_val: TRTTensor,
184+
):
185+
operation_type = trt.ActivationType.SELU
186+
187+
def elu_dyn_range_fn(dyn_range):
188+
return (torch.nn.SELU(dyn_range[0]), torch.nn.ELU(dyn_range[1]))
189+
190+
return convert_activation(
191+
network,
192+
target,
193+
source_ir,
194+
name,
195+
operation_type,
196+
input_val,
197+
dyn_range_fn=elu_dyn_range_fn,
198+
)

py/torch_tensorrt/fx/converters/nn_ops_converters.py

+31
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,34 @@ def hardtanh(network, submod, args, kwargs, layer_name):
5151
name=layer_name,
5252
input_val=kwargs["input"],
5353
)
54+
55+
56+
@tensorrt_converter(torch.nn.functional.elu)
57+
@tensorrt_converter(torch.nn.modules.activation.ELU)
58+
def elu(network, submod, args, kwargs, layer_name):
59+
# args/kwargs should have already been normalized to kwargs
60+
assert len(args) == 0
61+
62+
return activation.elu(
63+
network=network,
64+
target="torch.nn.functional.elu",
65+
source_ir=SourceIR.NN,
66+
name=layer_name,
67+
input_val=kwargs["input"],
68+
)
69+
70+
71+
@tensorrt_converter(torch.nn.functional.selu)
72+
@tensorrt_converter(torch.nn.modules.activation.SELU)
73+
def selu(network, submod, args, kwargs, layer_name):
74+
# args/kwargs should have already been normalized to kwargs
75+
assert len(args) == 0
76+
77+
return activation.selu(
78+
network=network,
79+
target="torch.nn.functional.selu",
80+
source_ir=SourceIR.NN,
81+
name=layer_name,
82+
input_val=kwargs["input"],
83+
alpha=kwargs["alpha"],
84+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestELUConverter(DispatchTestCase):
8+
def test_elu(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.elu(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default})
15+
16+
def test_elu_with_dynamic_shape(self):
17+
class TestModule(nn.Module):
18+
def forward(self, x):
19+
return nn.functional.elu(x)
20+
21+
input_specs = [
22+
InputTensorSpec(
23+
shape=(-1, -1, -1),
24+
dtype=torch.float32,
25+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
),
27+
]
28+
self.run_test_with_dynamic_shape(
29+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
30+
)
31+
32+
def test_elu_with_dynamic_shape_four_dimensions(self):
33+
class TestModule(nn.Module):
34+
def forward(self, x):
35+
return nn.functional.elu(x)
36+
37+
input_specs = [
38+
InputTensorSpec(
39+
shape=(-1, -1, -1, -1),
40+
dtype=torch.float32,
41+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
42+
),
43+
]
44+
45+
self.run_test_with_dynamic_shape(
46+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
run_tests()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestSeLUConverter(DispatchTestCase):
8+
def test_selu(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.selu(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default})
15+
16+
def test_selu_with_dynamic_shape(self):
17+
class TestModule(nn.Module):
18+
def forward(self, x):
19+
return nn.functional.selu(x)
20+
21+
input_specs = [
22+
InputTensorSpec(
23+
shape=(-1, -1, -1),
24+
dtype=torch.float32,
25+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
),
27+
]
28+
self.run_test_with_dynamic_shape(
29+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
30+
)
31+
32+
def test_selu_with_dynamic_shape_four_dimensions(self):
33+
class TestModule(nn.Module):
34+
def forward(self, x):
35+
return nn.functional.selu(x)
36+
37+
input_specs = [
38+
InputTensorSpec(
39+
shape=(-1, -1, -1, -1),
40+
dtype=torch.float32,
41+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
42+
),
43+
]
44+
45+
self.run_test_with_dynamic_shape(
46+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
run_tests()

0 commit comments

Comments
 (0)