From 87149572f24ef07660328e147873329cb1b3154d Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Mon, 18 Dec 2023 18:43:08 +0000 Subject: [PATCH 1/5] [mlir][vector] Extend `CreateMaskFolder` Extends `CreateMaskFolder` pattern so that the following: ```mlir %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %0 = vector.vscale %1 = arith.muli %0, %c16 : index %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1> ``` is folded as: ```mlir %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> ``` --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 96 +++++++++++++++++----- mlir/test/Dialect/Vector/canonicalize.mlir | 13 +++ 2 files changed, 89 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 540959b486db9..3619c1c00f166 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5657,30 +5657,79 @@ LogicalResult CreateMaskOp::verify() { namespace { -// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. +/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. +/// +/// Ex 1: +/// %c2 = arith.constant 2 : index +/// %c3 = arith.constant 3 : index +/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1> +/// Becomes: +/// vector.constant_mask [3, 2] : vector<4x3xi1> +/// +/// Ex 2: +/// %c_neg_1 = arith.constant -1 : index +/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1> +/// becomes: +/// vector.constant_mask [0] : vector<[8]xi1> +/// +/// Ex 3: +/// %c8 = arith.constant 8 : index +/// %c16 = arith.constant 16 : index +/// %0 = vector.vscale +/// %1 = arith.muli %0, %c16 : index +/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1> +/// becomes: +/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> class CreateMaskFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, PatternRewriter &rewriter) const override { - // Return if any of 'createMaskOp' operands are not defined by a constant. - auto isNotDefByConstant = [](Value operand) { - return !getConstantIntValue(operand).has_value(); - }; - if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant)) - return failure(); + VectorType retTy = createMaskOp.getResult().getType(); + bool isScalable = retTy.isScalable(); + + // Check every mask operand + for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) { + // Most basic case - this operand is a constant value. Note that for + // scalable dimensions, CreateMaskOp can be folded only if the + // corresponding operand is negative or zero. + if (auto op = getConstantIntValue(operand)) { + APInt intVal; + if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) || + intVal.isStrictlyPositive())) + return failure(); - // CreateMaskOp for scalable vectors can be folded only if all dimensions - // are negative or zero. - if (auto vType = llvm::dyn_cast(createMaskOp.getType())) { - if (vType.isScalable()) - for (auto opDim : createMaskOp.getOperands()) { - APInt intVal; - if (matchPattern(opDim, m_ConstantInt(&intVal)) && - intVal.isStrictlyPositive()) - return failure(); - } + continue; + } + + // Non-constant operands are not allowed for non-scalable vectors. + if (!isScalable) + return failure(); + + // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all + // true" mask, so can also be treated as constant. + auto mul = llvm::dyn_cast_or_null(operand.getDefiningOp()); + if (!mul) + return failure(); + auto mulLHS = mul.getOperands()[0]; + auto mulRHS = mul.getOperands()[1]; + bool isOneOpVscale = + (isa(mulLHS.getDefiningOp()) || + isa(mulRHS.getDefiningOp())); + + auto isConstantValMatchingDim = + [=, dim = createMaskOp.getResult().getType().getShape()[opIdx]]( + Value operand) { + auto constantVal = getConstantIntValue(operand); + return (constantVal.has_value() && constantVal.value() == dim); + }; + + bool isOneOpConstantMatchingDim = + isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS); + + if (!isOneOpVscale || !isOneOpConstantMatchingDim) + return failure(); } // Gather constant mask dimension sizes. @@ -5688,15 +5737,22 @@ class CreateMaskFolder final : public OpRewritePattern { maskDimSizes.reserve(createMaskOp->getNumOperands()); for (auto [operand, maxDimSize] : llvm::zip_equal( createMaskOp.getOperands(), createMaskOp.getType().getShape())) { - int64_t dimSize = getConstantIntValue(operand).value(); - dimSize = std::min(dimSize, maxDimSize); + auto dimSize = getConstantIntValue(operand); + if (not dimSize) { + // Although not a constant, it is safe to assume that `operand` is + // "vscale * maxDimSize". + maskDimSizes.push_back(maxDimSize); + continue; + } + int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize); // If one of dim sizes is zero, set all dims to zero. if (dimSize <= 0) { maskDimSizes.assign(createMaskOp.getType().getRank(), 0); break; } - maskDimSizes.push_back(dimSize); + maskDimSizes.push_back(dimSizeVal); } + // Replace 'createMaskOp' with ConstantMaskOp. rewriter.replaceOpWithNewOp( createMaskOp, createMaskOp.getResult().getType(), diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 1021c73cc57d3..a30016ea857d9 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -53,6 +53,19 @@ func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3x // ----- +// CHECK-LABEL: create_vector_mask_to_constant_mask_scalable_all_true +func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x[16]xi1>) { + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %0 = vector.vscale + %1 = arith.muli %0, %c16 : index + // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1> + %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1> + return %10 : vector<8x[16]xi1> +} + +// ----- + // CHECK-LABEL: create_mask_transpose_to_transposed_create_mask // CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index func.func @create_mask_transpose_to_transposed_create_mask( From 60ec419f816d777e8fa10b93ec972b7462d2df6e Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Mon, 18 Dec 2023 20:22:54 +0000 Subject: [PATCH 2/5] fixup! Use StringRef::{starts,ends}_with (NFC) Address comments from Jakub --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 3619c1c00f166..5eef2f4c3271b 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5694,7 +5694,7 @@ class CreateMaskFolder final : public OpRewritePattern { // Most basic case - this operand is a constant value. Note that for // scalable dimensions, CreateMaskOp can be folded only if the // corresponding operand is negative or zero. - if (auto op = getConstantIntValue(operand)) { + if (getConstantIntValue(operand)) { APInt intVal; if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) || intVal.isStrictlyPositive())) @@ -5709,11 +5709,11 @@ class CreateMaskFolder final : public OpRewritePattern { // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all // true" mask, so can also be treated as constant. - auto mul = llvm::dyn_cast_or_null(operand.getDefiningOp()); + auto mul = operand.getDefiningOp(); if (!mul) return failure(); - auto mulLHS = mul.getOperands()[0]; - auto mulRHS = mul.getOperands()[1]; + auto mulLHS = mul.getRhs(); + auto mulRHS = mul.getLhs(); bool isOneOpVscale = (isa(mulLHS.getDefiningOp()) || isa(mulRHS.getDefiningOp())); @@ -5737,8 +5737,8 @@ class CreateMaskFolder final : public OpRewritePattern { maskDimSizes.reserve(createMaskOp->getNumOperands()); for (auto [operand, maxDimSize] : llvm::zip_equal( createMaskOp.getOperands(), createMaskOp.getType().getShape())) { - auto dimSize = getConstantIntValue(operand); - if (not dimSize) { + std::optional dimSize = getConstantIntValue(operand); + if (!dimSize) { // Although not a constant, it is safe to assume that `operand` is // "vscale * maxDimSize". maskDimSizes.push_back(maxDimSize); From 08cd6617ef412746088424e94da9ff575d4218d9 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 19 Dec 2023 12:56:51 +0000 Subject: [PATCH 3/5] fixup! [mlir][vector] Extend `CreateMaskFolder` Address comments from Cullen --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 5eef2f4c3271b..66da2e6ac4a1f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5719,8 +5719,7 @@ class CreateMaskFolder final : public OpRewritePattern { isa(mulRHS.getDefiningOp())); auto isConstantValMatchingDim = - [=, dim = createMaskOp.getResult().getType().getShape()[opIdx]]( - Value operand) { + [=, dim = retTy.getShape()[opIdx]](Value operand) { auto constantVal = getConstantIntValue(operand); return (constantVal.has_value() && constantVal.value() == dim); }; @@ -5755,7 +5754,7 @@ class CreateMaskFolder final : public OpRewritePattern { // Replace 'createMaskOp' with ConstantMaskOp. rewriter.replaceOpWithNewOp( - createMaskOp, createMaskOp.getResult().getType(), + createMaskOp, retTy, vector::getVectorSubscriptAttr(rewriter, maskDimSizes)); return success(); } From 1e0d190d884572e3120d9a3dcb6133156db2b9ed Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 19 Dec 2023 17:01:33 +0000 Subject: [PATCH 4/5] fixup! [mlir][vector] Extend `CreateMaskFolder` Fix how scalable dims are treated --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 66da2e6ac4a1f..0e584ae2b19a0 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5691,13 +5691,14 @@ class CreateMaskFolder final : public OpRewritePattern { // Check every mask operand for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) { - // Most basic case - this operand is a constant value. Note that for - // scalable dimensions, CreateMaskOp can be folded only if the - // corresponding operand is negative or zero. if (getConstantIntValue(operand)) { + // Most basic case - this operand is a constant value. Note that for + // scalable dimensions, CreateMaskOp can be folded only if the + // corresponding operand is negative or zero. APInt intVal; - if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) || - intVal.isStrictlyPositive())) + if (retTy.getScalableDims()[opIdx] && + (!matchPattern(operand, m_ConstantInt(&intVal)) || + intVal.isStrictlyPositive())) return failure(); continue; From 521c45a02404126409a29395020bbbd049def9c3 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 19 Dec 2023 21:00:31 +0000 Subject: [PATCH 5/5] fixup! [mlir][vector] Extend `CreateMaskFolder` Refine based on Cullen's suggestion --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 0e584ae2b19a0..c7092ac465c3c 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5691,14 +5691,11 @@ class CreateMaskFolder final : public OpRewritePattern { // Check every mask operand for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) { - if (getConstantIntValue(operand)) { + if (auto cst = getConstantIntValue(operand)) { // Most basic case - this operand is a constant value. Note that for // scalable dimensions, CreateMaskOp can be folded only if the // corresponding operand is negative or zero. - APInt intVal; - if (retTy.getScalableDims()[opIdx] && - (!matchPattern(operand, m_ConstantInt(&intVal)) || - intVal.isStrictlyPositive())) + if (retTy.getScalableDims()[opIdx] && *cst > 0) return failure(); continue;