20
20
#include " mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21
21
#include " mlir/Dialect/Utils/IndexingUtils.h"
22
22
#include " mlir/Dialect/Vector/IR/VectorOps.h"
23
+ #include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
23
24
#include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
24
25
#include " mlir/IR/BuiltinTypes.h"
25
26
#include " mlir/IR/Operation.h"
26
27
#include " mlir/IR/PatternMatch.h"
28
+ #include " mlir/Pass/Pass.h"
27
29
#include " mlir/Support/LLVM.h"
28
30
#include " mlir/Transforms/DialectConversion.h"
31
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
29
32
#include " mlir/Transforms/OneToNTypeConversion.h"
30
33
#include " llvm/ADT/STLExtras.h"
31
34
#include " llvm/ADT/SmallVector.h"
32
35
#include " llvm/ADT/StringExtras.h"
33
36
#include " llvm/Support/Debug.h"
37
+ #include " llvm/Support/LogicalResult.h"
34
38
#include " llvm/Support/MathExtras.h"
35
39
36
40
#include < functional>
@@ -46,14 +50,6 @@ namespace {
46
50
// Utility functions
47
51
// ===----------------------------------------------------------------------===//
48
52
49
- static int getComputeVectorSize (int64_t size) {
50
- for (int i : {4 , 3 , 2 }) {
51
- if (size % i == 0 )
52
- return i;
53
- }
54
- return 1 ;
55
- }
56
-
57
53
static std::optional<SmallVector<int64_t >> getTargetShape (VectorType vecType) {
58
54
LLVM_DEBUG (llvm::dbgs () << " Get target shape\n " );
59
55
if (vecType.isScalable ()) {
@@ -62,8 +58,8 @@ static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
62
58
return std::nullopt;
63
59
}
64
60
SmallVector<int64_t > unrollShape = llvm::to_vector<4 >(vecType.getShape ());
65
- std::optional<SmallVector<int64_t >> targetShape =
66
- SmallVector< int64_t >( 1 , getComputeVectorSize (vecType.getShape ().back ()));
61
+ std::optional<SmallVector<int64_t >> targetShape = SmallVector< int64_t >(
62
+ 1 , mlir::spirv:: getComputeVectorSize (vecType.getShape ().back ()));
67
63
if (!targetShape) {
68
64
LLVM_DEBUG (llvm::dbgs () << " --no unrolling target shape defined\n " );
69
65
return std::nullopt;
@@ -1098,13 +1094,20 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1098
1094
// the original operand of illegal type.
1099
1095
auto originalShape =
1100
1096
llvm::to_vector_of<int64_t , 4 >(origVecType.getShape ());
1101
- SmallVector<int64_t > strides (targetShape->size (), 1 );
1097
+ SmallVector<int64_t > strides (originalShape.size (), 1 );
1098
+ SmallVector<int64_t > extractShape (originalShape.size (), 1 );
1099
+ extractShape.back () = targetShape->back ();
1102
1100
SmallVector<Type> newTypes;
1103
1101
Value returnValue = returnOp.getOperand (origResultNo);
1104
1102
for (SmallVector<int64_t > offsets :
1105
1103
StaticTileOffsetRange (originalShape, *targetShape)) {
1106
1104
Value result = rewriter.create <vector::ExtractStridedSliceOp>(
1107
- loc, returnValue, offsets, *targetShape, strides);
1105
+ loc, returnValue, offsets, extractShape, strides);
1106
+ if (originalShape.size () > 1 ) {
1107
+ SmallVector<int64_t > extractIndices (originalShape.size () - 1 , 0 );
1108
+ result =
1109
+ rewriter.create <vector::ExtractOp>(loc, result, extractIndices);
1110
+ }
1108
1111
newOperands.push_back (result);
1109
1112
newTypes.push_back (unrolledType);
1110
1113
}
@@ -1285,6 +1288,118 @@ Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter,
1285
1288
builder);
1286
1289
}
1287
1290
1291
+ // ===----------------------------------------------------------------------===//
1292
+ // Public functions for vector unrolling
1293
+ // ===----------------------------------------------------------------------===//
1294
+
1295
+ int mlir::spirv::getComputeVectorSize (int64_t size) {
1296
+ for (int i : {4 , 3 , 2 }) {
1297
+ if (size % i == 0 )
1298
+ return i;
1299
+ }
1300
+ return 1 ;
1301
+ }
1302
+
1303
+ SmallVector<int64_t >
1304
+ mlir::spirv::getNativeVectorShapeImpl (vector::ReductionOp op) {
1305
+ VectorType srcVectorType = op.getSourceVectorType ();
1306
+ assert (srcVectorType.getRank () == 1 ); // Guaranteed by semantics
1307
+ int64_t vectorSize =
1308
+ mlir::spirv::getComputeVectorSize (srcVectorType.getDimSize (0 ));
1309
+ return {vectorSize};
1310
+ }
1311
+
1312
+ SmallVector<int64_t >
1313
+ mlir::spirv::getNativeVectorShapeImpl (vector::TransposeOp op) {
1314
+ VectorType vectorType = op.getResultVectorType ();
1315
+ SmallVector<int64_t > nativeSize (vectorType.getRank (), 1 );
1316
+ nativeSize.back () =
1317
+ mlir::spirv::getComputeVectorSize (vectorType.getShape ().back ());
1318
+ return nativeSize;
1319
+ }
1320
+
1321
+ std::optional<SmallVector<int64_t >>
1322
+ mlir::spirv::getNativeVectorShape (Operation *op) {
1323
+ if (OpTrait::hasElementwiseMappableTraits (op) && op->getNumResults () == 1 ) {
1324
+ if (auto vecType = dyn_cast<VectorType>(op->getResultTypes ()[0 ])) {
1325
+ SmallVector<int64_t > nativeSize (vecType.getRank (), 1 );
1326
+ nativeSize.back () =
1327
+ mlir::spirv::getComputeVectorSize (vecType.getShape ().back ());
1328
+ return nativeSize;
1329
+ }
1330
+ }
1331
+
1332
+ return TypeSwitch<Operation *, std::optional<SmallVector<int64_t >>>(op)
1333
+ .Case <vector::ReductionOp, vector::TransposeOp>(
1334
+ [](auto typedOp) { return getNativeVectorShapeImpl (typedOp); })
1335
+ .Default ([](Operation *) { return std::nullopt; });
1336
+ }
1337
+
1338
+ LogicalResult mlir::spirv::unrollVectorsInSignatures (Operation *op) {
1339
+ MLIRContext *context = op->getContext ();
1340
+ RewritePatternSet patterns (context);
1341
+ populateFuncOpVectorRewritePatterns (patterns);
1342
+ populateReturnOpVectorRewritePatterns (patterns);
1343
+ // We only want to apply signature conversion once to the existing func ops.
1344
+ // Without specifying strictMode, the greedy pattern rewriter will keep
1345
+ // looking for newly created func ops.
1346
+ GreedyRewriteConfig config;
1347
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
1348
+ return applyPatternsAndFoldGreedily (op, std::move (patterns), config);
1349
+ }
1350
+
1351
+ LogicalResult mlir::spirv::unrollVectorsInFuncBodies (Operation *op) {
1352
+ MLIRContext *context = op->getContext ();
1353
+
1354
+ // Unroll vectors in function bodies to native vector size.
1355
+ {
1356
+ RewritePatternSet patterns (context);
1357
+ auto options = vector::UnrollVectorOptions ().setNativeShapeFn (
1358
+ [](auto op) { return mlir::spirv::getNativeVectorShape (op); });
1359
+ populateVectorUnrollPatterns (patterns, options);
1360
+ if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns))))
1361
+ return failure ();
1362
+ }
1363
+
1364
+ // Convert transpose ops into extract and insert pairs, in preparation of
1365
+ // further transformations to canonicalize/cancel.
1366
+ {
1367
+ RewritePatternSet patterns (context);
1368
+ auto options = vector::VectorTransformsOptions ().setVectorTransposeLowering (
1369
+ vector::VectorTransposeLowering::EltWise);
1370
+ vector::populateVectorTransposeLoweringPatterns (patterns, options);
1371
+ vector::populateVectorShapeCastLoweringPatterns (patterns);
1372
+ if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns))))
1373
+ return failure ();
1374
+ }
1375
+
1376
+ // Run canonicalization to cast away leading size-1 dimensions.
1377
+ {
1378
+ RewritePatternSet patterns (context);
1379
+
1380
+ // We need to pull in casting way leading one dims.
1381
+ vector::populateCastAwayVectorLeadingOneDimPatterns (patterns);
1382
+ vector::ReductionOp::getCanonicalizationPatterns (patterns, context);
1383
+ vector::TransposeOp::getCanonicalizationPatterns (patterns, context);
1384
+
1385
+ // Decompose different rank insert_strided_slice and n-D
1386
+ // extract_slided_slice.
1387
+ vector::populateVectorInsertExtractStridedSliceDecompositionPatterns (
1388
+ patterns);
1389
+ vector::InsertOp::getCanonicalizationPatterns (patterns, context);
1390
+ vector::ExtractOp::getCanonicalizationPatterns (patterns, context);
1391
+
1392
+ // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1393
+ // them up.
1394
+ vector::BroadcastOp::getCanonicalizationPatterns (patterns, context);
1395
+ vector::ShapeCastOp::getCanonicalizationPatterns (patterns, context);
1396
+
1397
+ if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns))))
1398
+ return failure ();
1399
+ }
1400
+ return success ();
1401
+ }
1402
+
1288
1403
// ===----------------------------------------------------------------------===//
1289
1404
// SPIR-V TypeConverter
1290
1405
// ===----------------------------------------------------------------------===//
0 commit comments