Skip to content

Commit 003853f

Browse files
committed
Add float support for masked_fill
1 parent 48418b9 commit 003853f

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
906906

907907
Value input = payloadArgs[0];
908908
Value mask = payloadArgs[1];
909+
if (mask.getType().isa<mlir::FloatType>())
910+
mask = b.create<arith::ConstantOp>(loc, b.getBoolAttr(false));
909911
Value fillValue = convertScalarToDtype(b, loc, adaptor.value(), dtype);
910912

911913
return b.create<arith::SelectOp>(loc, mask, fillValue, input);

python/torch_mlir_e2e_test/test_suite/constant_alloc.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,3 +1462,25 @@ def forward(self, x, mask, value):
14621462
def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils):
14631463
module.forward(tu.randint(2, 3, low=-10, high=10),
14641464
tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.rand())
1465+
1466+
1467+
class MaskedFillTensorFloatMaskModule(torch.nn.Module):
1468+
1469+
def __init__(self):
1470+
super().__init__()
1471+
1472+
@export
1473+
@annotate_args([
1474+
None,
1475+
([-1, -1], torch.int64, True),
1476+
([-1, -1], torch.float, True),
1477+
])
1478+
def forward(self, x, mask):
1479+
return torch.ops.aten.masked_fill(x, mask, value=0.1)
1480+
1481+
1482+
@register_test_case(module_factory=lambda: MaskedFillTensorFloatMaskModule())
1483+
def MaskedFillTensorFloatMaskModule_basic(module, tu: TestUtils):
1484+
mask = tu.randint(2, 3, low=0, high=2).to(torch.float)
1485+
module.forward(tu.randint(2, 3, low=-10, high=10),
1486+
mask)

0 commit comments

Comments
 (0)