-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory #80170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesWhen unrolling the reduction dimension of something like a matmul for SME, you can end up with transposed reads of illegal types, like so: %illegalRead = vector.transfer_read %memref[%a, %b]
: memref<?x?xf32>, vector<[8]x4xf32>
%legalType = vector.transpose %illegalRead, [1, 0]
: vector<[8]x4xf32> to vector<4x[8]xf32> Here the %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
: memref<?x?xf32> to memref<?x?xf32>
%transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
: memref<?x?xf32> to memref<?x?xf32>
%legalType = vector.transfer_read %transpose[%c0, %c0]
: memref<?x?xf32>, vector<4x[8]xf32> Full diff: https://github.com/llvm/llvm-project/pull/80170.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 85ec53c2618aa..a3db2d2395528 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -7,8 +7,6 @@
//===----------------------------------------------------------------------===//
//
// This pass legalizes vector operations so they can be lowered to ArmSME.
-// Currently, this only implements the decomposition of vector operations that
-// use vector sizes larger than an SME tile, into multiple SME-sized operations.
//
// Note: In the context of this pass 'tile' always refers to an SME tile.
//
@@ -19,6 +17,7 @@
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
@@ -35,6 +34,10 @@ using namespace mlir::arm_sme;
namespace {
+//===----------------------------------------------------------------------===//
+// Decomposition of vector operations larger than an SME tile
+//===----------------------------------------------------------------------===//
+
// Common match failure reasons.
static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
"op vector size is not multiple of SME tiles");
@@ -338,13 +341,166 @@ struct LegalizeTransferWriteOpsByDecomposition
}
};
+//===----------------------------------------------------------------------===//
+// ArmSME-specific fixup canonicalizations/folds
+//===----------------------------------------------------------------------===//
+
+/// Lifts an illegal vector.transpose and vector.transfer_read to a
+/// memref.subview + memref.transpose, followed by a legal read.
+///
+/// 'Illegal' here means a leading scalable dimension and a fixed trailing
+/// dimension, which has no valid lowering.
+///
+/// The memref.transpose is metadata-only transpose that produces a strided
+/// memref, which eventually becomes a loop reading individual elements.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %illegalRead = vector.transfer_read %memref[%a, %b]
+/// : memref<?x?xf32>, vector<[8]x4xf32>
+/// %legalType = vector.transpose %illegalRead, [1, 0]
+/// : vector<[8]x4xf32> to vector<4x[8]xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
+/// : memref<?x?xf32> to memref<?x?xf32>
+/// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
+/// : memref<?x?xf32> to memref<?x?xf32>
+/// %legalType = vector.transfer_read %transpose[%c0, %c0]
+/// : memref<?x?xf32>, vector<4x[8]xf32>
+/// ```
+struct LiftIllegalVectorTransposeToMemory
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+ static bool isIllegalVectorType(VectorType vType) {
+ bool seenFixedDim = false;
+ for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+ seenFixedDim |= !scalableFlag;
+ if (seenFixedDim && scalableFlag)
+ return true;
+ }
+ return false;
+ }
+
+ static Value getExtensionSource(Operation *op) {
+ if (auto signExtend = dyn_cast<arith::ExtSIOp>(op))
+ return signExtend.getIn();
+ if (auto zeroExtend = dyn_cast<arith::ExtUIOp>(op))
+ return zeroExtend.getIn();
+ if (auto floatExtend = dyn_cast<arith::ExtFOp>(op))
+ return floatExtend.getIn();
+ return {};
+ }
+
+ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = transposeOp.getSourceVectorType();
+ auto resultType = transposeOp.getResultVectorType();
+ if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(
+ transposeOp, "expected transpose from illegal type to legal type");
+
+ Value maybeRead = transposeOp.getVector();
+ auto *transposeSourceOp = maybeRead.getDefiningOp();
+ Operation *extendOp = nullptr;
+ if (Value extendSource = getExtensionSource(transposeSourceOp)) {
+ maybeRead = extendSource;
+ extendOp = transposeSourceOp;
+ }
+
+ auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
+ if (!illegalRead)
+ return rewriter.notifyMatchFailure(
+ transposeOp,
+ "expected source to be (possibility extended) transfer_read");
+
+ if (!illegalRead.getPermutationMap().isIdentity())
+ return rewriter.notifyMatchFailure(
+ illegalRead, "expected read to have identity permutation map");
+
+ auto loc = transposeOp.getLoc();
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+ // Create a subview that matches the size of the illegal read vector type.
+ auto readType = illegalRead.getVectorType();
+ auto readSizes = llvm::map_to_vector(
+ llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
+ [&](auto dim) -> Value {
+ auto [size, isScalable] = dim;
+ auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
+ if (!isScalable)
+ return dimSize;
+ auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
+ });
+ SmallVector<Value> strides(readType.getRank(), Value(one));
+ auto readSubview = rewriter.create<memref::SubViewOp>(
+ loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
+ strides);
+
+ // Apply the transpose to all values/attributes of the transfer_read.
+ // The mask.
+ Value mask = illegalRead.getMask();
+ if (mask) {
+ // Note: The transpose for the mask should fold into the
+ // vector.create_mask/constant_mask op, which will then become legal.
+ mask = rewriter.create<vector::TransposeOp>(loc, mask,
+ transposeOp.getPermutation());
+ }
+ // The source memref.
+ mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
+ transposeOp.getPermutation(), getContext());
+ auto transposedSubview = rewriter.create<memref::TransposeOp>(
+ loc, readSubview, AffineMapAttr::get(transposeMap));
+ ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
+ // The `in_bounds` attribute.
+ if (inBoundsAttr) {
+ SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
+ inBoundsAttr.end());
+ applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
+ inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
+ }
+
+ VectorType legalReadType =
+ VectorType::Builder(resultType)
+ .setElementType(illegalRead.getVectorType().getElementType());
+ // Note: The indices are all zero as the subview is already offset.
+ SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
+ Value legalRead = rewriter.create<vector::TransferReadOp>(
+ loc, legalReadType, transposedSubview, readIndices,
+ illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
+ inBoundsAttr);
+
+ // Replace the transpose with the new read, extending the result if
+ // necessary.
+ rewriter.replaceOp(transposeOp, [&]() -> Value {
+ if (!extendOp)
+ return legalRead;
+ if (isa<arith::ExtSIOp>(extendOp))
+ return rewriter.create<arith::ExtSIOp>(loc, resultType, legalRead);
+ if (isa<arith::ExtUIOp>(extendOp))
+ return rewriter.create<arith::ExtUIOp>(loc, resultType, legalRead);
+ if (isa<arith::ExtFOp>(extendOp))
+ return rewriter.create<arith::ExtFOp>(loc, resultType, legalRead);
+ return legalRead;
+ }());
+
+ return success();
+ }
+};
+
struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
auto *context = &getContext();
OneToNTypeConverter converter;
RewritePatternSet patterns(context);
-
converter.addConversion([](Type type) { return type; });
converter.addConversion(
[](VectorType vectorType,
@@ -358,6 +514,7 @@ struct VectorLegalizationPass
return success();
});
+ patterns.add<LiftIllegalVectorTransposeToMemory>(context);
// Note: High benefit to ensure masked outer products are lowered first.
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index a20abeefedcfd..2317930d3d061 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -266,3 +266,78 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_no_mask(
+// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+ // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+ // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+ // CHECK-NEXT: return %[[LEGAL_READ]]
+ %pad = arith.constant 0.0 : f32
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xf32>, vector<[8]x4xf32>
+ %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+ return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory(
+// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index, %dim1: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+ // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+ // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
+ // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]], %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+ // CHECK-NEXT: return %[[LEGAL_READ]]
+ %pad = arith.constant 0.0 : f32
+ %mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad, %mask : memref<?x?xf32>, vector<[8]x4xf32>
+ %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+ return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop(
+// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>)
+func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+ // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xi8> to memref<?x4xi8, strided<[?, 1], offset: ?>>
+ // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xi8, strided<[?, 1], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
+ // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xi8, strided<[?, ?], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
+ // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_I8]] : memref<?x?xi8, strided<[?, ?], offset: ?>>, vector<4x[8]xi8>
+ // CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
+ // CHECK-NEXT: return %[[EXT_TYPE]]
+ %pad = arith.constant 0 : i8
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xi8>, vector<[8]x4xi8>
+ %extRead = arith.extsi %illegalRead : vector<[8]x4xi8> to vector<[8]x4xi32>
+ %legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32>
+ return %legalType : vector<4x[8]xi32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
When unrolling the reduction dimension of something like a matmul for SME, you can end up with transposed reads of illegal types, like so: ```mlir %illegalRead = vector.transfer_read %memref[%a, %b] : memref<?x?xf32>, vector<[8]x4xf32> %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32> ``` Here the `vector<[8]x4xf32>` is an illegal type, there's no way to lower a scalable vector of fixed vectors. However, as the final type `vector<4x[8]xf32>` is legal, we can instead lift the transpose to memory (producing a strided memref), and eliminate all the illegal types. This is shown below. ```mlir %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1] : memref<?x?xf32> to memref<?x?xf32> %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32> %legalType = vector.transfer_read %transpose[%c0, %c0] : memref<?x?xf32>, vector<4x[8]xf32> ```
ab8c3d7
to
08dd0e4
Compare
…mlir This tests both llvm#80148 and llvm#80170 work together to allow unrolling the reduction dimension of a matmul.
When unrolling the reduction dimension of something like a matmul for SME, you can end up with transposed reads of illegal types, like so:
Here the
vector<[8]x4xf32>
is an illegal type, there's no way to lower a scalable vector of fixed vectors. However, as the final typevector<4x[8]xf32>
is legal, we can instead lift the transpose to memory (producing a strided memref), and eliminate all the illegal types. This is shown below.