-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Extend CreateMaskFolder
#75842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Extend CreateMaskFolder
#75842
Conversation
Extends `CreateMaskFolder` pattern so that the following: ```mlir %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %0 = vector.vscale %1 = arith.muli %0, %c16 : index %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1> ``` is folded as: ```mlir %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> ```
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesExtends %c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%0 = vector.vscale
%1 = arith.muli %0, %c16 : index
%10 = vector.create_mask %c8, %1 : vector<8x[16]xi1> is folded as: %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> Full diff: https://github.com/llvm/llvm-project/pull/75842.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..3619c1c00f1664 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5657,30 +5657,79 @@ LogicalResult CreateMaskOp::verify() {
namespace {
-// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+///
+/// Ex 1:
+/// %c2 = arith.constant 2 : index
+/// %c3 = arith.constant 3 : index
+/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
+/// Becomes:
+/// vector.constant_mask [3, 2] : vector<4x3xi1>
+///
+/// Ex 2:
+/// %c_neg_1 = arith.constant -1 : index
+/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
+/// becomes:
+/// vector.constant_mask [0] : vector<[8]xi1>
+///
+/// Ex 3:
+/// %c8 = arith.constant 8 : index
+/// %c16 = arith.constant 16 : index
+/// %0 = vector.vscale
+/// %1 = arith.muli %0, %c16 : index
+/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
+/// becomes:
+/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
- // Return if any of 'createMaskOp' operands are not defined by a constant.
- auto isNotDefByConstant = [](Value operand) {
- return !getConstantIntValue(operand).has_value();
- };
- if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
- return failure();
+ VectorType retTy = createMaskOp.getResult().getType();
+ bool isScalable = retTy.isScalable();
+
+ // Check every mask operand
+ for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
+ // Most basic case - this operand is a constant value. Note that for
+ // scalable dimensions, CreateMaskOp can be folded only if the
+ // corresponding operand is negative or zero.
+ if (auto op = getConstantIntValue(operand)) {
+ APInt intVal;
+ if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) ||
+ intVal.isStrictlyPositive()))
+ return failure();
- // CreateMaskOp for scalable vectors can be folded only if all dimensions
- // are negative or zero.
- if (auto vType = llvm::dyn_cast<VectorType>(createMaskOp.getType())) {
- if (vType.isScalable())
- for (auto opDim : createMaskOp.getOperands()) {
- APInt intVal;
- if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
- intVal.isStrictlyPositive())
- return failure();
- }
+ continue;
+ }
+
+ // Non-constant operands are not allowed for non-scalable vectors.
+ if (!isScalable)
+ return failure();
+
+ // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
+ // true" mask, so can also be treated as constant.
+ auto mul = llvm::dyn_cast_or_null<arith::MulIOp>(operand.getDefiningOp());
+ if (!mul)
+ return failure();
+ auto mulLHS = mul.getOperands()[0];
+ auto mulRHS = mul.getOperands()[1];
+ bool isOneOpVscale =
+ (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
+ isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
+
+ auto isConstantValMatchingDim =
+ [=, dim = createMaskOp.getResult().getType().getShape()[opIdx]](
+ Value operand) {
+ auto constantVal = getConstantIntValue(operand);
+ return (constantVal.has_value() && constantVal.value() == dim);
+ };
+
+ bool isOneOpConstantMatchingDim =
+ isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
+
+ if (!isOneOpVscale || !isOneOpConstantMatchingDim)
+ return failure();
}
// Gather constant mask dimension sizes.
@@ -5688,15 +5737,22 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
maskDimSizes.reserve(createMaskOp->getNumOperands());
for (auto [operand, maxDimSize] : llvm::zip_equal(
createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
- int64_t dimSize = getConstantIntValue(operand).value();
- dimSize = std::min(dimSize, maxDimSize);
+ auto dimSize = getConstantIntValue(operand);
+ if (not dimSize) {
+ // Although not a constant, it is safe to assume that `operand` is
+ // "vscale * maxDimSize".
+ maskDimSizes.push_back(maxDimSize);
+ continue;
+ }
+ int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
// If one of dim sizes is zero, set all dims to zero.
if (dimSize <= 0) {
maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
break;
}
- maskDimSizes.push_back(dimSize);
+ maskDimSizes.push_back(dimSizeVal);
}
+
// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
createMaskOp, createMaskOp.getResult().getType(),
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d34..a30016ea857d97 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -53,6 +53,19 @@ func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3x
// -----
+// CHECK-LABEL: create_vector_mask_to_constant_mask_scalable_all_true
+func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x[16]xi1>) {
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %0 = vector.vscale
+ %1 = arith.muli %0, %c16 : index
+ // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
+ %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
+ return %10 : vector<8x[16]xi1>
+}
+
+// -----
+
// CHECK-LABEL: create_mask_transpose_to_transposed_create_mask
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
func.func @create_mask_transpose_to_transposed_create_mask(
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some API suggestions.
One thing that I didn't see tested is what happens when the operand is a constant larger than the dim size. I don't know the semantics of vector.create_maks
, but I would expect this to fail verification with attributes but not with constant operands. Folds should, in general, not introduce verifier errors by making the information more local.
Address comments from Jakub
Thanks for taking a look Jakub!
For non-scalable dims:
For scalable dims (from https://mlir.llvm.org/docs/Dialects/Vector/#vectorconstant_mask-vectorconstantmaskop):
I check for that here: https://github.com/llvm/llvm-project/pull/75842/files#diff-c62b57552386a2a552ce6e3fe37bc23d399f2508636851b77ec6cd47fee906afR5728-R5729 I could add a negative test, but didn't see any in the test file 🤔 . |
I'm not sure if we should be fixing broken IR like that. Would be interesting to have a general discussion about this on discourse. |
That IR is valid (from docs):
Btw, this PR doesn't change the semantics of |
Ah, I missed this part of the spec. Now that you have highlighted it makes perfect sense. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) || | ||
intVal.isStrictlyPositive())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the RHS is equivalent to !(matchPattern(operand, m_ConstantInt(&intVal)) && !intVal.isStrictlyPositive())
, it doesn't make sense to check for <= 0 if it's not a constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The compiler should be able to figure these things out. That does seem to be the case:
Btw, I will rewrite it so that it's more descriptive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the point is the call to intVal.isStrictlyPositive()
does not make sense in this check.
Because if matchPattern(operand, m_ConstantInt(&intVal))
is:
true
- then
intVal.isStrictlyPositive()
won't be called
- then
false
- then
intVal.isStrictlyPositive()
will be called - but it does not make sense to call
intVal.isStrictlyPositive()
on something that's not an constant integer
- then
Address comments from Cullen
Fix how scalable dims are treated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left one suggestion but otherwise LGTM, cheers
Refine based on Cullen's suggestion
@MacDue The CI failures are unrelated to this change - are you OK with me landing this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, LGTM
Move matmul traits
Re-order matvec tests so that the one without masking is always first
Following on from llvm#75842, we can demonstrate that loop peeling combined with masked vectorisation and existing canonicalization for vector.mask operations leads to the following loop structure: ``` // M dimension scf.for 1:M // N dimension (contains vector ops _without_ masking) scf.for 1:UB // K dimension scf.for 1:K vector.add // N dimension (contains vector ops _with_ masking) scf.for UB:N // K dimension scf.for 1:K vector.mask {vector.add } ``` This is particularly beneficial for scalable vectors which normally require masking. This example demonstrates how to avoid them.
) Following on from #75842, we can demonstrate that loop peeling combined with masked vectorisation and existing canonicalization for vector.mask operations leads to the following loop structure: ``` // M dimension scf.for 1:M // N dimension (contains vector ops _without_ masking) scf.for 1:UB // K dimension scf.for 1:K vector.add // N dimension (contains vector ops _with_ masking) scf.for UB:N // K dimension scf.for 1:K vector.mask { vector.add } ``` This is particularly beneficial for scalable vectors which normally require masking. This example demonstrates how to avoid them.
…m#77590) Following on from llvm#75842, we can demonstrate that loop peeling combined with masked vectorisation and existing canonicalization for vector.mask operations leads to the following loop structure: ``` // M dimension scf.for 1:M // N dimension (contains vector ops _without_ masking) scf.for 1:UB // K dimension scf.for 1:K vector.add // N dimension (contains vector ops _with_ masking) scf.for UB:N // K dimension scf.for 1:K vector.mask { vector.add } ``` This is particularly beneficial for scalable vectors which normally require masking. This example demonstrates how to avoid them.
Extends
CreateMaskFolder
pattern so that the following:is folded as: