diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6154ebe644..f34028c3f3 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -501,7 +501,6 @@ def index_put_converter( F = [i for i in range(rank) if indices[i] is None] # Free dimensions I = [i for i in range(rank) if indices[i] is not None] # Indexed dimensions K = len(I) - # Determine the maximum size 'N' among the index tensors if K > 0: index_shapes = [tensor.shape[0] for tensor in indices if tensor is not None] @@ -684,16 +683,6 @@ def index_put_converter( values_reshaped = impl.shuffle.reshape( ctx, target, source_ir, f"{name}_reshape_scalar", values, (1,) ) - num_dims = len(expected_shape) - ones_shape = tuple([1] * num_dims) - values_reshaped = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_to_ones", - values_reshaped, - ones_shape, - ) values_expanded = impl.slice.expand( ctx, target, @@ -704,40 +693,79 @@ def index_put_converter( ) else: # Non-scalar case values_shape = list(values.shape) - - # Pad dimensions if necessary - if len(values_shape) < len(expected_shape): - values_shape = [1] * ( - len(expected_shape) - len(values_shape) - ) + values_shape - - # Calculate a broadcastable shape - broadcast_shape = [] - for exp_dim, val_dim in zip(expected_shape, values_shape): - if val_dim == 1: - broadcast_shape.append(exp_dim) - elif val_dim == exp_dim: - broadcast_shape.append(val_dim) + if K > 0 and N in values_shape: + n_idx = values_shape.index(N) + permute_order = [n_idx] + [ + i for i in range(len(values_shape)) if i != n_idx + ] + values_permuted = impl.permutation.permute( + ctx, target, source_ir, f"{name}_permute_values", values, permute_order + ) + remaining_shape = [ + values_shape[i] for i in range(len(values_shape)) if i != n_idx + ] + target_f_dims = len(F) + current_f_dims = len(remaining_shape) + if current_f_dims < target_f_dims: + values_expanded_shape = ( + [N] + [1] * (target_f_dims - current_f_dims) + remaining_shape + ) else: - raise ValueError(f"Cannot broadcast {values_shape} to {expected_shape}") - - # Reshape and then expand - values_reshaped = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_values", - values, - tuple(broadcast_shape), - ) - values_expanded = impl.slice.expand( - ctx, - target, - source_ir, - f"{name}_expand_values", - values_reshaped, - expected_shape, - ) + values_expanded_shape = [N] + remaining_shape[:target_f_dims] + values_expanded = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_unsqueeze_values", + values_permuted, + tuple(values_expanded_shape), + ) + broadcast_shape = [] + for exp_dim, val_dim in zip(expected_shape, values_expanded_shape): + if val_dim == 1: + broadcast_shape.append(exp_dim) + elif val_dim == exp_dim: + broadcast_shape.append(val_dim) + else: + raise ValueError( + f"Cannot broadcast {values_expanded_shape} to {expected_shape}" + ) + values_expanded = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_values", + values_expanded, + tuple(broadcast_shape), + ) + else: + values_shape_padded = [1] * ( + len(expected_shape) - len(values.shape) + ) + list(values.shape) + broadcast_shape = [] + for exp_dim, val_dim in zip(expected_shape, values_shape_padded): + if val_dim == 1 or exp_dim == val_dim: + broadcast_shape.append(exp_dim) + else: + raise ValueError( + f"Cannot broadcast {values.shape} to {expected_shape}" + ) + values_reshaped = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_values", + values, + tuple(broadcast_shape), + ) + values_expanded = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_values", + values_reshaped, + expected_shape, + ) # Flatten values to (N * F_volume,) flattened_values = impl.shuffle.reshape( @@ -749,6 +777,7 @@ def index_put_converter( (N * F_volume,), ) + indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32") # Perform Scatter ND operation scatter_layer = ctx.net.add_scatter( input_tensor, diff --git a/tests/py/dynamo/conversion/test_index_put_aten.py b/tests/py/dynamo/conversion/test_index_put_aten.py index 8413071026..74e38cd0c5 100644 --- a/tests/py/dynamo/conversion/test_index_put_aten.py +++ b/tests/py/dynamo/conversion/test_index_put_aten.py @@ -194,6 +194,12 @@ class TestIndexPutConverter(DispatchTestCase): dtype=torch.int32, ), ), + param( + test_name="4d_indices_none_none_multiple_idx_broadcast_error", + source_tensor=torch.zeros([1, 2, 5, 3], dtype=torch.float32), + indices_tensor=(None, None, torch.tensor([0, 1, 2], dtype=torch.int64)), + value_tensor=torch.randn([2, 3, 3], dtype=torch.float32), + ), # param( # test_name="2d_indices_accumulate_True", # source_tensor=torch.zeros([5, 5], dtype=torch.int32),