-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][spirv] Add folding for [S|U]Mod, [S|U]Div, SRem #73341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add missing constant propogation folder for [S|U]Mod, [S|U]Div, SRem Implement additional folding when rhs is 1 for all ops. This helps for readability of lowered code into SPIR-V. Part of work for llvm#70704
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Finn Plummer (inbelic) ChangesAdd missing constant propogation folder for [S|U]Mod, [S|U]Div, SRem Implement additional folding when rhs is 1 for all ops. This helps for readability of lowered code into SPIR-V. Part of work for #70704 Full diff: https://github.com/llvm/llvm-project/pull/73341.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..16bf173cb7971e0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -534,6 +534,8 @@ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -573,6 +575,8 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -673,6 +677,8 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -707,6 +713,8 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -811,6 +819,7 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod",
```
}];
+ let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..8144a100dab3495 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -69,6 +69,14 @@ static Attribute extractCompositeElement(Attribute composite,
return {};
}
+static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
+ bool div0 = b.isZero();
+
+ bool overflow = a.isMinSignedValue() && b.isAllOnes();
+
+ return div0 || overflow;
+}
+
//===----------------------------------------------------------------------===//
// TableGen'erated canonicalizers
//===----------------------------------------------------------------------===//
@@ -290,6 +298,158 @@ OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
[](APInt a, const APInt &b) { return std::move(a) - b; });
}
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
+ // sdiv (x, 1) = x
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
+
+ // According to the SPIR-V spec:
+ //
+ // Signed-integer division of Operand 1 divided by Operand 2.
+ // Results are computed per component. Behavior is undefined if Operand 2 is
+ // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
+ // representable value for the operands' type, causing signed overflow.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ return a.sdiv(b);
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
+ // smod (x, 1) = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getZeroAttr(getType());
+
+ // According to SPIR-V spec:
+ //
+ // Signed remainder operation for the remainder whose sign matches the sign
+ // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
+ // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+ // value for the operands' type, causing signed overflow. Otherwise, the
+ // result is the remainder r of Operand 1 divided by Operand 2 where if
+ // r ≠ 0, the sign of r is the same as the sign of Operand 2.
+ //
+ // So don't fold during undefined behaviour
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ APInt c = a.abs().urem(b.abs());
+ if (c.isZero())
+ return c;
+ if (b.isNegative()) {
+ APInt zero = APInt::getZero(c.getBitWidth());
+ return a.isNegative() ? (zero - c) : (b + c);
+ }
+ return a.isNegative() ? (b - c) : c;
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
+ // x % 1 = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getZeroAttr(getType());
+
+ // According to SPIR-V spec:
+ //
+ // Signed remainder operation for the remainder whose sign matches the sign
+ // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
+ // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+ // value for the operands' type, causing signed overflow. Otherwise, the
+ // result is the remainder r of Operand 1 divided by Operand 2 where if
+ // r ≠ 0, the sign of r is the same as the sign of Operand 1.
+
+ // Don't fold if it would do undefined behaviour.
+ bool div0OrOverflow = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](APInt a, const APInt &b) {
+ if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+ div0OrOverflow = true;
+ return a;
+ }
+ return a.srem(b);
+ });
+ return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
+ // udiv (x, 1) = x
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
+
+ // According to the SPIR-V spec:
+ //
+ // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
+ // undefined if Operand 2 is 0.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0 = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0 || b.isZero()) {
+ div0 = true;
+ return a;
+ }
+ return a.udiv(b);
+ });
+ return div0 ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
+ // umod (x, 1) = 0
+ if (matchPattern(getOperand2(), m_One()))
+ return Builder(getContext()).getZeroAttr(getType());
+
+ // According to the SPIR-V spec:
+ //
+ // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
+ // undefined if Operand 2 is 0.
+ //
+ // So don't fold during undefined behaviour.
+ bool div0 = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (div0 || b.isZero()) {
+ div0 = true;
+ return a;
+ }
+ return a.urem(b);
+ });
+ return div0 ? Attribute() : res;
+}
+
//===----------------------------------------------------------------------===//
// spirv.LogicalAnd
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a..7b1163601e1b427 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -462,10 +462,272 @@ func.func @const_fold_vector_isub() -> vector<3xi32> {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @sdiv_x_1
+func.func @sdiv_x_1(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: return %arg0 : i32
+ %c1 = spirv.Constant 1 : i32
+ %2 = spirv.SDiv %arg0, %c1: i32
+ return %2 : i32
+}
+
+// CHECK-LABEL: @sdiv_div_0_or_overflow
+func.func @sdiv_div_0_or_overflow() -> (i32, i32) {
+ // CHECK: spirv.SDiv
+ // CHECK: spirv.SDiv
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %min_i32 = spirv.Constant -2147483648 : i32
+
+ %0 = spirv.SDiv %cn1, %c0 : i32
+ %1 = spirv.SDiv %min_i32, %cn1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_sdiv
+func.func @const_fold_scalar_sdiv() -> (i32, i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: spirv.Constant -18
+ // CHECK-DAG: spirv.Constant -2
+ // CHECK-DAG: spirv.Constant -7
+ // CHECK-DAG: spirv.Constant 8
+ %0 = spirv.SDiv %c56, %c7 : i32
+ %1 = spirv.SDiv %c56, %cn8 : i32
+ %2 = spirv.SDiv %cn8, %c3 : i32
+ %3 = spirv.SDiv %c56, %cn3 : i32
+ return %0, %1, %2, %3: i32, i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_sdiv
+func.func @const_fold_vector_sdiv() -> vector<3xi32> {
+ // CHECK: spirv.Constant dense<[0, -1, -3]>
+
+ %cv_num = spirv.Constant dense<[42, 24, -16]> : vector<3xi32>
+ %cv_denom = spirv.Constant dense<[76, -24, 5]> : vector<3xi32>
+ %0 = spirv.SDiv %cv_num, %cv_denom : vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smod_x_1
+func.func @smod_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant dense<0>
+ %c1 = spirv.Constant 1 : i32
+ %cv1 = spirv.Constant dense<1> : vector<3xi32>
+ %0 = spirv.SMod %arg0, %c1: i32
+ %1 = spirv.SMod %arg1, %cv1: vector<3xi32>
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @smod_div_0_or_overflow
+func.func @smod_div_0_or_overflow() -> (i32, i32) {
+ // CHECK: spirv.SMod
+ // CHECK: spirv.SMod
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %min_i32 = spirv.Constant -2147483648 : i32
+
+ %0 = spirv.SMod %cn1, %c0 : i32
+ %1 = spirv.SMod %min_i32, %cn1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_smod
+func.func @const_fold_scalar_smod() -> (i32, i32, i32, i32, i32, i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %cn56 = spirv.Constant -56 : i32
+ %c59 = spirv.Constant 59 : i32
+ %cn59 = spirv.Constant -59 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+ // CHECK-DAG: %[[FIFTYTHREE:.*]] = spirv.Constant 53 : i32
+ // CHECK-DAG: %[[NFIFTYTHREE:.*]] = spirv.Constant -53 : i32
+ // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+ // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+ %0 = spirv.SMod %c56, %c7 : i32
+ %1 = spirv.SMod %c56, %cn8 : i32
+ %2 = spirv.SMod %c56, %c3 : i32
+ %3 = spirv.SMod %cn3, %c56 : i32
+ %4 = spirv.SMod %cn3, %cn56 : i32
+ %5 = spirv.SMod %c59, %c56 : i32
+ %6 = spirv.SMod %c59, %cn56 : i32
+ %7 = spirv.SMod %cn59, %cn56 : i32
+
+ // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[FIFTYTHREE]], %[[NTHREE]], %[[THREE]], %[[NFIFTYTHREE]], %[[NTHREE]]
+ return %0, %1, %2, %3, %4, %5, %6, %7 : i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_smod
+func.func @const_fold_vector_smod() -> vector<3xi32> {
+ // CHECK: spirv.Constant dense<[42, -4, 4]>
+
+ %cv = spirv.Constant dense<[42, 24, -16]> : vector<3xi32>
+ %cv_mod = spirv.Constant dense<[76, -7, 5]> : vector<3xi32>
+ %0 = spirv.SMod %cv, %cv_mod : vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @srem_x_1
+func.func @srem_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant dense<0>
+ %c1 = spirv.Constant 1 : i32
+ %cv1 = spirv.Constant dense<1> : vector<3xi32>
+ %0 = spirv.SRem %arg0, %c1: i32
+ %1 = spirv.SRem %arg1, %cv1: vector<3xi32>
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @srem_div_0_or_overflow
+func.func @srem_div_0_or_overflow() -> (i32, i32) {
+ // CHECK: spirv.SRem
+ // CHECK: spirv.SRem
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %min_i32 = spirv.Constant -2147483648 : i32
+
+ %0 = spirv.SRem %cn1, %c0 : i32
+ %1 = spirv.SRem %min_i32, %cn1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_srem
+func.func @const_fold_scalar_srem() -> (i32, i32, i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[ONE:.*]] = spirv.Constant 1 : i32
+ // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+ // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+ // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ %0 = spirv.SRem %c56, %c7 : i32
+ %1 = spirv.SRem %c56, %cn8 : i32
+ %2 = spirv.SRem %c56, %c3 : i32
+ %3 = spirv.SRem %cn3, %c56 : i32
+ %4 = spirv.SRem %c7, %cn3 : i32
+ // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[NTHREE]], %[[ONE]]
+ return %0, %1, %2, %3, %4 : i32, i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @udiv_x_1
+func.func @udiv_x_1(%arg0 : i32) -> i32 {
+ // CHECK-NEXT: return %arg0 : i32
+ %c1 = spirv.Constant 1 : i32
+ %2 = spirv.UDiv %arg0, %c1: i32
+ return %2 : i32
+}
+
+// CHECK-LABEL: @udiv_div_0
+func.func @udiv_div_0() -> i32 {
+ // CHECK: spirv.UDiv
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %0 = spirv.UDiv %cn1, %c0 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_udiv
+func.func @const_fold_scalar_udiv() -> (i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant 1431655762
+ // CHECK-DAG: spirv.Constant 8
+ %0 = spirv.UDiv %c56, %c7 : i32
+ %1 = spirv.UDiv %cn8, %c3 : i32
+ %2 = spirv.UDiv %c56, %cn8 : i32
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.UMod
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @umod_x_1
+func.func @umod_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant dense<0>
+ %c1 = spirv.Constant 1 : i32
+ %cv1 = spirv.Constant dense<1> : vector<3xi32>
+ %0 = spirv.UMod %arg0, %c1: i32
+ %1 = spirv.UMod %arg1, %cv1: vector<3xi32>
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @umod_div_0
+func.func @umod_div_0() -> i32 {
+ // CHECK: spirv.UMod
+ %c0 = spirv.Constant 0 : i32
+ %cn1 = spirv.Constant -1 : i32
+ %0 = spirv.UMod %cn1, %c0 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_umod
+func.func @const_fold_scalar_umod() -> (i32, i32, i32) {
+ %c56 = spirv.Constant 56 : i32
+ %c7 = spirv.Constant 7 : i32
+ %cn8 = spirv.Constant -8 : i32
+ %c3 = spirv.Constant 3 : i32
+
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant 2
+ // CHECK-DAG: spirv.Constant 56
+ %0 = spirv.UMod %c56, %c7 : i32
+ %1 = spirv.UMod %cn8, %c3 : i32
+ %2 = spirv.UMod %c56, %cn8 : i32
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_umod
+func.func @const_fold_vector_umod() -> vector<3xi32> {
+ // CHECK: spirv.Constant dense<[42, 24, 0]>
+
+ %cv = spirv.Constant dense<[42, 24, -16]> : vector<3xi32>
+ %cv_mod = spirv.Constant dense<[76, -7, 5]> : vector<3xi32>
+ %0 = spirv.UMod %cv, %cv_mod : vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
// CHECK-LABEL: @umod_fold
// CHECK-SAME: (%[[ARG:.*]]: i32)
func.func @umod_fold(%arg0: i32) -> (i32, i32) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks for adding these @inbelic! The implementation looks good to me, but I think we should make the tests match the expected results more precisely.
- fix spacing issue - correct spelling - make testcases more strict when matching the return values to ensure proper order
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the contribution!
Add missing constant propogation folder for [S|U]Mod, [S|U]Div, SRem
Implement additional folding when rhs is 1 for all ops.
This helps for readability of lowered code into SPIR-V.
Part of work for #70704