diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 567735dfbd..c15c0e2025 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -662,8 +662,15 @@ auto select_registrations TORCHTRT_UNUSED = auto self = args[0].ITensorOrFreeze(ctx); auto mask = args[1].ITensorOrFreeze(ctx); mask = addPadding(ctx, n, mask, self->getDimensions().nbDims, false, true); - auto val = args[2].unwrapToScalar().to(); - auto val_t = tensor_to_const(ctx, torch::full(util::toVec(self->getDimensions()), val)); + auto val = args[2].unwrapToScalar(); + + // Tensor type to use for initializing constant tensor used in Select + // value should inherit its type from self + auto val_t_dtype = util::TRTDataTypeToScalarType(self->getType()); + + // Initialize contant tensor for fill with the inherited data type + auto val_t = tensor_to_const( + ctx, torch::full(util::toVec(self->getDimensions()), val, {torch::dtype(val_t_dtype)})); TORCHTRT_CHECK( util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false), diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 2b70ac3dfc..40c5f11843 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -804,6 +804,62 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } +TEST(Converters, ATenMaskedFillMixedTypesFloatIntConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, %x.2 : Tensor): + %val : float = prim::Constant[value=4.0]() + %out : Tensor = aten::masked_fill(%x.1, %x.2, %val) + return (%out))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + // Input is a float tensor, filled with an int --> expecting float tensor out + auto in1 = at::rand({2, 3, 5, 7}, {at::kCUDA}).to(torch::kFloat32); + auto in2 = (2 * at::rand({2, 3, 5, 7}, {at::kCUDA})).to(torch::kBool); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); + + // Ensure data types match in outputs + ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype()); +} + +TEST(Converters, ATenMaskedFillMixedTypesIntFloatConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, %x.2 : Tensor): + %val : int = prim::Constant[value=4]() + %out : Tensor = aten::masked_fill(%x.1, %x.2, %val) + return (%out))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + // Input is an integer tensor, filled with a float --> expecting integer tensor out + auto in1 = at::rand({1, 3, 5, 7}, {at::kCUDA}).to(torch::kInt32); + auto in2 = (2 * at::rand({1, 3, 5, 7}, {at::kCUDA})).to(torch::kBool); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); + + // Ensure data types match in outputs + ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype()); +} + TEST(Converters, ATenIndexTensorOneIndiceConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor,