Skip to content

Commit dd2da5a

Browse files
authored
E2E support for AtenRemainderScalarOp (#1200)
1 parent 79b9cf9 commit dd2da5a

File tree

8 files changed

+151
-11
lines changed

8 files changed

+151
-11
lines changed

e2e_testing/torchscript/xfail_sets.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@
179179
"TransposeIntNegDimsModule_basic",
180180
"ArgmaxModule_keepDim",
181181
"ArgmaxModule_with_dim",
182-
"_LogSoftmaxModuleStable_basic",
182+
"_LogSoftmaxModuleStable_basic",
183183
}
184184

185185
LTC_XFAIL_SET = {
@@ -338,4 +338,8 @@
338338
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
339339
"AtenEmbeddingBagSumExample_basic",
340340
"Aten_EmbeddingBagExample_basic",
341+
"ElementwiseRemainderScalarModule_Int_Float_basic",
342+
"ElementwiseRemainderScalarModule_Float_basic",
343+
"ElementwiseRemainderScalarModule_Int_basic",
344+
"ElementwiseRemainderScalarModule_Bool_basic",
341345
}

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7910,6 +7910,30 @@ def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [
79107910
let hasFolder = 1;
79117911
}
79127912

7913+
def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [
7914+
AllowsTypeRefinement,
7915+
HasValueSemantics,
7916+
ReadOnly
7917+
]> {
7918+
let summary = "Generated op for `aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)`";
7919+
let arguments = (ins
7920+
AnyTorchTensorType:$self,
7921+
AnyTorchScalarType:$other
7922+
);
7923+
let results = (outs
7924+
AnyTorchTensorType:$result
7925+
);
7926+
let hasCustomAssemblyFormat = 1;
7927+
let extraClassDefinition = [{
7928+
ParseResult AtenRemainderScalarOp::parse(OpAsmParser &parser, OperationState &result) {
7929+
return parseDefaultTorchOp(parser, result, 2, 1);
7930+
}
7931+
void AtenRemainderScalarOp::print(OpAsmPrinter &printer) {
7932+
printDefaultTorchOp(printer, *this, 2, 1);
7933+
}
7934+
}];
7935+
}
7936+
79137937
def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [
79147938
AllowsTypeRefinement,
79157939
HasValueSemantics,

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,26 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
803803
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
804804
return b.create<arith::DivFOp>(loc, self, other);
805805
}
806+
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
807+
Type newResultType = converter->convertType(remScalar.getType())
808+
.cast<RankedTensorType>()
809+
.getElementType();
810+
811+
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
812+
Value other = convertScalarToDtype(b, loc, operands[1], newResultType);
813+
Value result;
814+
815+
if (newResultType.isa<mlir::FloatType>()) {
816+
result = b.create<arith::RemFOp>(loc, self, other);
817+
} else if (newResultType.isa<mlir::IntegerType>()) {
818+
result = b.create<arith::RemSIOp>(loc, self, other);
819+
} else {
820+
remScalar.emitError(
821+
"Unsupported type encountered for AtenRemainderScalarOp.");
822+
}
823+
824+
return result;
825+
}
806826
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
807827
Type dtype = converter->convertType(reciprocal.getType())
808828
.cast<RankedTensorType>()
@@ -943,14 +963,14 @@ class ConvertElementwiseOp : public ConversionPattern {
943963
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
944964
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
945965
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
946-
AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp,
947-
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
948-
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
949-
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
950-
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
951-
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
952-
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenLogicalOrOp,
953-
AtenTriuOp>(op))
966+
AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp,
967+
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
968+
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
969+
AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
970+
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
971+
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
972+
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
973+
AtenLogicalOrOp, AtenTriuOp>(op))
954974
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
955975

956976
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@@ -1688,7 +1708,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
16881708
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
16891709
AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
16901710
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
1691-
AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp>();
1711+
AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>();
16921712
patterns.add<ConvertElementwiseOp>(typeConverter, context);
16931713
target.addIllegalOp<AtenNllLossForwardOp>();
16941714
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ void TypeAnalysis::visitOperation(Operation *op,
770770
// Promote LHS with scalar RHS.
771771
if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp,
772772
AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenPowTensorScalarOp,
773-
AtenRsubScalarOp, AtenLeakyReluOp>(op)) {
773+
AtenRsubScalarOp, AtenLeakyReluOp, AtenRemainderScalarOp>(op)) {
774774
auto lhs = operands[0]->getValue();
775775
Value scalar = op->getOperand(1);
776776
auto knowledge =

lib/Dialect/Torch/Transforms/ShapeLibrary.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5441,6 +5441,10 @@ module {
54415441
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
54425442
return %0 : !torch.list<int>
54435443
}
5444+
func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
5445+
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
5446+
return %0 : !torch.list<int>
5447+
}
54445448
func.func @"__torch_mlir_shape_fn.aten.to.dtype"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
54455449
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
54465450
return %0 : !torch.list<int>

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,9 @@ def aten〇mul〇Scalar(self: List[int], other: float) -> List[int]:
481481
def aten〇div〇Scalar(self: List[int], other: float) -> List[int]:
482482
return upstream_shape_functions.unary(self)
483483

484+
def aten〇remainder〇Scalar(self: List[int], other: float) -> List[int]:
485+
return upstream_shape_functions.unary(self)
486+
484487
def aten〇floor_divide〇Scalar(self: List[int], other: float) -> List[int]:
485488
return upstream_shape_functions.unary(self)
486489

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ def emit_with_mutating_variants(key, **kwargs):
550550
emit("aten::eq.int : (int, int) -> (bool)", has_folder=True)
551551
emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True)
552552
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
553+
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
553554
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
554555
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
555556
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)

python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,90 @@ def forward(self, x):
13041304
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
13051305
module.forward(tu.rand(3, 4))
13061306

1307+
# ==============================================================================
1308+
1309+
1310+
class ElementwiseRemainderScalarModule_Int_Float(torch.nn.Module):
1311+
1312+
def __init__(self):
1313+
super().__init__()
1314+
1315+
@export
1316+
@annotate_args([
1317+
None,
1318+
([-1], torch.int32, True),
1319+
])
1320+
def forward(self, x):
1321+
return torch.remainder(x, 2.0)
1322+
1323+
1324+
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float())
1325+
def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils):
1326+
module.forward(torch.randint(10, (3,), dtype=torch.int32))
1327+
1328+
1329+
# ==============================================================================
1330+
1331+
1332+
class ElementwiseRemainderScalarModule_Float(torch.nn.Module):
1333+
1334+
def __init__(self):
1335+
super().__init__()
1336+
1337+
@export
1338+
@annotate_args([
1339+
None,
1340+
([-1, -1], torch.float32, True),
1341+
])
1342+
def forward(self, x):
1343+
return torch.remainder(x, 2.0)
1344+
1345+
1346+
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Float())
1347+
def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils):
1348+
module.forward(torch.rand(10, 3))
1349+
1350+
1351+
# ==============================================================================
1352+
1353+
class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
1354+
1355+
def __init__(self):
1356+
super().__init__()
1357+
1358+
@export
1359+
@annotate_args([
1360+
None,
1361+
([-1, -1], torch.int32, True),
1362+
])
1363+
def forward(self, x):
1364+
return torch.remainder(x, 2)
1365+
1366+
1367+
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int())
1368+
def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils):
1369+
module.forward(torch.randint(10, (3, 2), dtype=torch.int32))
1370+
1371+
# ==============================================================================
1372+
1373+
class ElementwiseRemainderScalarModule_Bool(torch.nn.Module):
1374+
1375+
def __init__(self):
1376+
super().__init__()
1377+
1378+
@export
1379+
@annotate_args([
1380+
None,
1381+
([-1], torch.bool, True),
1382+
])
1383+
def forward(self, x):
1384+
return torch.remainder(x, 2)
1385+
1386+
1387+
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Bool())
1388+
def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils):
1389+
module.forward(torch.tensor([True, False, True, True, True]))
1390+
13071391

13081392
# ==============================================================================
13091393

0 commit comments

Comments
 (0)