Skip to content

Commit b1db33a

Browse files
authored
fix: Ensure proper type inheritance in aten::masked_fill (#1430)
* fix: Ensure proper type inheritance in `aten::masked_fill` - Ensure `value` input inherits type from `self`, avoiding a bug where mismatched types cause TRT `Select` to throw an error - Test type-mismatched inputs (regression tests) to ensure type-casting is handled correctly * Edit test case to include a dimension with size 1
1 parent ffabdae commit b1db33a

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

core/conversion/converters/impl/select.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -662,8 +662,15 @@ auto select_registrations TORCHTRT_UNUSED =
662662
auto self = args[0].ITensorOrFreeze(ctx);
663663
auto mask = args[1].ITensorOrFreeze(ctx);
664664
mask = addPadding(ctx, n, mask, self->getDimensions().nbDims, false, true);
665-
auto val = args[2].unwrapToScalar().to<float>();
666-
auto val_t = tensor_to_const(ctx, torch::full(util::toVec(self->getDimensions()), val));
665+
auto val = args[2].unwrapToScalar();
666+
667+
// Tensor type to use for initializing constant tensor used in Select
668+
// value should inherit its type from self
669+
auto val_t_dtype = util::TRTDataTypeToScalarType(self->getType());
670+
671+
// Initialize contant tensor for fill with the inherited data type
672+
auto val_t = tensor_to_const(
673+
ctx, torch::full(util::toVec(self->getDimensions()), val, {torch::dtype(val_t_dtype)}));
667674

668675
TORCHTRT_CHECK(
669676
util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false),

tests/core/conversion/converters/test_select.cpp

+56
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,62 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
804804
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
805805
}
806806

807+
TEST(Converters, ATenMaskedFillMixedTypesFloatIntConvertsCorrectly) {
808+
const auto graph = R"IR(
809+
graph(%x.1 : Tensor, %x.2 : Tensor):
810+
%val : float = prim::Constant[value=4.0]()
811+
%out : Tensor = aten::masked_fill(%x.1, %x.2, %val)
812+
return (%out))IR";
813+
814+
auto g = std::make_shared<torch::jit::Graph>();
815+
816+
torch::jit::parseIR(graph, &*g);
817+
818+
// Input is a float tensor, filled with an int --> expecting float tensor out
819+
auto in1 = at::rand({2, 3, 5, 7}, {at::kCUDA}).to(torch::kFloat32);
820+
auto in2 = (2 * at::rand({2, 3, 5, 7}, {at::kCUDA})).to(torch::kBool);
821+
822+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
823+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
824+
825+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
826+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
827+
828+
ASSERT_TRUE(
829+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
830+
831+
// Ensure data types match in outputs
832+
ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype());
833+
}
834+
835+
TEST(Converters, ATenMaskedFillMixedTypesIntFloatConvertsCorrectly) {
836+
const auto graph = R"IR(
837+
graph(%x.1 : Tensor, %x.2 : Tensor):
838+
%val : int = prim::Constant[value=4]()
839+
%out : Tensor = aten::masked_fill(%x.1, %x.2, %val)
840+
return (%out))IR";
841+
842+
auto g = std::make_shared<torch::jit::Graph>();
843+
844+
torch::jit::parseIR(graph, &*g);
845+
846+
// Input is an integer tensor, filled with a float --> expecting integer tensor out
847+
auto in1 = at::rand({1, 3, 5, 7}, {at::kCUDA}).to(torch::kInt32);
848+
auto in2 = (2 * at::rand({1, 3, 5, 7}, {at::kCUDA})).to(torch::kBool);
849+
850+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
851+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});
852+
853+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
854+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});
855+
856+
ASSERT_TRUE(
857+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
858+
859+
// Ensure data types match in outputs
860+
ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype());
861+
}
862+
807863
TEST(Converters, ATenIndexTensorOneIndiceConvertsCorrectly) {
808864
const auto graph = R"IR(
809865
graph(%x.1 : Tensor,

0 commit comments

Comments
 (0)