Skip to content

Commit b1a5066

Browse files
author
Prashant Kumar
committed
Add decomposition of aten.masked.tensor op.
`aten.masked.tensor` op has been decomposed to `aten.masked.scalar` op.
1 parent d96ec64 commit b1a5066

File tree

7 files changed

+99
-6
lines changed

7 files changed

+99
-6
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,6 +1928,55 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [
19281928
}];
19291929
}
19301930

1931+
def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [
1932+
AllowsTypeRefinement,
1933+
HasValueSemantics,
1934+
ReadOnly
1935+
]> {
1936+
let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
1937+
let arguments = (ins
1938+
AnyTorchTensorType:$self,
1939+
AnyTorchTensorType:$mask,
1940+
AnyTorchTensorType:$value
1941+
);
1942+
let results = (outs
1943+
AnyTorchTensorType:$result
1944+
);
1945+
let hasCustomAssemblyFormat = 1;
1946+
let extraClassDefinition = [{
1947+
ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) {
1948+
return parseDefaultTorchOp(parser, result, 3, 1);
1949+
}
1950+
void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) {
1951+
printDefaultTorchOp(printer, *this, 3, 1);
1952+
}
1953+
}];
1954+
}
1955+
1956+
def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [
1957+
IsTrailingUnderscoreInplaceVariant,
1958+
AllowsTypeRefinement
1959+
]> {
1960+
let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
1961+
let arguments = (ins
1962+
AnyTorchTensorType:$self,
1963+
AnyTorchTensorType:$mask,
1964+
AnyTorchTensorType:$value
1965+
);
1966+
let results = (outs
1967+
AnyTorchTensorType:$result
1968+
);
1969+
let hasCustomAssemblyFormat = 1;
1970+
let extraClassDefinition = [{
1971+
ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
1972+
return parseDefaultTorchOp(parser, result, 3, 1);
1973+
}
1974+
void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) {
1975+
printDefaultTorchOp(printer, *this, 3, 1);
1976+
}
1977+
}];
1978+
}
1979+
19311980
def Torch_AtenClampOp : Torch_Op<"aten.clamp", [
19321981
AllowsTypeRefinement,
19331982
HasValueSemantics,

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -884,9 +884,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
884884
threshold);
885885
return b.create<arith::SelectOp>(loc, predicate, constantZero, grad);
886886
}
887-
if (auto maskedFill = dyn_cast<AtenMaskedFillScalarOp>(op)) {
887+
if (auto maskedFillScalar = dyn_cast<AtenMaskedFillScalarOp>(op)) {
888888
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
889-
Type dtype = converter->convertType(maskedFill.getType())
889+
Type dtype = converter->convertType(maskedFillScalar.getType())
890890
.cast<RankedTensorType>()
891891
.getElementType();
892892

@@ -896,6 +896,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
896896

897897
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
898898
}
899+
if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) {
900+
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
901+
Type dtype = converter->convertType(maskedFillTensor.getType())
902+
.cast<RankedTensorType>()
903+
.getElementType();
904+
905+
Value input = payloadArgs[0];
906+
Value mask = payloadArgs[1];
907+
Value fillValue = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
908+
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
909+
}
899910

900911
if (auto triu = dyn_cast<AtenTriuOp>(op)) {
901912
// Check if the rank of the input tensor is valid.
@@ -970,7 +981,7 @@ class ConvertElementwiseOp : public ConversionPattern {
970981
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
971982
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
972983
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
973-
AtenLogicalOrOp, AtenTriuOp>(op))
984+
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp>(op))
974985
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
975986

976987
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@@ -1708,7 +1719,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
17081719
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
17091720
AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
17101721
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
1711-
AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>();
1722+
AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp,
1723+
AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>();
17121724
patterns.add<ConvertElementwiseOp>(typeConverter, context);
17131725
target.addIllegalOp<AtenNllLossForwardOp>();
17141726
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,8 @@ void TypeAnalysis::visitOperation(Operation *op,
658658
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
659659
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp,
660660
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
661-
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
661+
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp>(
662+
op)) {
662663
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
663664
}
664665

lib/Dialect/Torch/Transforms/ShapeLibrary.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6214,6 +6214,10 @@ module {
62146214
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
62156215
return %0 : !torch.list<int>
62166216
}
6217+
func.func @"__torch_mlir_shape_fn.aten.masked_fill.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {
6218+
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
6219+
return %0 : !torch.list<int>
6220+
}
62176221
func.func @"__torch_mlir_shape_fn.aten.zero"(%arg0: !torch.list<int>) -> !torch.list<int> {
62186222
return %arg0 : !torch.list<int>
62196223
}

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,9 @@ def aten〇_to_copy(self: List[int], dtype: Optional[int] = None, layout: Option
777777
def aten〇masked_fill〇Scalar(self: List[int], mask: List[int], value: float) -> List[int]:
778778
return upstream_shape_functions.unary(self)
779779

780+
def aten〇masked_fill〇Tensor(self: List[int], mask: List[int], value: List[int]) -> List[int]:
781+
return upstream_shape_functions.unary(self)
782+
780783
def aten〇zero(self: List[int]) -> List[int]:
781784
return self
782785

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def emit_with_mutating_variants(key, **kwargs):
279279
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)",
280280
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
281281
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
282+
"aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
282283
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
283284
"aten::clamp_min : (Tensor, Scalar) -> (Tensor)",
284285
"aten::clamp_max : (Tensor, Scalar) -> (Tensor)",

python/torch_mlir_e2e_test/test_suite/constant_alloc.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@ def __init__(self):
342342
([-1, -1, -1, -1], torch.float32, True),
343343
])
344344
def forward(self, a):
345-
return torch.empty_like(a, memory_format=torch.preserve_format).fill_(0)
345+
return torch.empty_like(a,
346+
memory_format=torch.preserve_format).fill_(0)
346347

347348

348349
@register_test_case(module_factory=lambda: EmptyLikeMemoryFormatModule())
@@ -1421,3 +1422,25 @@ def forward(self, x, mask):
14211422
def MaskedFillScalarFloatValueModule_basic(module, tu: TestUtils):
14221423
module.forward(torch.randint(-10, 10, (2, 3)),
14231424
torch.randint(0, 2, (2, 3)).to(dtype=torch.bool))
1425+
1426+
1427+
class MaskedFillTensorFloatValueModule(torch.nn.Module):
1428+
1429+
def __init__(self):
1430+
super().__init__()
1431+
1432+
@export
1433+
@annotate_args([
1434+
None,
1435+
([-1, -1], torch.int64, True),
1436+
([-1, -1], torch.bool, True),
1437+
([], torch.float32, True),
1438+
])
1439+
def forward(self, x, mask, value):
1440+
return torch.ops.aten.masked_fill(x, mask, value=value)
1441+
1442+
1443+
@register_test_case(module_factory=lambda: MaskedFillTensorFloatValueModule())
1444+
def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils):
1445+
module.forward(torch.randint(-10, 10, (2, 3)),
1446+
torch.randint(0, 2, (2, 3)).to(dtype=torch.bool), tu.rand())

0 commit comments

Comments
 (0)