Skip to content

[mlir][vector] Extend CreateMaskFolder #75842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 20, 2023

Conversation

banach-space
Copy link
Contributor

Extends CreateMaskFolder pattern so that the following:

  %c8 = arith.constant 8 : index
  %c16 = arith.constant 16 : index
  %0 = vector.vscale
  %1 = arith.muli %0, %c16 : index
  %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>

is folded as:

  %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>

Extends `CreateMaskFolder` pattern so that the following:
```mlir
  %c8 = arith.constant 8 : index
  %c16 = arith.constant 16 : index
  %0 = vector.vscale
  %1 = arith.muli %0, %c16 : index
  %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
```

is folded as:

```mlir
  %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
```
@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Extends CreateMaskFolder pattern so that the following:

  %c8 = arith.constant 8 : index
  %c16 = arith.constant 16 : index
  %0 = vector.vscale
  %1 = arith.muli %0, %c16 : index
  %10 = vector.create_mask %c8, %1 : vector&lt;8x[16]xi1&gt;

is folded as:

  %0 = vector.constant_mask [8, 16] : vector&lt;8x[16]xi1&gt;

Full diff: https://github.com/llvm/llvm-project/pull/75842.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+76-20)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+13)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..3619c1c00f1664 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5657,30 +5657,79 @@ LogicalResult CreateMaskOp::verify() {
 
 namespace {
 
-// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+///
+/// Ex 1:
+///   %c2 = arith.constant 2 : index
+///   %c3 = arith.constant 3 : index
+///   %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
+/// Becomes:
+///    vector.constant_mask [3, 2] : vector<4x3xi1>
+///
+/// Ex 2:
+///   %c_neg_1 = arith.constant -1 : index
+///   %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
+/// becomes:
+///   vector.constant_mask [0] : vector<[8]xi1>
+///
+/// Ex 3:
+///   %c8 = arith.constant 8 : index
+///   %c16 = arith.constant 16 : index
+///   %0 = vector.vscale
+///   %1 = arith.muli %0, %c16 : index
+///   %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
+/// becomes:
+///   %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
                                 PatternRewriter &rewriter) const override {
-    // Return if any of 'createMaskOp' operands are not defined by a constant.
-    auto isNotDefByConstant = [](Value operand) {
-      return !getConstantIntValue(operand).has_value();
-    };
-    if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
-      return failure();
+    VectorType retTy = createMaskOp.getResult().getType();
+    bool isScalable = retTy.isScalable();
+
+    // Check every mask operand
+    for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
+      // Most basic case - this operand is a constant value. Note that for
+      // scalable dimensions, CreateMaskOp can be folded only if the
+      // corresponding operand is negative or zero.
+      if (auto op = getConstantIntValue(operand)) {
+        APInt intVal;
+        if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) ||
+                            intVal.isStrictlyPositive()))
+          return failure();
 
-    // CreateMaskOp for scalable vectors can be folded only if all dimensions
-    // are negative or zero.
-    if (auto vType = llvm::dyn_cast<VectorType>(createMaskOp.getType())) {
-      if (vType.isScalable())
-        for (auto opDim : createMaskOp.getOperands()) {
-          APInt intVal;
-          if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
-              intVal.isStrictlyPositive())
-            return failure();
-        }
+        continue;
+      }
+
+      // Non-constant operands are not allowed for non-scalable vectors.
+      if (!isScalable)
+        return failure();
+
+      // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
+      // true" mask, so can also be treated as constant.
+      auto mul = llvm::dyn_cast_or_null<arith::MulIOp>(operand.getDefiningOp());
+      if (!mul)
+        return failure();
+      auto mulLHS = mul.getOperands()[0];
+      auto mulRHS = mul.getOperands()[1];
+      bool isOneOpVscale =
+          (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
+           isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
+
+      auto isConstantValMatchingDim =
+          [=, dim = createMaskOp.getResult().getType().getShape()[opIdx]](
+              Value operand) {
+            auto constantVal = getConstantIntValue(operand);
+            return (constantVal.has_value() && constantVal.value() == dim);
+          };
+
+      bool isOneOpConstantMatchingDim =
+          isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
+
+      if (!isOneOpVscale || !isOneOpConstantMatchingDim)
+        return failure();
     }
 
     // Gather constant mask dimension sizes.
@@ -5688,15 +5737,22 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
     maskDimSizes.reserve(createMaskOp->getNumOperands());
     for (auto [operand, maxDimSize] : llvm::zip_equal(
              createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
-      int64_t dimSize = getConstantIntValue(operand).value();
-      dimSize = std::min(dimSize, maxDimSize);
+      auto dimSize = getConstantIntValue(operand);
+      if (not dimSize) {
+        // Although not a constant, it is safe to assume that `operand` is
+        // "vscale * maxDimSize".
+        maskDimSizes.push_back(maxDimSize);
+        continue;
+      }
+      int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
       // If one of dim sizes is zero, set all dims to zero.
       if (dimSize <= 0) {
         maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
         break;
       }
-      maskDimSizes.push_back(dimSize);
+      maskDimSizes.push_back(dimSizeVal);
     }
+
     // Replace 'createMaskOp' with ConstantMaskOp.
     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
         createMaskOp, createMaskOp.getResult().getType(),
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d34..a30016ea857d97 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -53,6 +53,19 @@ func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3x
 
 // -----
 
+// CHECK-LABEL: create_vector_mask_to_constant_mask_scalable_all_true
+func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x[16]xi1>) {
+  %c8 = arith.constant 8 : index
+  %c16 = arith.constant 16 : index
+  %0 = vector.vscale
+  %1 = arith.muli %0, %c16 : index
+  // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
+  %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
+  return %10 : vector<8x[16]xi1>
+}
+
+// -----
+
 // CHECK-LABEL: create_mask_transpose_to_transposed_create_mask
 //  CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
 func.func @create_mask_transpose_to_transposed_create_mask(

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some API suggestions.

One thing that I didn't see tested is what happens when the operand is a constant larger than the dim size. I don't know the semantics of vector.create_maks, but I would expect this to fail verification with attributes but not with constant operands. Folds should, in general, not introduce verifier errors by making the information more local.

@banach-space
Copy link
Contributor Author

Thanks for taking a look Jakub!

One thing that I didn't see tested is what happens when the operand is a constant larger than the dim size.

For non-scalable dims:

  • // CHECK-LABEL: create_vector_mask_to_constant_mask_truncation
    func.func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>) {
    %c2 = arith.constant 2 : index
    %c5 = arith.constant 5 : index
    // CHECK: vector.constant_mask [4, 2] : vector<4x3xi1>
    %0 = vector.create_mask %c5, %c2 : vector<4x3xi1>
    return %0 : vector<4x3xi1>
    }

For scalable dims (from https://mlir.llvm.org/docs/Dialects/Vector/#vectorconstant_mask-vectorconstantmaskop):

Sizes that correspond to scalable dimensions are implicitly multiplied by vscale, though currently only zero (none set) or the size of the dim/vscale (all set) are supported.

I check for that here: https://github.com/llvm/llvm-project/pull/75842/files#diff-c62b57552386a2a552ce6e3fe37bc23d399f2508636851b77ec6cd47fee906afR5728-R5729

I could add a negative test, but didn't see any in the test file 🤔 .

@kuhar
Copy link
Member

kuhar commented Dec 18, 2023

For non-scalable dims:

I'm not sure if we should be fixing broken IR like that. Would be interesting to have a general discussion about this on discourse.

@banach-space
Copy link
Contributor Author

For non-scalable dims:

I'm not sure if we should be fixing broken IR like that.

That IR is valid (from docs):

If operand-value is negative, it is treated as if it were zero, and if it is greater than the corresponding dimension size, it is treated as if it were equal to the dimension size.

Btw, this PR doesn't change the semantics of vector.create_mask.

@kuhar
Copy link
Member

kuhar commented Dec 18, 2023

If operand-value is negative, it is treated as if it were zero, and if it is greater than the corresponding dimension size, it is treated as if it were equal to the dimension size.

Btw, this PR doesn't change the semantics of vector.create_mask.

Ah, I missed this part of the spec. Now that you have highlighted it makes perfect sense. Thanks!

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 5699 to 5700
if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) ||
intVal.isStrictlyPositive()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the RHS is equivalent to !(matchPattern(operand, m_ConstantInt(&intVal)) && !intVal.isStrictlyPositive()), it doesn't make sense to check for <= 0 if it's not a constant?

Copy link
Contributor Author

@banach-space banach-space Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler should be able to figure these things out. That does seem to be the case:

Btw, I will rewrite it so that it's more descriptive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the point is the call to intVal.isStrictlyPositive() does not make sense in this check.

Because if matchPattern(operand, m_ConstantInt(&intVal)) is:

  • true
    • then intVal.isStrictlyPositive() won't be called
  • false
    • then intVal.isStrictlyPositive() will be called
    • but it does not make sense to call intVal.isStrictlyPositive() on something that's not an constant integer

Fix how scalable dims are treated
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left one suggestion but otherwise LGTM, cheers

Refine based on Cullen's suggestion
@banach-space
Copy link
Contributor Author

@MacDue The CI failures are unrelated to this change - are you OK with me landing this?

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, LGTM

@banach-space banach-space merged commit 354adb4 into llvm:main Dec 20, 2023
banach-space added a commit to banach-space/llvm-project that referenced this pull request Dec 20, 2023
banach-space added a commit to banach-space/llvm-project that referenced this pull request Dec 20, 2023
Re-order matvec tests so that the one without masking is always first
banach-space added a commit to banach-space/llvm-project that referenced this pull request Jan 10, 2024
Following on from llvm#75842, we can demonstrate that loop peeling combined
with masked vectorisation and existing canonicalization for vector.mask
operations leads to the following loop structure:

```
// M dimension
scf.for 1:M
  // N dimension (contains vector ops _without_ masking)
  scf.for 1:UB
    // K dimension
    scf.for 1:K
      vector.add

  // N dimension (contains vector ops _with_ masking)
  scf.for UB:N
    // K dimension
    scf.for 1:K
      vector.mask {vector.add }
```

This is particularly beneficial for scalable vectors which normally
require masking. This example demonstrates how to avoid them.
banach-space added a commit that referenced this pull request Jan 10, 2024
)

Following on from #75842, we can demonstrate that loop peeling combined
with masked vectorisation and existing canonicalization for vector.mask
operations leads to the following loop structure:

```
// M dimension
scf.for 1:M
  // N dimension (contains vector ops _without_ masking)
  scf.for 1:UB
    // K dimension
    scf.for 1:K
      vector.add

  // N dimension (contains vector ops _with_ masking)
  scf.for UB:N
    // K dimension
    scf.for 1:K
      vector.mask { vector.add }
```

This is particularly beneficial for scalable vectors which normally
require masking. This example demonstrates how to avoid them.
@banach-space banach-space deleted the andrzej/update_mask_folder branch January 16, 2024 18:46
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…m#77590)

Following on from llvm#75842, we can demonstrate that loop peeling combined
with masked vectorisation and existing canonicalization for vector.mask
operations leads to the following loop structure:

```
// M dimension
scf.for 1:M
  // N dimension (contains vector ops _without_ masking)
  scf.for 1:UB
    // K dimension
    scf.for 1:K
      vector.add

  // N dimension (contains vector ops _with_ masking)
  scf.for UB:N
    // K dimension
    scf.for 1:K
      vector.mask { vector.add }
```

This is particularly beneficial for scalable vectors which normally
require masking. This example demonstrates how to avoid them.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants