Skip to content

Commit 5259e14

Browse files
authored
[mlir][spirv] Add folding for [S|U]Mod, [S|U]Div, SRem (#73341)
Add missing constant propogation folder for [S|U]Mod, [S|U]Div, SRem Implement additional folding when rhs is 1 for all ops. This helps for readability of lowered code into SPIR-V. Part of work for #70704
1 parent 749d595 commit 5259e14

File tree

3 files changed

+465
-0
lines changed

3 files changed

+465
-0
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
481481

482482
```
483483
}];
484+
485+
let hasFolder = 1;
484486
}
485487

486488
// -----
@@ -513,6 +515,8 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
513515

514516
```
515517
}];
518+
519+
let hasFolder = 1;
516520
}
517521

518522
// -----
@@ -606,6 +610,8 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
606610

607611
```
608612
}];
613+
614+
let hasFolder = 1;
609615
}
610616

611617
// -----
@@ -632,6 +638,8 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
632638
%5 = spirv.UDiv %2, %3 : vector<4xi32>
633639
```
634640
}];
641+
642+
let hasFolder = 1;
635643
}
636644

637645
// -----
@@ -728,6 +736,7 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod",
728736
```
729737
}];
730738

739+
let hasFolder = 1;
731740
let hasCanonicalizer = 1;
732741
}
733742

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ static Attribute extractCompositeElement(Attribute composite,
6969
return {};
7070
}
7171

72+
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
73+
bool div0 = b.isZero();
74+
bool overflow = a.isMinSignedValue() && b.isAllOnes();
75+
76+
return div0 || overflow;
77+
}
78+
7279
//===----------------------------------------------------------------------===//
7380
// TableGen'erated canonicalizers
7481
//===----------------------------------------------------------------------===//
@@ -290,6 +297,158 @@ OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
290297
[](APInt a, const APInt &b) { return std::move(a) - b; });
291298
}
292299

300+
//===----------------------------------------------------------------------===//
301+
// spirv.SDiv
302+
//===----------------------------------------------------------------------===//
303+
304+
OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
305+
// sdiv (x, 1) = x
306+
if (matchPattern(getOperand2(), m_One()))
307+
return getOperand1();
308+
309+
// According to the SPIR-V spec:
310+
//
311+
// Signed-integer division of Operand 1 divided by Operand 2.
312+
// Results are computed per component. Behavior is undefined if Operand 2 is
313+
// 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
314+
// representable value for the operands' type, causing signed overflow.
315+
//
316+
// So don't fold during undefined behavior.
317+
bool div0OrOverflow = false;
318+
auto res = constFoldBinaryOp<IntegerAttr>(
319+
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
320+
if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
321+
div0OrOverflow = true;
322+
return a;
323+
}
324+
return a.sdiv(b);
325+
});
326+
return div0OrOverflow ? Attribute() : res;
327+
}
328+
329+
//===----------------------------------------------------------------------===//
330+
// spirv.SMod
331+
//===----------------------------------------------------------------------===//
332+
333+
OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
334+
// smod (x, 1) = 0
335+
if (matchPattern(getOperand2(), m_One()))
336+
return Builder(getContext()).getZeroAttr(getType());
337+
338+
// According to SPIR-V spec:
339+
//
340+
// Signed remainder operation for the remainder whose sign matches the sign
341+
// of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
342+
// undefined if Operand 2 is -1 and Operand 1 is the minimum representable
343+
// value for the operands' type, causing signed overflow. Otherwise, the
344+
// result is the remainder r of Operand 1 divided by Operand 2 where if
345+
// r ≠ 0, the sign of r is the same as the sign of Operand 2.
346+
//
347+
// So don't fold during undefined behavior
348+
bool div0OrOverflow = false;
349+
auto res = constFoldBinaryOp<IntegerAttr>(
350+
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
351+
if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
352+
div0OrOverflow = true;
353+
return a;
354+
}
355+
APInt c = a.abs().urem(b.abs());
356+
if (c.isZero())
357+
return c;
358+
if (b.isNegative()) {
359+
APInt zero = APInt::getZero(c.getBitWidth());
360+
return a.isNegative() ? (zero - c) : (b + c);
361+
}
362+
return a.isNegative() ? (b - c) : c;
363+
});
364+
return div0OrOverflow ? Attribute() : res;
365+
}
366+
367+
//===----------------------------------------------------------------------===//
368+
// spirv.SRem
369+
//===----------------------------------------------------------------------===//
370+
371+
OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
372+
// x % 1 = 0
373+
if (matchPattern(getOperand2(), m_One()))
374+
return Builder(getContext()).getZeroAttr(getType());
375+
376+
// According to SPIR-V spec:
377+
//
378+
// Signed remainder operation for the remainder whose sign matches the sign
379+
// of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
380+
// undefined if Operand 2 is -1 and Operand 1 is the minimum representable
381+
// value for the operands' type, causing signed overflow. Otherwise, the
382+
// result is the remainder r of Operand 1 divided by Operand 2 where if
383+
// r ≠ 0, the sign of r is the same as the sign of Operand 1.
384+
385+
// Don't fold if it would do undefined behavior.
386+
bool div0OrOverflow = false;
387+
auto res = constFoldBinaryOp<IntegerAttr>(
388+
adaptor.getOperands(), [&](APInt a, const APInt &b) {
389+
if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
390+
div0OrOverflow = true;
391+
return a;
392+
}
393+
return a.srem(b);
394+
});
395+
return div0OrOverflow ? Attribute() : res;
396+
}
397+
398+
//===----------------------------------------------------------------------===//
399+
// spirv.UDiv
400+
//===----------------------------------------------------------------------===//
401+
402+
OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
403+
// udiv (x, 1) = x
404+
if (matchPattern(getOperand2(), m_One()))
405+
return getOperand1();
406+
407+
// According to the SPIR-V spec:
408+
//
409+
// Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
410+
// undefined if Operand 2 is 0.
411+
//
412+
// So don't fold during undefined behavior.
413+
bool div0 = false;
414+
auto res = constFoldBinaryOp<IntegerAttr>(
415+
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
416+
if (div0 || b.isZero()) {
417+
div0 = true;
418+
return a;
419+
}
420+
return a.udiv(b);
421+
});
422+
return div0 ? Attribute() : res;
423+
}
424+
425+
//===----------------------------------------------------------------------===//
426+
// spirv.UMod
427+
//===----------------------------------------------------------------------===//
428+
429+
OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
430+
// umod (x, 1) = 0
431+
if (matchPattern(getOperand2(), m_One()))
432+
return Builder(getContext()).getZeroAttr(getType());
433+
434+
// According to the SPIR-V spec:
435+
//
436+
// Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
437+
// undefined if Operand 2 is 0.
438+
//
439+
// So don't fold during undefined behavior.
440+
bool div0 = false;
441+
auto res = constFoldBinaryOp<IntegerAttr>(
442+
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
443+
if (div0 || b.isZero()) {
444+
div0 = true;
445+
return a;
446+
}
447+
return a.urem(b);
448+
});
449+
return div0 ? Attribute() : res;
450+
}
451+
293452
//===----------------------------------------------------------------------===//
294453
// spirv.LogicalAnd
295454
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)