Skip to content

Commit cd61e54

Browse files
authored
fix: bug in elementwise base for static inputs (#2819)
1 parent db67cb9 commit cd61e54

File tree

6 files changed

+100
-54
lines changed

6 files changed

+100
-54
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload
55

66
import numpy as np
7+
import tensorrt as trt
78
import torch
89
import torch_tensorrt.dynamo.conversion.impl as impl
910
from torch import SymBool, SymFloat, SymInt
@@ -15,11 +16,12 @@
1516
ConverterRegistry,
1617
DynamoConverterImplSignature,
1718
)
18-
from torch_tensorrt.fx.converters.converter_utils import get_axes_for_reduce_op
19+
from torch_tensorrt.fx.converters.converter_utils import (
20+
broadcast,
21+
get_axes_for_reduce_op,
22+
)
1923
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
2024

21-
import tensorrt as trt
22-
2325
_LOGGER: logging.Logger = logging.getLogger(__name__)
2426

2527

@@ -205,6 +207,72 @@ def broadcastable(
205207
return True
206208

207209

210+
def broadcast_to_same_shape(
211+
ctx: ConversionContext,
212+
target: Target,
213+
source_ir: Optional[SourceIR],
214+
name: str,
215+
lhs_val: TRTTensor,
216+
rhs_val: TRTTensor,
217+
) -> Tuple[TRTTensor, TRTTensor]:
218+
"""Broadcast ITensors `lhs_val` and `rhs_val` to the same shape. If the shapes are already the same, return the
219+
original tensors. If the shapes are different, broadcast the tensors to the same shape.
220+
221+
This helper function is different from fx/converter_utils.broadcast.
222+
fx/converter_utils.broadcast only broadcasts two ITensors to the same number of dimensions (ranks)
223+
by prepending 1s, while this function broadcasts two ITensors to the same shape.
224+
225+
For example, we have original ITensors: lhs_val.shape: (2, 3) rhs_val.shape: (2, 2, 1, 3)
226+
If calling fx/converter_utils.broadcast, lhs_val.shape: (1, 1, 2, 3) lhs_val.shape: (2, 2, 1, 3).
227+
If calling this function broadcast_to_same_shape, lhs_val.shape: (2, 2, 2, 3) lhs_val.shape: (2, 2, 2, 3).
228+
229+
Args:
230+
lhs_val (TRTTensor): A TensorRT ITensor.
231+
rhs_val (TRTTensor): A TensorRT ITensor.
232+
233+
Returns:
234+
Tuple[TRTTensor, TRTTensor]: Two TensorRT ITensors that are broadcasted to the same shape
235+
236+
"""
237+
lhs_val, rhs_val = broadcast(
238+
ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
239+
)
240+
241+
lhs_val_shape = lhs_val.shape
242+
rhs_val_shape = rhs_val.shape
243+
244+
if tuple(lhs_val_shape) != tuple(rhs_val_shape):
245+
rank = len(lhs_val_shape)
246+
expanded_dims = [-1] * len(lhs_val_shape)
247+
248+
for dim in range(rank):
249+
expanded_dims[dim] = max(lhs_val_shape[dim], rhs_val_shape[dim])
250+
251+
expanded_shape = tuple(expanded_dims)
252+
253+
if lhs_val_shape != expanded_shape:
254+
lhs_val = impl.slice.expand(
255+
ctx,
256+
target,
257+
source_ir,
258+
f"{name}_expand_lhs_val",
259+
lhs_val,
260+
expanded_shape,
261+
)
262+
263+
if rhs_val_shape != expanded_shape:
264+
rhs_val = impl.slice.expand(
265+
ctx,
266+
target,
267+
source_ir,
268+
f"{name}_expand_rhs_val",
269+
rhs_val,
270+
expanded_shape,
271+
)
272+
273+
return lhs_val, rhs_val
274+
275+
208276
get_axes_for_reduce_op = functools.partial(
209277
get_axes_for_reduce_op, has_implicit_batch_dimension=False
210278
)

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from torch.fx.node import Target
99
from torch_tensorrt import _enums
1010
from torch_tensorrt.dynamo._SourceIR import SourceIR
11-
from torch_tensorrt.dynamo.conversion import impl
1211
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1312
from torch_tensorrt.dynamo.conversion.converter_utils import (
13+
broadcast_to_same_shape,
1414
cast_trt_tensor,
1515
get_trt_tensor,
1616
)
@@ -152,41 +152,12 @@ def convert_binary_elementwise(
152152

153153
if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
154154
lhs_val, rhs_val = broadcast(
155-
ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
155+
ctx.net, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs"
156156
)
157157
else:
158-
lhs_val_shape = lhs_val.shape
159-
rhs_val_shape = rhs_val.shape
160-
rank_diff = len(lhs_val_shape) - len(rhs_val_shape)
161-
if rank_diff > 0:
162-
rhs_val = impl.slice.expand(
163-
ctx, target, source_ir, f"{name}_expand_rhs_val", rhs_val, lhs_val_shape
164-
)
165-
elif rank_diff < 0:
166-
lhs_val = impl.slice.expand(
167-
ctx, target, source_ir, f"{name}_expand_lhs_val", lhs_val, rhs_val_shape
168-
)
169-
else:
170-
if tuple(lhs_val_shape) != tuple(rhs_val_shape):
171-
sum_diff = sum(lhs_val_shape) - sum(rhs_val_shape)
172-
if sum_diff > 0:
173-
rhs_val = impl.slice.expand(
174-
ctx,
175-
target,
176-
source_ir,
177-
f"{name}_expand_rhs_val",
178-
rhs_val,
179-
lhs_val_shape,
180-
)
181-
elif sum_diff < 0:
182-
lhs_val = impl.slice.expand(
183-
ctx,
184-
target,
185-
source_ir,
186-
f"{name}_expand_lhs_val",
187-
lhs_val,
188-
rhs_val_shape,
189-
)
158+
lhs_val, rhs_val = broadcast_to_same_shape(
159+
ctx, target, source_ir, f"{name}_broadcast_to_same_shape", lhs_val, rhs_val
160+
)
190161

191162
layer = ctx.net.add_elementwise(lhs_val, rhs_val, op_type)
192163
set_layer_name(layer, target, name, source_ir)

tests/py/dynamo/conversion/test_binary_ops_aten.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import unittest
21
from typing import Callable
32

43
import torch
@@ -59,7 +58,6 @@ def forward(self, x):
5958
self.run_test(m, inputs)
6059

6160
@parameterized.expand([(op[0].__name__, op[0]) for op in elementwise_ops])
62-
@unittest.skip("Pending reimplementation of all binary converters in Dynamo")
6361
def test_elementwise_ops_mismatched_dtypes(self, name, orig_op: Callable):
6462
class TestModule(nn.Module):
6563
def __init__(self, orig_op):

tests/py/dynamo/conversion/test_bitwise_and_aten.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,21 @@
99
class TestBitwiseAndConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12-
("2d", (5, 3)),
13-
("3d", (5, 3, 2)),
12+
("2d", (2, 3), (2, 3)),
13+
("3d", (5, 3, 2), (5, 3, 2)),
14+
("3d_broadcast", (2, 3), (2, 1, 3)),
15+
("4d_broadcast_1", (2, 3), (1, 2, 1, 3)),
16+
("4d_broadcast_2", (2, 3), (2, 2, 2, 3)),
1417
]
1518
)
16-
def test_bitwise_and_tensor(self, _, shape):
19+
def test_bitwise_and_tensor(self, _, lhs_shape, rhs_shape):
1720
class bitwise_and(nn.Module):
1821
def forward(self, lhs_val, rhs_val):
1922
return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val)
2023

2124
inputs = [
22-
torch.randint(0, 2, shape, dtype=bool),
23-
torch.randint(0, 2, shape, dtype=bool),
25+
torch.randint(0, 2, lhs_shape, dtype=bool),
26+
torch.randint(0, 2, rhs_shape, dtype=bool),
2427
]
2528
self.run_test(
2629
bitwise_and(),

tests/py/dynamo/conversion/test_bitwise_or_aten.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,21 @@
99
class TestBitwiseOrConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12-
("2d", (5, 3)),
13-
("3d", (5, 3, 2)),
12+
("2d", (2, 3), (2, 3)),
13+
("3d", (5, 3, 2), (5, 3, 2)),
14+
("3d_broadcast", (2, 3), (2, 1, 3)),
15+
("4d_broadcast_1", (2, 3), (1, 2, 1, 3)),
16+
("4d_broadcast_2", (2, 3), (2, 2, 2, 3)),
1417
]
1518
)
16-
def test_bitwise_or_tensor(self, _, shape):
19+
def test_bitwise_or_tensor(self, _, lhs_shape, rhs_shape):
1720
class bitwise_or(nn.Module):
1821
def forward(self, lhs_val, rhs_val):
1922
return torch.ops.aten.bitwise_or.Tensor(lhs_val, rhs_val)
2023

2124
inputs = [
22-
torch.randint(0, 2, shape, dtype=bool),
23-
torch.randint(0, 2, shape, dtype=bool),
25+
torch.randint(0, 2, lhs_shape, dtype=bool),
26+
torch.randint(0, 2, rhs_shape, dtype=bool),
2427
]
2528
self.run_test(
2629
bitwise_or(),

tests/py/dynamo/conversion/test_bitwise_xor_aten.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,21 @@
99
class TestBitwiseXorConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12-
("2d", (5, 3)),
13-
("3d", (5, 3, 2)),
12+
("2d", (2, 3), (2, 3)),
13+
("3d", (5, 3, 2), (5, 3, 2)),
14+
("3d_broadcast", (2, 3), (2, 1, 3)),
15+
("4d_broadcast_1", (2, 3), (1, 2, 1, 3)),
16+
("4d_broadcast_2", (2, 3), (2, 2, 2, 3)),
1417
]
1518
)
16-
def test_bitwise_xor_tensor(self, _, shape):
19+
def test_bitwise_xor_tensor(self, _, lhs_shape, rhs_shape):
1720
class bitwise_xor(nn.Module):
1821
def forward(self, lhs_val, rhs_val):
1922
return torch.ops.aten.bitwise_xor.Tensor(lhs_val, rhs_val)
2023

2124
inputs = [
22-
torch.randint(0, 2, shape, dtype=bool),
23-
torch.randint(0, 2, shape, dtype=bool),
25+
torch.randint(0, 2, lhs_shape, dtype=bool),
26+
torch.randint(0, 2, rhs_shape, dtype=bool),
2427
]
2528
self.run_test(
2629
bitwise_xor(),

0 commit comments

Comments
 (0)