Skip to content

Commit 3354b58

Browse files
committed
aten::where mypy changes- cases where x_shape!=y_shape. ToDO: to add more such test cases
1 parent 3a034e1 commit 3354b58

File tree

1 file changed

+3
-3
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/condition

1 file changed

+3
-3
lines changed

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,15 @@ def where(
6565
condition_val = condition_layer.get_output(0)
6666
else:
6767
assert condition.dtype == trt.bool, "mask dtype is not bool!"
68-
if condition_shape != condition_dim: # TODO: What is this checking?
68+
if len(condition_shape) != condition_dim: # TODO: What is this checking?
6969
condition_val = expand(
7070
network, target, source_ir, f"{name}_expand", condition, output_shape
7171
)
7272
else:
7373
condition_val = condition
7474

7575
if type(input) != TRTTensor:
76-
if x_shape != input_dim: # TODO: What is this checking?
76+
if x_shape != output_shape: # TODO: What is this checking?
7777
# special case where 1 element in input
7878
if len(input.shape) == 0:
7979
input = input.unsqueeze(0)
@@ -95,7 +95,7 @@ def where(
9595
y_val = get_trt_tensor(network, other, f"{name}_y")
9696
else:
9797
y_val = other
98-
if y_shape != other_dim: # TODO: What is this checking?
98+
if y_shape != output_shape: # TODO: What is this checking?
9999
y_val = expand(
100100
network, target, source_ir, f"{name}_y_expand", y_val, output_shape
101101
)

0 commit comments

Comments
 (0)