Skip to content

Commit 02c7da4

Browse files
committed
[feat] support converter for torch.log2
1 parent ab08c63 commit 02c7da4

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+17
Original file line numberDiff line numberDiff line change
@@ -2695,6 +2695,23 @@ def aten_ops_flip(
26952695
)
26962696

26972697

2698+
@dynamo_tensorrt_converter(torch.ops.aten.log2.default)
2699+
def log2(
2700+
ctx: ConversionContext,
2701+
target: Target,
2702+
args: Tuple[Argument, ...],
2703+
kwargs: Dict[str, Argument],
2704+
name: str,
2705+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2706+
return impl.unary.log2(
2707+
ctx,
2708+
target,
2709+
SourceIR.ATEN,
2710+
name,
2711+
args[0],
2712+
)
2713+
2714+
26982715
@dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default)
26992716
def aten_ops_scalar_tensor(
27002717
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

+16
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,22 @@ def log10(
7777
)
7878

7979

80+
def log2(
81+
ctx: ConversionContext,
82+
target: Target,
83+
source_ir: Optional[SourceIR],
84+
name: str,
85+
input_val: TRTTensor,
86+
) -> TRTTensor:
87+
log_layer_output = log(ctx, target, source_ir, f"{name}_log", input_val)
88+
89+
ln2 = 0.693147180559945309
90+
91+
return impl.elementwise.div(
92+
ctx, target, source_ir, f"{name}_div", log_layer_output, ln2
93+
)
94+
95+
8096
def sqrt(
8197
ctx: ConversionContext,
8298
target: Target,
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestLogConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((10,), torch.float),
13+
((1, 20), torch.float),
14+
((2, 3, 4), torch.float),
15+
((2, 3, 4, 5), torch.float),
16+
]
17+
)
18+
def test_log_float(self, input_shape, dtype):
19+
class log2(nn.Module):
20+
def forward(self, input):
21+
return torch.ops.aten.log2.default(input)
22+
23+
inputs = [torch.randn(input_shape, dtype=dtype)]
24+
self.run_test(
25+
log2(),
26+
inputs,
27+
)
28+
29+
@parameterized.expand(
30+
[
31+
((10,), torch.int, 0, 5),
32+
((1, 20), torch.int32, -10, 10),
33+
((2, 3, 4), torch.int, -5, 5),
34+
]
35+
)
36+
def test_log_int(self, input_shape, dtype, low, high):
37+
class log2(nn.Module):
38+
def forward(self, input):
39+
return torch.ops.aten.log2.default(input)
40+
41+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
42+
self.run_test(
43+
log2(),
44+
inputs,
45+
)
46+
47+
48+
if __name__ == "__main__":
49+
run_tests()

0 commit comments

Comments
 (0)