14
14
#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
15
15
#include " mlir/Dialect/Affine/IR/AffineOps.h"
16
16
#include " mlir/Dialect/Arith/IR/Arith.h"
17
+ #include " mlir/Dialect/Arith/Utils/Utils.h"
17
18
#include " mlir/Dialect/Func/IR/FuncOps.h"
18
19
#include " mlir/Dialect/Linalg/IR/Linalg.h"
19
20
#include " mlir/Dialect/Linalg/Utils/Utils.h"
30
31
#include " mlir/Pass/Pass.h"
31
32
#include " mlir/Support/LLVM.h"
32
33
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
34
+ #include " llvm/ADT/STLExtras.h"
33
35
#include " llvm/ADT/ScopeExit.h"
34
36
#include " llvm/ADT/TypeSwitch.h"
37
+ #include " llvm/ADT/iterator.h"
35
38
#include " llvm/Support/Debug.h"
39
+ #include " llvm/Support/InterleavedRange.h"
36
40
#include " llvm/Support/raw_ostream.h"
37
41
#include < type_traits>
38
42
#include < utility>
@@ -95,6 +99,10 @@ static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
95
99
}
96
100
return true ;
97
101
}
102
+
103
+ static std::string stringifyReassocIndices (ReassociationIndicesRef ri) {
104
+ return llvm::interleaved (ri, " , " , /* Prefix=*/ " |" , /* Suffix=*/ " " );
105
+ }
98
106
#endif // NDEBUG
99
107
100
108
// / Return the index of the first result of `map` that is a function of
@@ -278,22 +286,21 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
278
286
highs, paddingValue, /* nofold=*/ false );
279
287
280
288
LLVM_DEBUG (
281
- DBGSNL (); DBGSNL (); llvm::interleaveComma (packingMetadata.insertPositions ,
282
- DBGS () << " insertPositions: " );
283
- DBGSNL (); llvm::interleaveComma (packingMetadata.outerPositions ,
284
- DBGS () << " outerPositions: " );
285
- DBGSNL (); llvm::interleaveComma (packedTensorType.getShape (),
286
- DBGS () << " packedShape: " );
289
+ DBGSNL (); DBGSNL ();
290
+ DBGS () << " insertPositions: "
291
+ << llvm::interleaved (packingMetadata.insertPositions );
292
+ DBGSNL (); DBGS () << " outerPositions: "
293
+ << llvm::interleaved (packingMetadata.outerPositions );
294
+ DBGSNL (); DBGS () << " packedShape: "
295
+ << llvm::interleaved (packedTensorType.getShape ());
296
+ DBGSNL (); DBGS () << " packedToStripMinedShapePerm: "
297
+ << llvm::interleaved (packedToStripMinedShapePerm);
287
298
DBGSNL ();
288
- llvm::interleaveComma (packedToStripMinedShapePerm,
289
- DBGS () << " packedToStripMinedShapePerm: " );
290
- DBGSNL (); llvm::interleaveComma (
291
- packingMetadata.reassociations , DBGS () << " reassociations: " ,
292
- [&](ReassociationIndices ri) {
293
- llvm::interleaveComma (ri, llvm::dbgs () << " |" );
294
- });
299
+ DBGS () << " reassociations: "
300
+ << llvm::interleaved (llvm::map_range (
301
+ packingMetadata.reassociations , stringifyReassocIndices));
295
302
DBGSNL ();
296
- llvm::interleaveComma (stripMinedShape, DBGS () << " stripMinedShape: " );
303
+ DBGS () << " stripMinedShape: " << llvm::interleaved (stripMinedShape );
297
304
DBGSNL (); DBGS () << " collapsed type: " << collapsed; DBGSNL (););
298
305
299
306
if (lowerPadLikeWithInsertSlice && packOp.isLikePad ()) {
@@ -343,7 +350,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
343
350
344
351
LLVM_DEBUG (DBGSNL (); DBGSNL (); DBGSNL ();
345
352
DBGS () << " reshape op: " << reshapeOp; DBGSNL ();
346
- llvm::interleaveComma (transpPerm, DBGS () << " transpPerm: " );
353
+ DBGS () << " transpPerm: " << llvm::interleaved (transpPerm );
347
354
DBGSNL (); DBGS () << " transpose op: " << transposeOp; DBGSNL (););
348
355
349
356
// 7. Replace packOp by transposeOp.
@@ -412,20 +419,19 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
412
419
loc, unPackOp.getSource (), emptyOp, packedToStripMinedShapePerm);
413
420
414
421
LLVM_DEBUG (
415
- DBGSNL (); DBGSNL (); llvm::interleaveComma (packingMetadata.insertPositions ,
416
- DBGS () << " insertPositions: " );
417
- DBGSNL (); llvm::interleaveComma (packedTensorType.getShape (),
418
- DBGS () << " packedShape: " );
422
+ DBGSNL (); DBGSNL ();
423
+ DBGS () << " insertPositions: "
424
+ << llvm::interleaved (packingMetadata.insertPositions );
425
+ DBGSNL (); DBGS () << " packedShape: "
426
+ << llvm::interleaved (packedTensorType.getShape ());
427
+ DBGSNL (); DBGS () << " packedToStripMinedShapePerm: "
428
+ << llvm::interleaved (packedToStripMinedShapePerm);
419
429
DBGSNL ();
420
- llvm::interleaveComma (packedToStripMinedShapePerm,
421
- DBGS () << " packedToStripMinedShapePerm: " );
422
- DBGSNL (); llvm::interleaveComma (
423
- packingMetadata.reassociations , DBGS () << " reassociations: " ,
424
- [&](ReassociationIndices ri) {
425
- llvm::interleaveComma (ri, llvm::dbgs () << " |" );
426
- });
430
+ DBGS () << " reassociations: "
431
+ << llvm::interleaved (llvm::map_range (
432
+ packingMetadata.reassociations , stringifyReassocIndices));
427
433
DBGSNL ();
428
- llvm::interleaveComma (stripMinedShape, DBGS () << " stripMinedShape: " );
434
+ DBGS () << " stripMinedShape: " << llvm::interleaved (stripMinedShape );
429
435
DBGSNL (); DBGS () << " collapsed type: " << collapsedType; DBGSNL (););
430
436
431
437
// 4. Collapse from the stripMinedShape to the padded result.
@@ -488,10 +494,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
488
494
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray ();
489
495
SmallVector<utils::IteratorType> iteratorTypes =
490
496
linalgOp.getIteratorTypesArray ();
491
- LLVM_DEBUG (DBGS () << " Start packing: " << linalgOp << " \n " ;
492
- llvm::interleaveComma (indexingMaps, DBGS () << " maps: " ); DBGSNL ();
493
- llvm::interleaveComma (iteratorTypes, DBGS () << " iterators: " );
494
- DBGSNL (); );
497
+ LLVM_DEBUG (DBGS () << " Start packing: " << linalgOp << " \n "
498
+ << " maps: " << llvm::interleaved (indexingMaps) << " \n "
499
+ << " iterators: " << llvm::interleaved (iteratorTypes)
500
+ << " \n " );
495
501
496
502
SmallVector<linalg::PackOp> packOps;
497
503
SmallVector<linalg::UnPackOp> unPackOps;
@@ -515,18 +521,18 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
515
521
516
522
LLVM_DEBUG (
517
523
DBGS () << " ++++ After pack size #" << i << " : " << packedSizes[i]
518
- << " \n " ;
519
- llvm::interleaveComma (indexingMaps, DBGS () << " maps: " ); DBGSNL ();
520
- llvm::interleaveComma (iteratorTypes, DBGS () << " iterators: " ); DBGSNL ();
521
- llvm::interleaveComma (packedOperandsDims. packedDimForEachOperand ,
522
- DBGS () << " packedDimForEachOperand: " );
523
- DBGSNL (); );
524
+ << " \n "
525
+ << " maps: " << llvm::interleaved (indexingMaps) << " \n "
526
+ << " iterators: " << llvm::interleaved (iteratorTypes) << " \n "
527
+ << " packedDimForEachOperand: "
528
+ << llvm::interleaved (packedOperandsDims. packedDimForEachOperand )
529
+ << " \n " );
524
530
}
525
531
526
532
// Step 2. Propagate packing to all LinalgOp operands.
527
533
SmallVector<Value> inputsAndInits, results;
528
- SmallVector<OpOperand *> initOperands = llvm::to_vector ( llvm::map_range (
529
- linalgOp.getDpsInitsMutable (), [](OpOperand &o) { return &o; } ));
534
+ SmallVector<OpOperand *> initOperands =
535
+ llvm::to_vector ( llvm::make_pointer_range ( linalgOp.getDpsInitsMutable ()));
530
536
SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands ();
531
537
for (const auto &operandsList : {inputOperands, initOperands}) {
532
538
for (OpOperand *opOperand : operandsList) {
@@ -536,11 +542,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
536
542
listOfPackedOperandsDim.extractPackedDimsForOperand (pos);
537
543
SmallVector<OpFoldResult> innerPackSizes =
538
544
listOfPackedOperandsDim.extractPackSizesForOperand (pos);
539
- LLVM_DEBUG (
540
- DBGS () << " operand: " << operand << " \n " ;
541
- llvm::interleaveComma (innerPos, DBGS () << " innerPos: " ); DBGSNL ();
542
- llvm::interleaveComma (innerPackSizes, DBGS () << " innerPackSizes: " );
543
- DBGSNL (););
545
+ LLVM_DEBUG (DBGS () << " operand: " << operand << " \n "
546
+ << " innerPos: " << llvm::interleaved (innerPos) << " \n "
547
+ << " innerPackSizes: "
548
+ << llvm::interleaved (innerPackSizes) << " \n " );
544
549
if (innerPackSizes.empty ()) {
545
550
inputsAndInits.push_back (operand);
546
551
continue ;
@@ -835,7 +840,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
835
840
// not change the indexings of any operand.
836
841
SmallVector<int64_t > permutation =
837
842
computePermutationVector (numLoops, {mPos , nPos, kPos }, mmnnkkPos);
838
- LLVM_DEBUG (llvm::interleaveComma (permutation, DBGS () << " perm: " ); DBGSNL (); );
843
+ LLVM_DEBUG (DBGS () << " perm: " << llvm::interleaved (permutation) << " \n " );
839
844
// Sign .. unsigned pollution.
840
845
SmallVector<unsigned > unsignedPerm (permutation.begin (), permutation.end ());
841
846
FailureOr<GenericOp> interchangeResult =
@@ -864,12 +869,12 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
864
869
865
870
// Add leading zeros to match numLoops, we only pack the last 3 dimensions
866
871
// post interchange.
867
- LLVM_DEBUG (llvm::interleaveComma ( paddedSizesNextMultipleOf,
868
- DBGS ( ) << " paddedSizesNextMultipleOf: " );
869
- DBGSNL (););
870
- LLVM_DEBUG ( llvm::interleaveComma (loopRanges, DBGS () << " loopRanges: " ,
871
- [](Range r) { llvm::dbgs () << r.size ; });
872
- DBGSNL (); );
872
+ LLVM_DEBUG (DBGS () << " paddedSizesNextMultipleOf: "
873
+ << llvm::interleaved (paddedSizesNextMultipleOf ) << " \n " );
874
+ LLVM_DEBUG ( DBGS () << " loopRanges: "
875
+ << llvm::interleaved ( llvm::map_range (
876
+ loopRanges, [](Range r) { return r.size ; }))
877
+ << " \n " );
873
878
SmallVector<OpFoldResult> adjustedPackedSizes (numLoops - packedSizes.size (),
874
879
rewriter.getIndexAttr (0 ));
875
880
for (int64_t i = 0 , e = numPackedDims; i < e; ++i) {
@@ -885,9 +890,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
885
890
{loopRanges[adjustedPackedSizes.size ()].size ,
886
891
rewriter.getIndexAttr (paddedSizesNextMultipleOf[i])}));
887
892
}
888
- LLVM_DEBUG (llvm::interleaveComma (adjustedPackedSizes,
889
- DBGS () << " adjustedPackedSizes: " );
890
- DBGSNL (););
893
+ LLVM_DEBUG (DBGS () << " adjustedPackedSizes: "
894
+ << llvm::interleaved (adjustedPackedSizes) << " \n " );
891
895
892
896
// TODO: If we wanted to give the genericOp a name after packing, after
893
897
// calling `pack` would be a good time. One would still need to check that
@@ -1202,9 +1206,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
1202
1206
1203
1207
srcPermForTranspose.append (SmallVector<int64_t >(packOp.getInnerDimsPos ()));
1204
1208
1205
- LLVM_DEBUG (DBGS () << " Pack permutation: " << packOp << " \n " ;
1206
- llvm::interleaveComma (srcPermForTranspose, DBGS () << " perm: " );
1207
- DBGSNL (); );
1209
+ LLVM_DEBUG (DBGS () << " Pack permutation: " << packOp << " \n "
1210
+ << " perm: " << llvm::interleaved (srcPermForTranspose)
1211
+ << " \n " );
1208
1212
1209
1213
// 2.1 Create tensor.empty (init value for TransposeOp)
1210
1214
SmallVector<OpFoldResult> transShapeForEmptyOp (srcRank - numTiles,
0 commit comments