Skip to content

Commit 174cd61

Browse files
authored
[mlir][ArmSME] Add custom vector.print lowering for SME tiles (#66691)
This adds a custom lowering for SME that loops over each row of the tile, extracting it via an SME MOVA, then printing with a normal 1D vector.print. This makes writing SME integration tests easier and less verbose. Depends on: #66910, #66911
1 parent 9555736 commit 174cd61

File tree

10 files changed

+161
-134
lines changed

10 files changed

+161
-134
lines changed

mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ bool isValidSMETileElementType(Type type);
3434
/// otherwise.
3535
bool isValidSMETileVectorType(VectorType vType);
3636

37+
/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
38+
/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
39+
/// integer, to an i32 that can be passed as the `tile` parameter to the SME
40+
/// intrinsics. Or returns `tile` if already i32.
41+
Value castTileIDToI32(Value tile, Location loc, RewriterBase &rewriter);
42+
3743
} // namespace arm_sme
3844
} // namespace mlir
3945

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,94 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
190190
}
191191
};
192192

193+
/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
194+
/// extracting them via a MOVA, then printing with a 1D `vector.print`.
195+
///
196+
/// BEFORE:
197+
/// ```mlir
198+
/// vector.print %tile : vector<[4]x[4]xf32>
199+
/// ```
200+
/// AFTER:
201+
/// ```mlir
202+
/// %c0 = arith.constant 0 : index
203+
/// %c1 = arith.constant 1 : index
204+
/// %c4 = arith.constant 4 : index
205+
/// %ptrue = arith.constant dense<true> : vector<[4]xi1>
206+
/// %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xf32> to i32
207+
/// %vscale = vector.vscale
208+
/// %svl_s = arith.muli %c4, %vscale : index
209+
/// %cst = arith.constant dense<0.000000e+00> : vector<[4]xf32>
210+
/// scf.for %i = %c0 to %svl_s step %c1 {
211+
/// %slice_idx = arith.index_cast %i : index to i32
212+
/// %tile_slice = "arm_sme.intr.read.horiz"
213+
/// (%cst, %ptrue, %tile_id, %slice_idx)
214+
/// : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
215+
/// vector.print %tile_slice : vector<[4]xf32>
216+
/// }
217+
/// ```
218+
struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
219+
using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
220+
221+
LogicalResult matchAndRewrite(vector::PrintOp printOp,
222+
PatternRewriter &rewriter) const override {
223+
if (!printOp.getSource())
224+
return failure();
225+
226+
VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
227+
if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
228+
return failure();
229+
230+
auto loc = printOp.getLoc();
231+
232+
// Create an 'all true' predicate for each tile row.
233+
auto predicateType =
234+
VectorType::get(vectorType.getDimSize(1), rewriter.getI1Type(), true);
235+
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
236+
loc, DenseElementsAttr::get(predicateType, true));
237+
238+
// Cast tile to i32 tile ID.
239+
auto tileId =
240+
rewriter.create<arm_sme::CastVectorToTile>(loc, printOp.getSource());
241+
Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
242+
243+
// Zero destination/fallback for tile slice extraction.
244+
auto rowType = VectorType::get(vectorType.getDimSize(1),
245+
vectorType.getElementType(), true);
246+
auto zeroVector = rewriter.create<arith::ConstantOp>(
247+
loc, rowType, rewriter.getZeroAttr(rowType));
248+
249+
// Create a loop over the rows of the tile.
250+
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
251+
auto minTileRows =
252+
rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
253+
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
254+
auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
255+
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
256+
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
257+
{
258+
// Loop body.
259+
rewriter.setInsertionPointToStart(forOp.getBody());
260+
// Extract the current row from the tile.
261+
Value rowIndex = forOp.getInductionVar();
262+
auto rowIndexI32 = rewriter.create<arith::IndexCastOp>(
263+
loc, rewriter.getI32Type(), rowIndex);
264+
auto tileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
265+
loc, rowType, zeroVector, allTruePredicate, tileIdI32, rowIndexI32);
266+
// Print the row with a 1D vector.print.
267+
rewriter.create<vector::PrintOp>(loc, tileSlice,
268+
printOp.getPunctuation());
269+
}
270+
271+
rewriter.eraseOp(printOp);
272+
return success();
273+
}
274+
};
275+
193276
} // namespace
194277

195278
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
196-
patterns.add<TileLoadOpConversion, TileStoreOpConversion>(
197-
patterns.getContext());
279+
patterns.add<TileLoadOpConversion, TileStoreOpConversion,
280+
TileVectorPrintOpConversion>(patterns.getContext());
198281
}
199282

200283
namespace {
@@ -208,6 +291,12 @@ struct ConvertArmSMEToSCFPass
208291
target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
209292
arith::ArithDialect, scf::SCFDialect>();
210293
target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
294+
target.addDynamicallyLegalOp<vector::PrintOp>([](vector::PrintOp op) {
295+
if (!op.getSource())
296+
return true;
297+
VectorType vectorType = dyn_cast<VectorType>(op.getPrintType());
298+
return !vectorType || !arm_sme::isValidSMETileVectorType(vectorType);
299+
});
211300
if (failed(applyPartialConversion(getOperation(), target,
212301
std::move(patterns))))
213302
signalPassFailure();

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,6 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
4949
}
5050
};
5151

52-
/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
53-
/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
54-
/// integer, to an i32 that can be passed as the `tile` parameter to the SME
55-
/// intrinsics. Or returns `tile` if already i32.
56-
Value castTileIDToI32(Value tile, Location loc,
57-
ConversionPatternRewriter &rewriter) {
58-
assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
59-
tile.getDefiningOp())) &&
60-
"expected ArmSME GetTileID or CastVectorToTile op!");
61-
unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
62-
if (tileElementWidth < 32)
63-
return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
64-
if (tileElementWidth > 32)
65-
return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
66-
return tile;
67-
}
68-
6952
/// Lower 'arm_sme.zero' to SME intrinsics.
7053
///
7154
/// BEFORE:

mlir/lib/Dialect/ArmSME/Utils/Utils.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
1414

15+
#include "mlir/Dialect/Arith/IR/Arith.h"
1516
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
1617

1718
using namespace mlir;
@@ -42,3 +43,16 @@ bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
4243

4344
return true;
4445
}
46+
47+
Value mlir::arm_sme::castTileIDToI32(Value tile, Location loc,
48+
RewriterBase &rewriter) {
49+
assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
50+
tile.getDefiningOp())) &&
51+
"expected ArmSME GetTileID or CastVectorToTile op!");
52+
unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
53+
if (tileElementWidth < 32)
54+
return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
55+
if (tileElementWidth > 32)
56+
return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
57+
return tile;
58+
}

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,25 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
5656
arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
5757
return
5858
}
59+
60+
// -----
61+
62+
func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
63+
{
64+
vector.print %tile : vector<[4]x[4]xf32>
65+
return
66+
}
67+
// CHECK-LABEL: func.func @arm_sme_tile_print(
68+
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) {
69+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
70+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
71+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
72+
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
73+
// CHECK-DAG: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
74+
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
75+
// CHECK-DAG: %[[ZERO_VECTOR:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
76+
// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
77+
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
78+
// CHECK-NEXT: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_cast %[[TILE_SLICE_INDEX]] : index to i32
79+
// CHECK-NEXT: %[[TILE_SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VECTOR]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
80+
// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>

mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
llvm.func @printCString(!llvm.ptr<i8>)
1515

16-
func.func @printTileBegin() {
16+
func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
1717
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
1818
%1 = llvm.mlir.constant(0 : index) : i64
1919
%2 = llvm.getelementptr %0[%1, %1]
@@ -22,7 +22,7 @@ func.func @printTileBegin() {
2222
return
2323
}
2424

25-
func.func @printTileEnd() {
25+
func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
2626
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
2727
%1 = llvm.mlir.constant(0 : index) : i64
2828
%2 = llvm.getelementptr %0[%1, %1]
@@ -44,7 +44,6 @@ func.func @entry() {
4444

4545
// Allocate memory.
4646
%mem1 = memref.alloca(%za_s_size) : memref<?xi32>
47-
%mem2 = memref.alloca(%za_s_size) : memref<?xi32>
4847

4948
// Fill each "row" of "mem1" with row number.
5049
//
@@ -66,11 +65,6 @@ func.func @entry() {
6665
// Load tile from "mem1" vertically.
6766
%0 = arm_sme.tile_load %mem1[%c0, %c0], <vertical> : memref<?xi32>, vector<[4]x[4]xi32>
6867

69-
// Store tile back to "mem2" to print.
70-
// TODO: Support vector.print for 2-D scalable vectors so don't have to spill
71-
// to memory and reload to print.
72-
vector.store %0, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
73-
7468
// 1. ORIGINAL HORIZONTAL LAYOUT
7569
// Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
7670
// 4x4xi32.
@@ -99,10 +93,7 @@ func.func @entry() {
9993
// CHECK-NEXT: ( 0, 1, 2, 3
10094
// CHECK: TILE END
10195
func.call @printTileBegin() : () -> ()
102-
scf.for %i = %c0 to %za_s_size step %svl_s {
103-
%tileslice = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
104-
vector.print %tileslice : vector<[4]xi32>
105-
}
96+
vector.print %0 : vector<[4]x[4]xi32>
10697
func.call @printTileEnd() : () -> ()
10798

10899
return

mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir

Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
llvm.func @printCString(!llvm.ptr<i8>)
1818

19-
func.func @printTileBegin() {
19+
func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
2020
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
2121
%1 = llvm.mlir.constant(0 : index) : i64
2222
%2 = llvm.getelementptr %0[%1, %1]
@@ -25,7 +25,7 @@ func.func @printTileBegin() {
2525
return
2626
}
2727

28-
func.func @printTileEnd() {
28+
func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
2929
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
3030
%1 = llvm.mlir.constant(0 : index) : i64
3131
%2 = llvm.getelementptr %0[%1, %1]
@@ -41,20 +41,8 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
4141
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
4242
%tile = vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>
4343

44-
// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
45-
%vscale = vector.vscale
46-
%min_elts_s = arith.constant 4 : index
47-
%svl_s = arith.muli %min_elts_s, %vscale : index
48-
%za_s_size = arith.muli %svl_s, %svl_s : index
49-
50-
// Allocate memory.
51-
%mem = memref.alloca(%za_s_size) : memref<?xf32>
52-
53-
// Store the tile to memory.
54-
vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
55-
56-
// Reload and print. The smallest SVL is 128-bits so the tile will be at
57-
// least 4x4xf32.
44+
// Print the tile. The smallest SVL is 128-bits so the tile will be at least
45+
// 4x4xf32.
5846
//
5947
// WITHOUT-ACC: TILE BEGIN
6048
// WITHOUT-ACC-NEXT: ( 0, 0, 0, 0
@@ -63,10 +51,7 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
6351
// WITHOUT-ACC-NEXT: ( 0, 3, 6, 9
6452
// WITHOUT-ACC: TILE END
6553
func.call @printTileBegin() : () -> ()
66-
scf.for %i = %c0 to %za_s_size step %svl_s {
67-
%tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
68-
vector.print %tileslice : vector<[4]xf32>
69-
}
54+
vector.print %tile : vector<[4]x[4]xf32>
7055
func.call @printTileEnd() : () -> ()
7156

7257
return
@@ -81,20 +66,8 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
8166
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
8267
%tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
8368

84-
// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
85-
%vscale = vector.vscale
86-
%min_elts_s = arith.constant 4 : index
87-
%svl_s = arith.muli %min_elts_s, %vscale : index
88-
%za_s_size = arith.muli %svl_s, %svl_s : index
89-
90-
// Allocate memory.
91-
%mem = memref.alloca(%za_s_size) : memref<?xf32>
92-
93-
// Store the tile to memory.
94-
vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
95-
96-
// Reload and print. The smallest SVL is 128-bits so the tile will be at
97-
// least 4x4xf32.
69+
// Print the tile. The smallest SVL is 128-bits so the tile will be at least
70+
// 4x4xf32.
9871
//
9972
// WITH-ACC: TILE BEGIN
10073
// WITH-ACC-NEXT: ( 10, 10, 10, 10
@@ -103,10 +76,7 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
10376
// WITH-ACC-NEXT: ( 10, 13, 16, 19
10477
// WITH-ACC: TILE END
10578
func.call @printTileBegin() : () -> ()
106-
scf.for %i = %c0 to %za_s_size step %svl_s {
107-
%tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
108-
vector.print %tileslice : vector<[4]xf32>
109-
}
79+
vector.print %tile : vector<[4]x[4]xf32>
11080
func.call @printTileEnd() : () -> ()
11181

11282
return

mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
llvm.func @printCString(!llvm.ptr<i8>)
1515

16-
func.func @printTileBegin() {
16+
func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
1717
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
1818
%1 = llvm.mlir.constant(0 : index) : i64
1919
%2 = llvm.getelementptr %0[%1, %1]
@@ -22,7 +22,7 @@ func.func @printTileBegin() {
2222
return
2323
}
2424

25-
func.func @printTileEnd() {
25+
func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
2626
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
2727
%1 = llvm.mlir.constant(0 : index) : i64
2828
%2 = llvm.getelementptr %0[%1, %1]
@@ -32,7 +32,6 @@ func.func @printTileEnd() {
3232
}
3333

3434
func.func @test_outerproduct_with_accumulator_2x2xf64() {
35-
%c0 = arith.constant 0 : index
3635
%f1 = arith.constant 1.0 : f64
3736
%f2 = arith.constant 2.0 : f64
3837
%f10 = arith.constant 10.0 : f64
@@ -44,30 +43,15 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {
4443

4544
%tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64>
4645

47-
// Calculate the size of a 64-bit tile, e.g. ZA{n}.d.
48-
%vscale = vector.vscale
49-
%min_elts_d = arith.constant 2 : index
50-
%svl_d = arith.muli %min_elts_d, %vscale : index
51-
%za_d_size = arith.muli %svl_d, %svl_d : index
52-
53-
// Allocate memory.
54-
%mem = memref.alloca(%za_d_size) : memref<?xf64>
55-
56-
// Store the tile to memory.
57-
vector.store %tile, %mem[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
58-
59-
// Reload and print. The smallest SVL is 128-bits so the tile will be at
60-
// least 2x2xf64.
46+
// Print the tile. The smallest SVL is 128-bits so the tile will be at least
47+
// 2x2xf64.
6148
//
6249
// CHECK: TILE BEGIN
6350
// CHECK-NEXT: ( 12, 12
6451
// CHECK-NEXT: ( 12, 12
6552
// CHECK: TILE END
6653
func.call @printTileBegin() : () -> ()
67-
scf.for %i = %c0 to %za_d_size step %svl_d {
68-
%tileslice = vector.load %mem[%i] : memref<?xf64>, vector<[2]xf64>
69-
vector.print %tileslice : vector<[2]xf64>
70-
}
54+
vector.print %tile : vector<[2]x[2]xf64>
7155
func.call @printTileEnd() : () -> ()
7256

7357
return

0 commit comments

Comments
 (0)