Skip to content

Commit 4a5c28f

Browse files
committed
fix(//tests): use right type for masked_fill test
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 0647d17 commit 4a5c28f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

Diff for: tests/core/conversion/converters/test_select.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -433,13 +433,14 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
433433
%44 : Device = prim::Constant[value="cuda"]()
434434
%8 : bool = prim::Constant[value=0]()
435435
%7 : None = prim::Constant()
436+
%f32_dtype: int = prim::Constant[value=11]()
436437
%1 : int = prim::Constant[value=0]() # bert.py:5:26
437438
%2 : int = prim::Constant[value=1]() # bert.py:5:32
438439
%33 : int = prim::Constant[value=2]() # bert.py:6:31
439440
%3 : int[] = prim::ListConstruct(%1, %1, %2)
440441
%4 : int[] = prim::ListConstruct(%2, %2, %1)
441442
%5 : int[][] = prim::ListConstruct(%3, %4)
442-
%9 : Tensor = aten::tensor(%5, %1, %7, %8) # bert.py:5:11
443+
%9 : Tensor = aten::tensor(%5, %f32_dtype, %7, %8) # bert.py:5:11
443444
%mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11
444445
%mask.2 : Tensor = trt::const(%mask.1)
445446
%34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11

0 commit comments

Comments
 (0)