diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index bcb8495c67..3edcbad2dd 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -344,6 +344,10 @@ def create_constant( with unset_fake_temporarily(): torch_value = to_torch(value, dtype) + if torch_value is None: + raise ValueError( + f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None." + ) if torch_value.dtype == torch.float64: raise ValueError( "TensorRT does not support float64 (double) precision. To resolve this, please set truncate_double=True in your compilation settings and re-run the model." @@ -1065,3 +1069,42 @@ def load_tensorrt_llm() -> bool: ) return False return False + + +def promote_trt_tensors_to_same_dtype( + ctx: ConversionContext, lhs: TRTTensor, rhs: TRTTensor, name_prefix: str +) -> tuple[TRTTensor, TRTTensor]: + """ + Promotes two TensorRT tensors to a common data type to ensure type compatibility + during operations (e.g., select, where, etc.), following simplified PyTorch promotion rules. + + Args: + ctx: Conversion context containing the TRT network definition. + lhs: The left-hand-side TensorRT tensor. + rhs: The right-hand-side TensorRT tensor. + name_prefix: A prefix string used to name any cast operations. + + Returns: + A tuple of (lhs_cast, rhs_cast) TensorRT tensors, both cast to the promoted dtype. + """ + + # Define supported float types (TensorRT supports float16 and float32) + float_types = {trt.float16, trt.float32} + + # Case 1: If either tensor is a float, promote to the wider float type + if lhs.dtype in float_types or rhs.dtype in float_types: + # Prefer float32 if either tensor is float32 + if lhs.dtype == trt.float32 or rhs.dtype == trt.float32: + promoted_dtype = trt.float32 + else: + promoted_dtype = trt.float16 + else: + # Case 2: If both tensors are int types (e.g., int32, int64), promote to int32 + # (Note: TensorRT does not support int64 for many ops like select/where) + promoted_dtype = trt.int32 + + # Cast both tensors to the promoted dtype + lhs_cast = cast_trt_tensor(ctx, lhs, promoted_dtype, f"{name_prefix}lhs_cast") + rhs_cast = cast_trt_tensor(ctx, rhs, promoted_dtype, f"{name_prefix}rhs_cast") + + return lhs_cast, rhs_cast diff --git a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py index bad671899b..e21e7f32a1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py @@ -11,6 +11,7 @@ cast_trt_tensor, get_trt_tensor, prepend_ones, + promote_trt_tensors_to_same_dtype, set_layer_name, ) from torch_tensorrt.dynamo.conversion.impl.elementwise import ne @@ -57,6 +58,9 @@ def where( if diff > 0: other = prepend_ones(ctx, other, f"{name}_other_broadcast", diff) + # Ensure that input and other have the same TRT dtype + input, other = promote_trt_tensors_to_same_dtype(ctx, input, other, name) + return select(ctx, target, source_ir, name, input, other, condition) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6154ebe644..cf7808fff4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -462,6 +462,7 @@ def gather( ) -> TRTTensor: input_shape = input.shape dim = get_positive_dim(dim, len(input_shape)) + index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor") gather_layer = ctx.net.add_gather(input, index, axis=dim) gather_layer.mode = trt.GatherMode.ELEMENT set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index b4165477ed..8037858151 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -196,8 +196,10 @@ def slice_scatter_decomposition( ) -> torch.Tensor: dim_size = input_tensor.shape[dim] device_input_tensor = input_tensor.device + + start = 0 if start is None else start # Ensure start is int start = get_positive_dim(start, input_tensor.shape[dim]) - if end is None: + if end is None: # Ensure end is int end = dim_size end = get_positive_dim(end, input_tensor.shape[dim]) if step is None: @@ -575,6 +577,53 @@ def cudnn_grid_sampler_decomposition( return torch.grid_sampler_2d(x, grid, 0, 0, True) +@register_torch_trt_decomposition( + aten.masked_scatter, registry=TORCH_TRT_DECOMPOSITIONS +) +def masked_scatter_decomposition( + input: torch.Tensor, + mask: torch.Tensor, + source: torch.Tensor, +) -> torch.Tensor: + """ + Decomposition of `aten.masked_scatter` for TensorRT. + + Emulates the behavior of `input[mask] = source` using only TensorRT-compatible ops. + + Steps: + 1) Broadcast `input` and `mask` to a common shape. + 2) Flatten all tensors for uniform indexing. + 3) Compute gather indices for `source` by applying cumsum to the boolean mask. + - Use `masked_fill` to avoid invalid indices in positions where `mask` is False. + 4) Gather values from `source` at valid positions. + 5) Use `torch.where` to insert gathered values into `input` where `mask` is True. + 6) Reshape the result back to the original broadcasted shape. + """ + + # 1) Broadcast input and mask to the same shape + input_b, mask_b = aten.broadcast_tensors([input, mask]) + + # 2) Flatten tensors for element-wise operations + input_flat = input_b.flatten() + mask_flat = mask_b.flatten() + source_flat = source.flatten() + + # 3) Compute gather indices from cumsum of the mask + # Subtract 1 so that the first True position maps to index 0 in source + source_idx = mask_flat.cumsum(0) - 1 + # Set gather index to 0 where mask is False (these will be ignored later) + safe_idx = source_idx.masked_fill(~mask_flat, 0) + + # 4) Gather values from source using computed indices + gathered = source_flat.gather(0, safe_idx) + + # 5) Replace masked positions in input with gathered values + replaced = torch.where(mask_flat, gathered, input_flat) + + # 6) Reshape the result to match the original broadcasted shape + return replaced.view(input_b.shape) + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 49baaa5db2..b63e0f3bf7 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -2167,6 +2167,86 @@ def forward(self, x, grid): msg="Cudnn_grid_sampler TRT outputs don't match with the original model.", ) + @parameterized.expand( + [ + ("float32_2d", torch.float32, (4, 4)), + ("float16_3d", torch.float16, (2, 3, 4)), + ] + ) + def test_masked_scatter(self, _, dtype, shape): + """ + Test that masked_scatter.default is correctly decomposed into + (cumsum, gather, where, etc.) and that final TRT results match PyTorch. + """ + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, mask, source): + return torch.ops.aten.masked_scatter.default(x, mask, source) + + x = torch.randn(*shape, dtype=dtype, device="cuda") + + mask = torch.rand(*shape, device="cuda") > 0.5 + num_trues = mask.sum().item() + if num_trues == 0: + mask[0] = True + num_trues = 1 + source = torch.arange(num_trues, dtype=dtype, device="cuda") + + inputs = [x, mask, source] + + fx_graph = torch.fx.symbolic_trace(TestModule()) + + expected_ops = { + torch.ops.aten.where.self, + torch.ops.aten.gather.default, + torch.ops.aten.cumsum.default, + } + unexpected_ops = {torch.ops.aten.masked_scatter.default} + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + trt_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + with torch.no_grad(): + trt_results = trt_model(*inputs).detach().cpu() + torch_results = fx_graph(*inputs).detach().cpu() + + max_diff = float(torch.max(torch.abs(trt_results - torch_results))) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Masked_scatter TRT outputs don't match with the original model. (diff={max_diff})", + ) + if __name__ == "__main__": run_tests()