Skip to content

Commit 4b4a1b4

Browse files
apbosenarendasan
authored andcommitted
Activation Converter reorg and adding tanh in aten operations
Correcting linting error correction nn_ops_converters Undoing setup.py change mistake
1 parent 6d28bba commit 4b4a1b4

File tree

5 files changed

+113
-5
lines changed

5 files changed

+113
-5
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1169,15 +1169,12 @@ def acc_ops_tanh(
11691169
kwargs: Dict[str, Argument],
11701170
name: str,
11711171
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1172-
input_val = kwargs["input"]
1173-
operation_type = trt.ActivationType.TANH
1174-
return activation.convert_activation(
1172+
return activation.tanh(
11751173
network,
11761174
target,
11771175
SourceIR.ACC,
11781176
name,
1179-
operation_type,
1180-
input_val,
1177+
kwargs["input"],
11811178
)
11821179

11831180

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+18
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,24 @@ def aten_ops_reshape(
353353
return layer.get_output(0)
354354

355355

356+
@tensorrt_converter(torch.ops.aten.tanh.default)
357+
def aten_ops_tanh(
358+
network: TRTNetwork,
359+
target: Target,
360+
args: Tuple[Argument, ...],
361+
kwargs: Dict[str, Argument],
362+
name: str,
363+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
364+
365+
return activation.tanh(
366+
network,
367+
target,
368+
SourceIR.ATEN,
369+
name,
370+
args[0],
371+
)
372+
373+
356374
@tensorrt_converter(torch.ops.aten.cat.default)
357375
def aten_ops_cat(
358376
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/activation.py

+27
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,30 @@ def sigmoid_fn(x):
117117
input_val,
118118
dyn_range_fn=sigmoid_dyn_range_fn,
119119
)
120+
121+
122+
def tanh(
123+
network: TRTNetwork,
124+
target: Target,
125+
source_ir: Optional[SourceIR],
126+
name: str,
127+
input_val: TRTTensor,
128+
):
129+
operation_type = trt.ActivationType.TANH
130+
131+
def tanh_dyn_range_fn(dyn_range):
132+
def tanh_fn(x):
133+
# TODO: Can this just call torch.nn.functional.tanh?
134+
return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))
135+
136+
return tanh_fn(dyn_range[0]), tanh_fn(dyn_range[1])
137+
138+
return convert_activation(
139+
network,
140+
target,
141+
source_ir,
142+
name,
143+
operation_type,
144+
input_val,
145+
dyn_range_fn=tanh_dyn_range_fn,
146+
)

py/torch_tensorrt/fx/converters/nn_ops_converters.py

+15
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,18 @@ def sigmoid(network, submod, args, kwargs, layer_name):
3636
name=layer_name,
3737
input_val=kwargs["input"],
3838
)
39+
40+
41+
@tensorrt_converter(torch.nn.functional.tanh)
42+
@tensorrt_converter(torch.nn.modules.activation.Tanh)
43+
def tanh(network, submod, args, kwargs, layer_name):
44+
# args/kwargs should have already been normalized to kwargs
45+
assert len(args) == 0
46+
47+
return activation.tanh(
48+
network=network,
49+
target="torch.nn.modules.activation.Tanh",
50+
source_ir=SourceIR.NN,
51+
name=layer_name,
52+
input_val=kwargs["input"],
53+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestTanhConverter(DispatchTestCase):
8+
def test_tanh(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.tanh(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.tanh.default})
15+
16+
def test_tanh_with_dynamic_shape(self):
17+
class TestModule(nn.Module):
18+
def forward(self, x):
19+
return nn.functional.tanh(x)
20+
21+
input_specs = [
22+
InputTensorSpec(
23+
shape=(-1, -1, -1),
24+
dtype=torch.float32,
25+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
),
27+
]
28+
self.run_test_with_dynamic_shape(
29+
TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default}
30+
)
31+
32+
def test_tanh_with_dynamic_shape_four_dimensions(self):
33+
class TestModule(nn.Module):
34+
def forward(self, x):
35+
return nn.functional.tanh(x)
36+
37+
input_specs = [
38+
InputTensorSpec(
39+
shape=(-1, -1, -1, -1),
40+
dtype=torch.float32,
41+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
42+
),
43+
]
44+
45+
self.run_test_with_dynamic_shape(
46+
TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default}
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
run_tests()

0 commit comments

Comments
 (0)