18
18
#include " mlir/Dialect/ArmSME/Utils/Utils.h"
19
19
#include " mlir/Dialect/Func/IR/FuncOps.h"
20
20
#include " mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
21
+ #include " mlir/Dialect/MemRef/IR/MemRef.h"
21
22
#include " mlir/Dialect/SCF/Transforms/Patterns.h"
22
23
#include " mlir/Dialect/Utils/IndexingUtils.h"
23
24
#include " mlir/Transforms/OneToNTypeConversion.h"
@@ -415,6 +416,146 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
415
416
}
416
417
};
417
418
419
+ // / Lifts an illegal vector.transpose and vector.transfer_read to a
420
+ // / memref.subview + memref.transpose, followed by a legal read.
421
+ // /
422
+ // / 'Illegal' here means a leading scalable dimension and a fixed trailing
423
+ // / dimension, which has no valid lowering.
424
+ // /
425
+ // / The memref.transpose is metadata-only transpose that produces a strided
426
+ // / memref, which eventually becomes a loop reading individual elements.
427
+ // /
428
+ // / Example:
429
+ // /
430
+ // / BEFORE:
431
+ // / ```mlir
432
+ // / %illegalRead = vector.transfer_read %memref[%a, %b]
433
+ // / : memref<?x?xf32>, vector<[8]x4xf32>
434
+ // / %legalType = vector.transpose %illegalRead, [1, 0]
435
+ // / : vector<[8]x4xf32> to vector<4x[8]xf32>
436
+ // / ```
437
+ // /
438
+ // / AFTER:
439
+ // / ```mlir
440
+ // / %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
441
+ // / : memref<?x?xf32> to memref<?x?xf32>
442
+ // / %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
443
+ // / : memref<?x?xf32> to memref<?x?xf32>
444
+ // / %legalType = vector.transfer_read %transpose[%c0, %c0]
445
+ // / : memref<?x?xf32>, vector<4x[8]xf32>
446
+ // / ```
447
+ struct LiftIllegalVectorTransposeToMemory
448
+ : public OpRewritePattern<vector::TransposeOp> {
449
+ using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
450
+
451
+ static bool isIllegalVectorType (VectorType vType) {
452
+ bool seenFixedDim = false ;
453
+ for (bool scalableFlag : llvm::reverse (vType.getScalableDims ())) {
454
+ seenFixedDim |= !scalableFlag;
455
+ if (seenFixedDim && scalableFlag)
456
+ return true ;
457
+ }
458
+ return false ;
459
+ }
460
+
461
+ static Value getExtensionSource (Operation *op) {
462
+ if (isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
463
+ return op->getOperand (0 );
464
+ return {};
465
+ }
466
+
467
+ LogicalResult matchAndRewrite (vector::TransposeOp transposeOp,
468
+ PatternRewriter &rewriter) const override {
469
+ auto sourceType = transposeOp.getSourceVectorType ();
470
+ auto resultType = transposeOp.getResultVectorType ();
471
+ if (!isIllegalVectorType (sourceType) || isIllegalVectorType (resultType))
472
+ return rewriter.notifyMatchFailure (
473
+ transposeOp, " expected transpose from illegal type to legal type" );
474
+
475
+ // Look through extend for transfer_read.
476
+ Value maybeRead = transposeOp.getVector ();
477
+ auto *transposeSourceOp = maybeRead.getDefiningOp ();
478
+ Operation *extendOp = nullptr ;
479
+ if (Value extendSource = getExtensionSource (transposeSourceOp)) {
480
+ maybeRead = extendSource;
481
+ extendOp = transposeSourceOp;
482
+ }
483
+
484
+ auto illegalRead = maybeRead.getDefiningOp <vector::TransferReadOp>();
485
+ if (!illegalRead)
486
+ return rewriter.notifyMatchFailure (
487
+ transposeOp,
488
+ " expected source to be (possibly extended) transfer_read" );
489
+
490
+ if (!illegalRead.getPermutationMap ().isIdentity ())
491
+ return rewriter.notifyMatchFailure (
492
+ illegalRead, " expected read to have identity permutation map" );
493
+
494
+ auto loc = transposeOp.getLoc ();
495
+ auto zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
496
+ auto one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
497
+
498
+ // Create a subview that matches the size of the illegal read vector type.
499
+ auto readType = illegalRead.getVectorType ();
500
+ auto readSizes = llvm::map_to_vector (
501
+ llvm::zip_equal (readType.getShape (), readType.getScalableDims ()),
502
+ [&](auto dim) -> Value {
503
+ auto [size, isScalable] = dim;
504
+ auto dimSize = rewriter.create <arith::ConstantIndexOp>(loc, size);
505
+ if (!isScalable)
506
+ return dimSize;
507
+ auto vscale = rewriter.create <vector::VectorScaleOp>(loc);
508
+ return rewriter.create <arith::MulIOp>(loc, vscale, dimSize);
509
+ });
510
+ SmallVector<Value> strides (readType.getRank (), Value (one));
511
+ auto readSubview = rewriter.create <memref::SubViewOp>(
512
+ loc, illegalRead.getSource (), illegalRead.getIndices (), readSizes,
513
+ strides);
514
+
515
+ // Apply the transpose to all values/attributes of the transfer_read:
516
+ // - The mask
517
+ Value mask = illegalRead.getMask ();
518
+ if (mask) {
519
+ // Note: The transpose for the mask should fold into the
520
+ // vector.create_mask/constant_mask op, which will then become legal.
521
+ mask = rewriter.create <vector::TransposeOp>(loc, mask,
522
+ transposeOp.getPermutation ());
523
+ }
524
+ // - The source memref
525
+ mlir::AffineMap transposeMap = AffineMap::getPermutationMap (
526
+ transposeOp.getPermutation (), getContext ());
527
+ auto transposedSubview = rewriter.create <memref::TransposeOp>(
528
+ loc, readSubview, AffineMapAttr::get (transposeMap));
529
+ ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr ();
530
+ // - The `in_bounds` attribute
531
+ if (inBoundsAttr) {
532
+ SmallVector<Attribute> inBoundsValues (inBoundsAttr.begin (),
533
+ inBoundsAttr.end ());
534
+ applyPermutationToVector (inBoundsValues, transposeOp.getPermutation ());
535
+ inBoundsAttr = rewriter.getArrayAttr (inBoundsValues);
536
+ }
537
+
538
+ VectorType legalReadType = resultType.clone (readType.getElementType ());
539
+ // Note: The indices are all zero as the subview is already offset.
540
+ SmallVector<Value> readIndices (illegalRead.getIndices ().size (), zero);
541
+ auto legalRead = rewriter.create <vector::TransferReadOp>(
542
+ loc, legalReadType, transposedSubview, readIndices,
543
+ illegalRead.getPermutationMapAttr (), illegalRead.getPadding (), mask,
544
+ inBoundsAttr);
545
+
546
+ // Replace the transpose with the new read, extending the result if
547
+ // necessary.
548
+ rewriter.replaceOp (transposeOp, [&]() -> Operation * {
549
+ if (extendOp)
550
+ return rewriter.create (loc, extendOp->getName ().getIdentifier (),
551
+ Value (legalRead), resultType);
552
+ return legalRead;
553
+ }());
554
+
555
+ return success ();
556
+ }
557
+ };
558
+
418
559
struct VectorLegalizationPass
419
560
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
420
561
void runOnOperation () override {
@@ -434,7 +575,8 @@ struct VectorLegalizationPass
434
575
return success ();
435
576
});
436
577
437
- patterns.add <FoldExtractFromVectorOfSMELikeCreateMasks>(context);
578
+ patterns.add <FoldExtractFromVectorOfSMELikeCreateMasks,
579
+ LiftIllegalVectorTransposeToMemory>(context);
438
580
// Note: High benefit to ensure masked outer products are lowered first.
439
581
patterns.add <LegalizeMaskedVectorOuterProductOpsByDecomposition>(
440
582
converter, context, 1024 );
0 commit comments