Skip to content

Commit 00a0e39

Browse files
committed
refactor: Centralizing sigmoid implementation
Signed-off-by: Naren Dasan <[email protected]>
1 parent 9864d96 commit 00a0e39

File tree

6 files changed

+119
-47
lines changed

6 files changed

+119
-47
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -3192,21 +3192,13 @@ def acc_ops_sigmoid(
31923192
kwargs: Dict[str, Argument],
31933193
name: str,
31943194
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3195-
input_val = kwargs["input"]
31963195

3197-
if not isinstance(input_val, TRTTensor):
3198-
raise RuntimeError(
3199-
f"Sigmoid received input {input_val} that is not part "
3200-
"of the TensorRT region!"
3201-
)
3202-
3203-
return activation.convert_activation(
3196+
return activation.convert_sigmoid(
32043197
network,
32053198
target,
32063199
SourceIR.ACC,
32073200
name,
3208-
trt.ActivationType.SIGMOID,
3209-
input_val,
3201+
kwargs,
32103202
)
32113203

32123204

py/torch_tensorrt/fx/converters/activation.py

-37
This file was deleted.

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+22
Original file line numberDiff line numberDiff line change
@@ -486,3 +486,25 @@ def aten_ops_sym_size(
486486
)
487487
set_layer_name(slice_layer, target, "_slice_layer")
488488
return slice_layer.get_output(0)
489+
490+
491+
@tensorrt_converter(torch.ops.aten.sigmoid.default)
492+
def aten_ops_sigmoid(
493+
network: TRTNetwork,
494+
target: Target,
495+
args: Tuple[Argument, ...],
496+
kwargs: Dict[str, Argument],
497+
name: str,
498+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
499+
500+
kwargs_new = {
501+
"input": args[0],
502+
}
503+
504+
return activation.convert_sigmoid(
505+
network,
506+
target,
507+
SourceIR.ATEN,
508+
name,
509+
kwargs_new,
510+
)

py/torch_tensorrt/fx/converters/impl/activation.py

+28
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,31 @@ def relu_dyn_range_fn(dyn_range):
9191
input_val,
9292
dyn_range_fn=relu_dyn_range_fn,
9393
)
94+
95+
96+
def convert_sigmoid(
97+
network: TRTNetwork,
98+
target: Target,
99+
source_ir: Optional[SourceIR],
100+
name: str,
101+
kwargs: Dict[str, Any],
102+
):
103+
input_val = kwargs["input"]
104+
operation_type = trt.ActivationType.SIGMOID
105+
106+
def sigmoid_dyn_range_fn(dyn_range):
107+
def sigmoid_fn(x):
108+
# TODO: Can this just call torch.nn.functional.sigmoid?
109+
return 1 / (1 + np.exp(-x))
110+
111+
return sigmoid_fn(dyn_range[0]), sigmoid_fn(dyn_range[1])
112+
113+
return convert_activation(
114+
network,
115+
target,
116+
source_ir,
117+
name,
118+
operation_type,
119+
input_val,
120+
dyn_range_fn=sigmoid_dyn_range_fn,
121+
)

py/torch_tensorrt/fx/converters/nn_ops_converters.py

+14
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,17 @@ def relu(network, submod, args, kwargs, layer_name):
2222
name=layer_name,
2323
kwargs=kwargs,
2424
)
25+
26+
27+
@tensorrt_converter(torch.nn.modules.activation.Sigmoid)
28+
def sigmoid(network, submod, args, kwargs, layer_name):
29+
# args/kwargs should have already been normalized to kwargs
30+
assert len(args) == 0
31+
32+
activation.convert_sigmoid(
33+
network=network,
34+
target="torch.nn.modules.activation.Sigmoid",
35+
source_ir=SourceIR.NN,
36+
name=layer_name,
37+
kwargs=kwargs,
38+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestSigmoidConverter(DispatchTestCase):
8+
def test_sigmoid(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.sigmoid(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(
15+
TestModule(), inputs, expected_ops={torch.ops.aten.sigmoid.default}
16+
)
17+
18+
def test_sigmoid_with_dynamic_shape(self):
19+
class TestModule(nn.Module):
20+
def forward(self, x):
21+
return nn.functional.sigmoid(x)
22+
23+
input_specs = [
24+
InputTensorSpec(
25+
shape=(-1, -1, -1),
26+
dtype=torch.float32,
27+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
28+
),
29+
]
30+
self.run_test_with_dynamic_shape(
31+
TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default}
32+
)
33+
34+
def test_sigmoid_with_dynamic_shape_four_dimensions(self):
35+
class TestModule(nn.Module):
36+
def forward(self, x):
37+
return nn.functional.sigmoid(x)
38+
39+
input_specs = [
40+
InputTensorSpec(
41+
shape=(-1, -1, -1, -1),
42+
dtype=torch.float32,
43+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
44+
),
45+
]
46+
47+
self.run_test_with_dynamic_shape(
48+
TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default}
49+
)
50+
51+
52+
if __name__ == "__main__":
53+
run_tests()

0 commit comments

Comments
 (0)