@@ -804,6 +804,62 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
804
804
torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
805
805
}
806
806
807
+ TEST (Converters, ATenMaskedFillMixedTypesFloatIntConvertsCorrectly) {
808
+ const auto graph = R"IR(
809
+ graph(%x.1 : Tensor, %x.2 : Tensor):
810
+ %val : float = prim::Constant[value=4.0]()
811
+ %out : Tensor = aten::masked_fill(%x.1, %x.2, %val)
812
+ return (%out))IR" ;
813
+
814
+ auto g = std::make_shared<torch::jit::Graph>();
815
+
816
+ torch::jit::parseIR (graph, &*g);
817
+
818
+ // Input is a float tensor, filled with an int --> expecting float tensor out
819
+ auto in1 = at::rand ({2 , 3 , 5 , 7 }, {at::kCUDA }).to (torch::kFloat32 );
820
+ auto in2 = (2 * at::rand ({2 , 3 , 5 , 7 }, {at::kCUDA })).to (torch::kBool );
821
+
822
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
823
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
824
+
825
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
826
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
827
+
828
+ ASSERT_TRUE (
829
+ torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
830
+
831
+ // Ensure data types match in outputs
832
+ ASSERT_TRUE (jit_results[0 ].dtype () == trt_results[0 ].dtype ());
833
+ }
834
+
835
+ TEST (Converters, ATenMaskedFillMixedTypesIntFloatConvertsCorrectly) {
836
+ const auto graph = R"IR(
837
+ graph(%x.1 : Tensor, %x.2 : Tensor):
838
+ %val : int = prim::Constant[value=4]()
839
+ %out : Tensor = aten::masked_fill(%x.1, %x.2, %val)
840
+ return (%out))IR" ;
841
+
842
+ auto g = std::make_shared<torch::jit::Graph>();
843
+
844
+ torch::jit::parseIR (graph, &*g);
845
+
846
+ // Input is an integer tensor, filled with a float --> expecting integer tensor out
847
+ auto in1 = at::rand ({1 , 3 , 5 , 7 }, {at::kCUDA }).to (torch::kInt32 );
848
+ auto in2 = (2 * at::rand ({1 , 3 , 5 , 7 }, {at::kCUDA })).to (torch::kBool );
849
+
850
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
851
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
852
+
853
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
854
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
855
+
856
+ ASSERT_TRUE (
857
+ torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
858
+
859
+ // Ensure data types match in outputs
860
+ ASSERT_TRUE (jit_results[0 ].dtype () == trt_results[0 ].dtype ());
861
+ }
862
+
807
863
TEST (Converters, ATenIndexTensorOneIndiceConvertsCorrectly) {
808
864
const auto graph = R"IR(
809
865
graph(%x.1 : Tensor,
0 commit comments