Skip to content

[mlir][vector] Extend CreateMaskFolder #75842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 76 additions & 20 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5657,46 +5657,102 @@ LogicalResult CreateMaskOp::verify() {

namespace {

// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
///
/// Ex 1:
/// %c2 = arith.constant 2 : index
/// %c3 = arith.constant 3 : index
/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
/// Becomes:
/// vector.constant_mask [3, 2] : vector<4x3xi1>
///
/// Ex 2:
/// %c_neg_1 = arith.constant -1 : index
/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
/// becomes:
/// vector.constant_mask [0] : vector<[8]xi1>
///
/// Ex 3:
/// %c8 = arith.constant 8 : index
/// %c16 = arith.constant 16 : index
/// %0 = vector.vscale
/// %1 = arith.muli %0, %c16 : index
/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
/// becomes:
/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
// Return if any of 'createMaskOp' operands are not defined by a constant.
auto isNotDefByConstant = [](Value operand) {
return !getConstantIntValue(operand).has_value();
};
if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
return failure();
VectorType retTy = createMaskOp.getResult().getType();
bool isScalable = retTy.isScalable();

// Check every mask operand
for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
// Most basic case - this operand is a constant value. Note that for
// scalable dimensions, CreateMaskOp can be folded only if the
// corresponding operand is negative or zero.
if (getConstantIntValue(operand)) {
APInt intVal;
if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) ||
intVal.isStrictlyPositive()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the RHS is equivalent to !(matchPattern(operand, m_ConstantInt(&intVal)) && !intVal.isStrictlyPositive()), it doesn't make sense to check for <= 0 if it's not a constant?

Copy link
Contributor Author

@banach-space banach-space Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler should be able to figure these things out. That does seem to be the case:

Btw, I will rewrite it so that it's more descriptive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the point is the call to intVal.isStrictlyPositive() does not make sense in this check.

Because if matchPattern(operand, m_ConstantInt(&intVal)) is:

  • true
    • then intVal.isStrictlyPositive() won't be called
  • false
    • then intVal.isStrictlyPositive() will be called
    • but it does not make sense to call intVal.isStrictlyPositive() on something that's not an constant integer

return failure();

// CreateMaskOp for scalable vectors can be folded only if all dimensions
// are negative or zero.
if (auto vType = llvm::dyn_cast<VectorType>(createMaskOp.getType())) {
if (vType.isScalable())
for (auto opDim : createMaskOp.getOperands()) {
APInt intVal;
if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
intVal.isStrictlyPositive())
return failure();
}
continue;
}

// Non-constant operands are not allowed for non-scalable vectors.
if (!isScalable)
return failure();

// For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
// true" mask, so can also be treated as constant.
auto mul = operand.getDefiningOp<arith::MulIOp>();
if (!mul)
return failure();
auto mulLHS = mul.getRhs();
auto mulRHS = mul.getLhs();
bool isOneOpVscale =
(isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));

auto isConstantValMatchingDim =
[=, dim = createMaskOp.getResult().getType().getShape()[opIdx]](
Value operand) {
auto constantVal = getConstantIntValue(operand);
return (constantVal.has_value() && constantVal.value() == dim);
};

bool isOneOpConstantMatchingDim =
isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);

if (!isOneOpVscale || !isOneOpConstantMatchingDim)
return failure();
}

// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
maskDimSizes.reserve(createMaskOp->getNumOperands());
for (auto [operand, maxDimSize] : llvm::zip_equal(
createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
int64_t dimSize = getConstantIntValue(operand).value();
dimSize = std::min(dimSize, maxDimSize);
std::optional dimSize = getConstantIntValue(operand);
if (!dimSize) {
// Although not a constant, it is safe to assume that `operand` is
// "vscale * maxDimSize".
maskDimSizes.push_back(maxDimSize);
continue;
}
int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
// If one of dim sizes is zero, set all dims to zero.
if (dimSize <= 0) {
maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
break;
}
maskDimSizes.push_back(dimSize);
maskDimSizes.push_back(dimSizeVal);
}

// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
createMaskOp, createMaskOp.getResult().getType(),
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3x

// -----

// CHECK-LABEL: create_vector_mask_to_constant_mask_scalable_all_true
func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x[16]xi1>) {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%0 = vector.vscale
%1 = arith.muli %0, %c16 : index
// CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
%10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
return %10 : vector<8x[16]xi1>
}

// -----

// CHECK-LABEL: create_mask_transpose_to_transposed_create_mask
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
func.func @create_mask_transpose_to_transposed_create_mask(
Expand Down