Skip to content

Commit 0473e32

Browse files
authored
[mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory (#80170)
When unrolling the reduction dimension of something like a matmul for SME, you can end up with transposed reads of illegal types, like so: ```mlir %illegalRead = vector.transfer_read %memref[%a, %b] : memref<?x?xf32>, vector<[8]x4xf32> %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32> ``` Here the `vector<[8]x4xf32>` is an illegal type, there's no way to lower a scalable vector of fixed vectors. However, as the final type `vector<4x[8]xf32>` is legal, we can instead lift the transpose to memory (producing a strided memref), and eliminate all the illegal types. This is shown below. ```mlir %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1] : memref<?x?xf32> to memref<?x?xf32> %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0) : memref<?x?xf32> to memref<?x?xf32> %legalType = vector.transfer_read %transpose[%c0, %c0] : memref<?x?xf32>, vector<4x[8]xf32> ```
1 parent 42b5b72 commit 0473e32

File tree

2 files changed

+218
-1
lines changed

2 files changed

+218
-1
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
2020
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
21+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2122
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
2223
#include "mlir/Dialect/Utils/IndexingUtils.h"
2324
#include "mlir/Transforms/OneToNTypeConversion.h"
@@ -415,6 +416,146 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
415416
}
416417
};
417418

419+
/// Lifts an illegal vector.transpose and vector.transfer_read to a
420+
/// memref.subview + memref.transpose, followed by a legal read.
421+
///
422+
/// 'Illegal' here means a leading scalable dimension and a fixed trailing
423+
/// dimension, which has no valid lowering.
424+
///
425+
/// The memref.transpose is metadata-only transpose that produces a strided
426+
/// memref, which eventually becomes a loop reading individual elements.
427+
///
428+
/// Example:
429+
///
430+
/// BEFORE:
431+
/// ```mlir
432+
/// %illegalRead = vector.transfer_read %memref[%a, %b]
433+
/// : memref<?x?xf32>, vector<[8]x4xf32>
434+
/// %legalType = vector.transpose %illegalRead, [1, 0]
435+
/// : vector<[8]x4xf32> to vector<4x[8]xf32>
436+
/// ```
437+
///
438+
/// AFTER:
439+
/// ```mlir
440+
/// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
441+
/// : memref<?x?xf32> to memref<?x?xf32>
442+
/// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
443+
/// : memref<?x?xf32> to memref<?x?xf32>
444+
/// %legalType = vector.transfer_read %transpose[%c0, %c0]
445+
/// : memref<?x?xf32>, vector<4x[8]xf32>
446+
/// ```
447+
struct LiftIllegalVectorTransposeToMemory
448+
: public OpRewritePattern<vector::TransposeOp> {
449+
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
450+
451+
static bool isIllegalVectorType(VectorType vType) {
452+
bool seenFixedDim = false;
453+
for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
454+
seenFixedDim |= !scalableFlag;
455+
if (seenFixedDim && scalableFlag)
456+
return true;
457+
}
458+
return false;
459+
}
460+
461+
static Value getExtensionSource(Operation *op) {
462+
if (isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
463+
return op->getOperand(0);
464+
return {};
465+
}
466+
467+
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
468+
PatternRewriter &rewriter) const override {
469+
auto sourceType = transposeOp.getSourceVectorType();
470+
auto resultType = transposeOp.getResultVectorType();
471+
if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
472+
return rewriter.notifyMatchFailure(
473+
transposeOp, "expected transpose from illegal type to legal type");
474+
475+
// Look through extend for transfer_read.
476+
Value maybeRead = transposeOp.getVector();
477+
auto *transposeSourceOp = maybeRead.getDefiningOp();
478+
Operation *extendOp = nullptr;
479+
if (Value extendSource = getExtensionSource(transposeSourceOp)) {
480+
maybeRead = extendSource;
481+
extendOp = transposeSourceOp;
482+
}
483+
484+
auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
485+
if (!illegalRead)
486+
return rewriter.notifyMatchFailure(
487+
transposeOp,
488+
"expected source to be (possibly extended) transfer_read");
489+
490+
if (!illegalRead.getPermutationMap().isIdentity())
491+
return rewriter.notifyMatchFailure(
492+
illegalRead, "expected read to have identity permutation map");
493+
494+
auto loc = transposeOp.getLoc();
495+
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
496+
auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
497+
498+
// Create a subview that matches the size of the illegal read vector type.
499+
auto readType = illegalRead.getVectorType();
500+
auto readSizes = llvm::map_to_vector(
501+
llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
502+
[&](auto dim) -> Value {
503+
auto [size, isScalable] = dim;
504+
auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
505+
if (!isScalable)
506+
return dimSize;
507+
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
508+
return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
509+
});
510+
SmallVector<Value> strides(readType.getRank(), Value(one));
511+
auto readSubview = rewriter.create<memref::SubViewOp>(
512+
loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
513+
strides);
514+
515+
// Apply the transpose to all values/attributes of the transfer_read:
516+
// - The mask
517+
Value mask = illegalRead.getMask();
518+
if (mask) {
519+
// Note: The transpose for the mask should fold into the
520+
// vector.create_mask/constant_mask op, which will then become legal.
521+
mask = rewriter.create<vector::TransposeOp>(loc, mask,
522+
transposeOp.getPermutation());
523+
}
524+
// - The source memref
525+
mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
526+
transposeOp.getPermutation(), getContext());
527+
auto transposedSubview = rewriter.create<memref::TransposeOp>(
528+
loc, readSubview, AffineMapAttr::get(transposeMap));
529+
ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
530+
// - The `in_bounds` attribute
531+
if (inBoundsAttr) {
532+
SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
533+
inBoundsAttr.end());
534+
applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
535+
inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
536+
}
537+
538+
VectorType legalReadType = resultType.clone(readType.getElementType());
539+
// Note: The indices are all zero as the subview is already offset.
540+
SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
541+
auto legalRead = rewriter.create<vector::TransferReadOp>(
542+
loc, legalReadType, transposedSubview, readIndices,
543+
illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
544+
inBoundsAttr);
545+
546+
// Replace the transpose with the new read, extending the result if
547+
// necessary.
548+
rewriter.replaceOp(transposeOp, [&]() -> Operation * {
549+
if (extendOp)
550+
return rewriter.create(loc, extendOp->getName().getIdentifier(),
551+
Value(legalRead), resultType);
552+
return legalRead;
553+
}());
554+
555+
return success();
556+
}
557+
};
558+
418559
struct VectorLegalizationPass
419560
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
420561
void runOnOperation() override {
@@ -434,7 +575,8 @@ struct VectorLegalizationPass
434575
return success();
435576
});
436577

437-
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks>(context);
578+
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
579+
LiftIllegalVectorTransposeToMemory>(context);
438580
// Note: High benefit to ensure masked outer products are lowered first.
439581
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
440582
converter, context, 1024);

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,78 @@ func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: ind
302302
%extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
303303
return %extract : vector<[4]x[4]xi1>
304304
}
305+
306+
// -----
307+
308+
// CHECK-LABEL: @lift_illegal_transpose_to_memory(
309+
// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
310+
// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
311+
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
312+
func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
313+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
314+
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
315+
// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
316+
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
317+
// CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
318+
// CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
319+
// CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
320+
// CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
321+
// CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
322+
// CHECK-NEXT: return %[[LEGAL_READ]]
323+
%pad = arith.constant 0.0 : f32
324+
%illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xf32>, vector<[8]x4xf32>
325+
%legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
326+
return %legalType : vector<4x[8]xf32>
327+
}
328+
329+
// -----
330+
331+
// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_mask(
332+
// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
333+
// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
334+
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>
335+
func.func @lift_illegal_transpose_to_memory_with_mask(%dim0: index, %dim1: index, %memref: memref<?x?xf32>, %a: index, %b: index) -> vector<4x[8]xf32> {
336+
// CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
337+
// CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
338+
// CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
339+
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
340+
// CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
341+
// CHECK-SAME: %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
342+
// CHECK-NEXT: return %[[LEGAL_READ]]
343+
%pad = arith.constant 0.0 : f32
344+
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
345+
%illegalRead = vector.transfer_read %memref[%a, %b], %pad, %mask : memref<?x?xf32>, vector<[8]x4xf32>
346+
%legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
347+
return %legalType : vector<4x[8]xf32>
348+
}
349+
350+
// -----
351+
352+
// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop(
353+
// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>
354+
func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
355+
// CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
356+
// CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
357+
// CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
358+
// CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
359+
// CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
360+
// CHECK-NEXT: return %[[EXT_TYPE]]
361+
%pad = arith.constant 0 : i8
362+
%illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xi8>, vector<[8]x4xi8>
363+
%extRead = arith.extsi %illegalRead : vector<[8]x4xi8> to vector<[8]x4xi32>
364+
%legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32>
365+
return %legalType : vector<4x[8]xi32>
366+
}
367+
368+
// -----
369+
370+
// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_in_bounds_attr
371+
func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
372+
// CHECK: vector.transfer_read
373+
// CHECK-SAME: in_bounds = [true, false]
374+
// CHECK-NOT: in_bounds = [false, true]
375+
%pad = arith.constant 0.0 : f32
376+
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[8]x4xf32>
377+
%legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
378+
return %legalType : vector<4x[8]xf32>
379+
}

0 commit comments

Comments
 (0)