-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Extend CreateMaskFolder
#75842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
8714957
60ec419
08cd661
1e0d190
521c45a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5657,46 +5657,102 @@ LogicalResult CreateMaskOp::verify() { | |
|
||
namespace { | ||
|
||
// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. | ||
/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. | ||
/// | ||
/// Ex 1: | ||
/// %c2 = arith.constant 2 : index | ||
/// %c3 = arith.constant 3 : index | ||
/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1> | ||
/// Becomes: | ||
/// vector.constant_mask [3, 2] : vector<4x3xi1> | ||
/// | ||
/// Ex 2: | ||
/// %c_neg_1 = arith.constant -1 : index | ||
/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1> | ||
/// becomes: | ||
/// vector.constant_mask [0] : vector<[8]xi1> | ||
/// | ||
/// Ex 3: | ||
/// %c8 = arith.constant 8 : index | ||
/// %c16 = arith.constant 16 : index | ||
/// %0 = vector.vscale | ||
/// %1 = arith.muli %0, %c16 : index | ||
/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1> | ||
/// becomes: | ||
/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> | ||
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, | ||
PatternRewriter &rewriter) const override { | ||
// Return if any of 'createMaskOp' operands are not defined by a constant. | ||
auto isNotDefByConstant = [](Value operand) { | ||
return !getConstantIntValue(operand).has_value(); | ||
}; | ||
if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant)) | ||
return failure(); | ||
VectorType retTy = createMaskOp.getResult().getType(); | ||
bool isScalable = retTy.isScalable(); | ||
|
||
// Check every mask operand | ||
for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) { | ||
// Most basic case - this operand is a constant value. Note that for | ||
// scalable dimensions, CreateMaskOp can be folded only if the | ||
// corresponding operand is negative or zero. | ||
if (getConstantIntValue(operand)) { | ||
APInt intVal; | ||
if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) || | ||
intVal.isStrictlyPositive())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the RHS is equivalent to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The compiler should be able to figure these things out. That does seem to be the case: Btw, I will rewrite it so that it's more descriptive. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the point is the call to Because if
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return failure(); | ||
|
||
// CreateMaskOp for scalable vectors can be folded only if all dimensions | ||
// are negative or zero. | ||
if (auto vType = llvm::dyn_cast<VectorType>(createMaskOp.getType())) { | ||
if (vType.isScalable()) | ||
for (auto opDim : createMaskOp.getOperands()) { | ||
APInt intVal; | ||
if (matchPattern(opDim, m_ConstantInt(&intVal)) && | ||
intVal.isStrictlyPositive()) | ||
return failure(); | ||
} | ||
continue; | ||
} | ||
|
||
// Non-constant operands are not allowed for non-scalable vectors. | ||
if (!isScalable) | ||
return failure(); | ||
|
||
// For scalable vectors, "arith.muli %vscale, %dimSize" means an "all | ||
// true" mask, so can also be treated as constant. | ||
auto mul = operand.getDefiningOp<arith::MulIOp>(); | ||
if (!mul) | ||
return failure(); | ||
auto mulLHS = mul.getRhs(); | ||
auto mulRHS = mul.getLhs(); | ||
bool isOneOpVscale = | ||
(isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) || | ||
isa<vector::VectorScaleOp>(mulRHS.getDefiningOp())); | ||
|
||
auto isConstantValMatchingDim = | ||
[=, dim = createMaskOp.getResult().getType().getShape()[opIdx]]( | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Value operand) { | ||
auto constantVal = getConstantIntValue(operand); | ||
return (constantVal.has_value() && constantVal.value() == dim); | ||
}; | ||
|
||
bool isOneOpConstantMatchingDim = | ||
isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS); | ||
|
||
if (!isOneOpVscale || !isOneOpConstantMatchingDim) | ||
return failure(); | ||
} | ||
|
||
// Gather constant mask dimension sizes. | ||
SmallVector<int64_t, 4> maskDimSizes; | ||
maskDimSizes.reserve(createMaskOp->getNumOperands()); | ||
for (auto [operand, maxDimSize] : llvm::zip_equal( | ||
createMaskOp.getOperands(), createMaskOp.getType().getShape())) { | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
int64_t dimSize = getConstantIntValue(operand).value(); | ||
dimSize = std::min(dimSize, maxDimSize); | ||
std::optional dimSize = getConstantIntValue(operand); | ||
if (!dimSize) { | ||
// Although not a constant, it is safe to assume that `operand` is | ||
// "vscale * maxDimSize". | ||
maskDimSizes.push_back(maxDimSize); | ||
continue; | ||
} | ||
int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize); | ||
// If one of dim sizes is zero, set all dims to zero. | ||
if (dimSize <= 0) { | ||
maskDimSizes.assign(createMaskOp.getType().getRank(), 0); | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
break; | ||
} | ||
maskDimSizes.push_back(dimSize); | ||
maskDimSizes.push_back(dimSizeVal); | ||
} | ||
|
||
// Replace 'createMaskOp' with ConstantMaskOp. | ||
rewriter.replaceOpWithNewOp<ConstantMaskOp>( | ||
createMaskOp, createMaskOp.getResult().getType(), | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
Uh oh!
There was an error while loading. Please reload this page.