Skip to content

[mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory #80170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 6, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jan 31, 2024

When unrolling the reduction dimension of something like a matmul for SME, you can end up with transposed reads of illegal types, like so:

%illegalRead = vector.transfer_read %memref[%a, %b]
                : memref<?x?xf32>, vector<[8]x4xf32>
%legalType = vector.transpose %illegalRead, [1, 0]
                : vector<[8]x4xf32> to vector<4x[8]xf32>

Here the vector<[8]x4xf32> is an illegal type, there's no way to lower a scalable vector of fixed vectors. However, as the final type vector<4x[8]xf32> is legal, we can instead lift the transpose to memory (producing a strided memref), and eliminate all the illegal types. This is shown below.

%readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
                : memref<?x?xf32> to memref<?x?xf32>
%transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
                : memref<?x?xf32> to memref<?x?xf32>
%legalType = vector.transfer_read %transpose[%c0, %c0]
                : memref<?x?xf32>, vector<4x[8]xf32>

@llvmbot
Copy link
Member

llvmbot commented Jan 31, 2024

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

When unrolling the reduction dimension of something like a matmul for SME, you can end up with transposed reads of illegal types, like so:

%illegalRead = vector.transfer_read %memref[%a, %b]
                : memref&lt;?x?xf32&gt;, vector&lt;[8]x4xf32&gt;
%legalType = vector.transpose %illegalRead, [1, 0]
                : vector&lt;[8]x4xf32&gt; to vector&lt;4x[8]xf32&gt;

Here the vector&lt;[8]x4xf32&gt; is an illegal type, there's no way to lower a scalable vector of fixed vectors. However, as the final type vector&lt;4x[8]xf32&gt; is legal, we can instead lift the transpose to memory (producing a strided memref), and eliminate all the illegal types. This is shown below.

%readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
                : memref&lt;?x?xf32&gt; to memref&lt;?x?xf32&gt;
%transpose = memref.transpose %readSubview (d0, d1) -&gt; (d1, d0)
                : memref&lt;?x?xf32&gt; to memref&lt;?x?xf32&gt;
%legalType = vector.transfer_read %transpose[%c0, %c0]
                : memref&lt;?x?xf32&gt;, vector&lt;4x[8]xf32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/80170.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+160-3)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+75)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 85ec53c2618aa..a3db2d2395528 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -7,8 +7,6 @@
 //===----------------------------------------------------------------------===//
 //
 // This pass legalizes vector operations so they can be lowered to ArmSME.
-// Currently, this only implements the decomposition of vector operations that
-// use vector sizes larger than an SME tile, into multiple SME-sized operations.
 //
 // Note: In the context of this pass 'tile' always refers to an SME tile.
 //
@@ -19,6 +17,7 @@
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
@@ -35,6 +34,10 @@ using namespace mlir::arm_sme;
 
 namespace {
 
+//===----------------------------------------------------------------------===//
+// Decomposition of vector operations larger than an SME tile
+//===----------------------------------------------------------------------===//
+
 // Common match failure reasons.
 static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
     "op vector size is not multiple of SME tiles");
@@ -338,13 +341,166 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ArmSME-specific fixup canonicalizations/folds
+//===----------------------------------------------------------------------===//
+
+/// Lifts an illegal vector.transpose and vector.transfer_read to a
+/// memref.subview + memref.transpose, followed by a legal read.
+///
+/// 'Illegal' here means a leading scalable dimension and a fixed trailing
+/// dimension, which has no valid lowering.
+///
+/// The memref.transpose is metadata-only transpose that produces a strided
+/// memref, which eventually becomes a loop reading individual elements.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %illegalRead = vector.transfer_read %memref[%a, %b]
+///                  : memref<?x?xf32>, vector<[8]x4xf32>
+///  %legalType = vector.transpose %illegalRead, [1, 0]
+///                  : vector<[8]x4xf32> to vector<4x[8]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
+///                  : memref<?x?xf32> to memref<?x?xf32>
+///  %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
+///                  : memref<?x?xf32> to memref<?x?xf32>
+///  %legalType = vector.transfer_read %transpose[%c0, %c0]
+///                  : memref<?x?xf32>, vector<4x[8]xf32>
+///  ```
+struct LiftIllegalVectorTransposeToMemory
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  static bool isIllegalVectorType(VectorType vType) {
+    bool seenFixedDim = false;
+    for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+      seenFixedDim |= !scalableFlag;
+      if (seenFixedDim && scalableFlag)
+        return true;
+    }
+    return false;
+  }
+
+  static Value getExtensionSource(Operation *op) {
+    if (auto signExtend = dyn_cast<arith::ExtSIOp>(op))
+      return signExtend.getIn();
+    if (auto zeroExtend = dyn_cast<arith::ExtUIOp>(op))
+      return zeroExtend.getIn();
+    if (auto floatExtend = dyn_cast<arith::ExtFOp>(op))
+      return floatExtend.getIn();
+    return {};
+  }
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = transposeOp.getSourceVectorType();
+    auto resultType = transposeOp.getResultVectorType();
+    if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(
+          transposeOp, "expected transpose from illegal type to legal type");
+
+    Value maybeRead = transposeOp.getVector();
+    auto *transposeSourceOp = maybeRead.getDefiningOp();
+    Operation *extendOp = nullptr;
+    if (Value extendSource = getExtensionSource(transposeSourceOp)) {
+      maybeRead = extendSource;
+      extendOp = transposeSourceOp;
+    }
+
+    auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
+    if (!illegalRead)
+      return rewriter.notifyMatchFailure(
+          transposeOp,
+          "expected source to be (possibility extended) transfer_read");
+
+    if (!illegalRead.getPermutationMap().isIdentity())
+      return rewriter.notifyMatchFailure(
+          illegalRead, "expected read to have identity permutation map");
+
+    auto loc = transposeOp.getLoc();
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+    // Create a subview that matches the size of the illegal read vector type.
+    auto readType = illegalRead.getVectorType();
+    auto readSizes = llvm::map_to_vector(
+        llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
+        [&](auto dim) -> Value {
+          auto [size, isScalable] = dim;
+          auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
+          if (!isScalable)
+            return dimSize;
+          auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+          return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
+        });
+    SmallVector<Value> strides(readType.getRank(), Value(one));
+    auto readSubview = rewriter.create<memref::SubViewOp>(
+        loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
+        strides);
+
+    // Apply the transpose to all values/attributes of the transfer_read.
+    // The mask.
+    Value mask = illegalRead.getMask();
+    if (mask) {
+      // Note: The transpose for the mask should fold into the
+      // vector.create_mask/constant_mask op, which will then become legal.
+      mask = rewriter.create<vector::TransposeOp>(loc, mask,
+                                                  transposeOp.getPermutation());
+    }
+    // The source memref.
+    mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
+        transposeOp.getPermutation(), getContext());
+    auto transposedSubview = rewriter.create<memref::TransposeOp>(
+        loc, readSubview, AffineMapAttr::get(transposeMap));
+    ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
+    // The `in_bounds` attribute.
+    if (inBoundsAttr) {
+      SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
+                                            inBoundsAttr.end());
+      applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
+      inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
+    }
+
+    VectorType legalReadType =
+        VectorType::Builder(resultType)
+            .setElementType(illegalRead.getVectorType().getElementType());
+    // Note: The indices are all zero as the subview is already offset.
+    SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
+    Value legalRead = rewriter.create<vector::TransferReadOp>(
+        loc, legalReadType, transposedSubview, readIndices,
+        illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
+        inBoundsAttr);
+
+    // Replace the transpose with the new read, extending the result if
+    // necessary.
+    rewriter.replaceOp(transposeOp, [&]() -> Value {
+      if (!extendOp)
+        return legalRead;
+      if (isa<arith::ExtSIOp>(extendOp))
+        return rewriter.create<arith::ExtSIOp>(loc, resultType, legalRead);
+      if (isa<arith::ExtUIOp>(extendOp))
+        return rewriter.create<arith::ExtUIOp>(loc, resultType, legalRead);
+      if (isa<arith::ExtFOp>(extendOp))
+        return rewriter.create<arith::ExtFOp>(loc, resultType, legalRead);
+      return legalRead;
+    }());
+
+    return success();
+  }
+};
+
 struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
     auto *context = &getContext();
     OneToNTypeConverter converter;
     RewritePatternSet patterns(context);
-
     converter.addConversion([](Type type) { return type; });
     converter.addConversion(
         [](VectorType vectorType,
@@ -358,6 +514,7 @@ struct VectorLegalizationPass
           return success();
         });
 
+    patterns.add<LiftIllegalVectorTransposeToMemory>(context);
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index a20abeefedcfd..2317930d3d061 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -266,3 +266,78 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
   vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_no_mask(
+// CHECK-SAME:                                            %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                            %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                            %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+  // CHECK-NEXT: return %[[LEGAL_READ]]
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xf32>, vector<[8]x4xf32>
+  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+  return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory(
+// CHECK-SAME:                                    %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index, %dim1: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]], %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+  // CHECK-NEXT: return %[[LEGAL_READ]]
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad, %mask : memref<?x?xf32>, vector<[8]x4xf32>
+  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+  return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop(
+// CHECK-SAME:                                                     %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                     %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                     %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>)
+func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xi8> to memref<?x4xi8, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xi8, strided<[?, 1], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xi8, strided<[?, ?], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_I8]] : memref<?x?xi8, strided<[?, ?], offset: ?>>, vector<4x[8]xi8>
+  // CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
+  // CHECK-NEXT: return %[[EXT_TYPE]]
+  %pad = arith.constant 0 : i8
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xi8>, vector<[8]x4xi8>
+  %extRead = arith.extsi %illegalRead : vector<[8]x4xi8> to vector<[8]x4xi32>
+  %legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32>
+  return %legalType : vector<4x[8]xi32>
+}

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

When unrolling the reduction dimension of something like a matmul for
SME, you can end up with transposed reads of illegal types, like so:

```mlir
%illegalRead = vector.transfer_read %memref[%a, %b]
                : memref<?x?xf32>, vector<[8]x4xf32>
%legalType = vector.transpose %illegalRead, [1, 0]
                : vector<[8]x4xf32> to vector<4x[8]xf32>
```

Here the `vector<[8]x4xf32>` is an illegal type, there's no way to
lower a scalable vector of fixed vectors. However, as the final type
`vector<4x[8]xf32>` is legal, we can instead lift the transpose to
memory (producing a strided memref), and eliminate all the illegal
types. This is shown below.

```mlir
%readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
                : memref<?x?xf32> to memref<?x?xf32>
%transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
                : memref<?x?xf32> to memref<?x?xf32>
%legalType = vector.transfer_read %transpose[%c0, %c0]
                : memref<?x?xf32>, vector<4x[8]xf32>
```
@MacDue MacDue force-pushed the transpose_to_memory branch from ab8c3d7 to 08dd0e4 Compare February 2, 2024 10:26
@MacDue MacDue merged commit 0473e32 into llvm:main Feb 6, 2024
@MacDue MacDue deleted the transpose_to_memory branch February 6, 2024 09:30
MacDue added a commit to MacDue/llvm-project that referenced this pull request Feb 8, 2024
…mlir

This tests both llvm#80148 and llvm#80170 work together to allow unrolling the
reduction dimension of a matmul.
MacDue added a commit that referenced this pull request Feb 12, 2024
…mlir (#81160)

This tests both #80148 and #80170 work together to allow unrolling the
reduction dimension of a matmul.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants