33
33
#include " llvm/ADT/ScopeExit.h"
34
34
#include " llvm/ADT/TypeSwitch.h"
35
35
#include " llvm/Support/Debug.h"
36
+ #include " llvm/Support/InterleavedRange.h"
36
37
#include " llvm/Support/raw_ostream.h"
37
38
#include < type_traits>
38
39
#include < utility>
@@ -95,6 +96,10 @@ static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
95
96
}
96
97
return true ;
97
98
}
99
+
100
+ static std::string stringifyReassocIndices (ReassociationIndicesRef ri) {
101
+ return llvm::interleaved (ri, " , " , /* Prefix=*/ " |" , /* Suffix=*/ " " );
102
+ }
98
103
#endif // NDEBUG
99
104
100
105
// / Return the index of the first result of `map` that is a function of
@@ -278,22 +283,21 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
278
283
highs, paddingValue, /* nofold=*/ false );
279
284
280
285
LLVM_DEBUG (
281
- DBGSNL (); DBGSNL (); llvm::interleaveComma (packingMetadata.insertPositions ,
282
- DBGS () << " insertPositions: " );
283
- DBGSNL (); llvm::interleaveComma (packingMetadata.outerPositions ,
284
- DBGS () << " outerPositions: " );
285
- DBGSNL (); llvm::interleaveComma (packedTensorType.getShape (),
286
- DBGS () << " packedShape: " );
286
+ DBGSNL (); DBGSNL ();
287
+ DBGS () << " insertPositions: "
288
+ << llvm::interleaved (packingMetadata.insertPositions );
289
+ DBGSNL (); DBGS () << " outerPositions: "
290
+ << llvm::interleaved (packingMetadata.outerPositions );
291
+ DBGSNL (); DBGS () << " packedShape: "
292
+ << llvm::interleaved (packedTensorType.getShape ());
293
+ DBGSNL (); DBGS () << " packedToStripMinedShapePerm: "
294
+ << llvm::interleaved (packedToStripMinedShapePerm);
287
295
DBGSNL ();
288
- llvm::interleaveComma (packedToStripMinedShapePerm,
289
- DBGS () << " packedToStripMinedShapePerm: " );
290
- DBGSNL (); llvm::interleaveComma (
291
- packingMetadata.reassociations , DBGS () << " reassociations: " ,
292
- [&](ReassociationIndices ri) {
293
- llvm::interleaveComma (ri, llvm::dbgs () << " |" );
294
- });
296
+ DBGS () << " reassociations: "
297
+ << llvm::interleaved (llvm::map_range (
298
+ packingMetadata.reassociations , stringifyReassocIndices));
295
299
DBGSNL ();
296
- llvm::interleaveComma (stripMinedShape, DBGS () << " stripMinedShape: " );
300
+ DBGS () << " stripMinedShape: " << llvm::interleaved (stripMinedShape );
297
301
DBGSNL (); DBGS () << " collapsed type: " << collapsed; DBGSNL (););
298
302
299
303
if (lowerPadLikeWithInsertSlice && packOp.isLikePad ()) {
@@ -343,7 +347,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
343
347
344
348
LLVM_DEBUG (DBGSNL (); DBGSNL (); DBGSNL ();
345
349
DBGS () << " reshape op: " << reshapeOp; DBGSNL ();
346
- llvm::interleaveComma (transpPerm, DBGS () << " transpPerm: " );
350
+ DBGS () << " transpPerm: " << llvm::interleaved (transpPerm );
347
351
DBGSNL (); DBGS () << " transpose op: " << transposeOp; DBGSNL (););
348
352
349
353
// 7. Replace packOp by transposeOp.
@@ -412,20 +416,19 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
412
416
loc, unPackOp.getSource (), emptyOp, packedToStripMinedShapePerm);
413
417
414
418
LLVM_DEBUG (
415
- DBGSNL (); DBGSNL (); llvm::interleaveComma (packingMetadata.insertPositions ,
416
- DBGS () << " insertPositions: " );
417
- DBGSNL (); llvm::interleaveComma (packedTensorType.getShape (),
418
- DBGS () << " packedShape: " );
419
+ DBGSNL (); DBGSNL ();
420
+ DBGS () << " insertPositions: "
421
+ << llvm::interleaved (packingMetadata.insertPositions );
422
+ DBGSNL (); DBGS () << " packedShape: "
423
+ << llvm::interleaved (packedTensorType.getShape ());
424
+ DBGSNL (); DBGS () << " packedToStripMinedShapePerm: "
425
+ << llvm::interleaved (packedToStripMinedShapePerm);
419
426
DBGSNL ();
420
- llvm::interleaveComma (packedToStripMinedShapePerm,
421
- DBGS () << " packedToStripMinedShapePerm: " );
422
- DBGSNL (); llvm::interleaveComma (
423
- packingMetadata.reassociations , DBGS () << " reassociations: " ,
424
- [&](ReassociationIndices ri) {
425
- llvm::interleaveComma (ri, llvm::dbgs () << " |" );
426
- });
427
+ DBGS () << " reassociations: "
428
+ << llvm::interleaved (llvm::map_range (
429
+ packingMetadata.reassociations , stringifyReassocIndices));
427
430
DBGSNL ();
428
- llvm::interleaveComma (stripMinedShape, DBGS () << " stripMinedShape: " );
431
+ DBGS () << " stripMinedShape: " << llvm::interleaved (stripMinedShape );
429
432
DBGSNL (); DBGS () << " collapsed type: " << collapsedType; DBGSNL (););
430
433
431
434
// 4. Collapse from the stripMinedShape to the padded result.
@@ -488,10 +491,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
488
491
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray ();
489
492
SmallVector<utils::IteratorType> iteratorTypes =
490
493
linalgOp.getIteratorTypesArray ();
491
- LLVM_DEBUG (DBGS () << " Start packing: " << linalgOp << " \n " ;
492
- llvm::interleaveComma (indexingMaps, DBGS () << " maps: " ); DBGSNL ();
493
- llvm::interleaveComma (iteratorTypes, DBGS () << " iterators: " );
494
- DBGSNL (); );
494
+ LLVM_DEBUG (DBGS () << " Start packing: " << linalgOp << " \n "
495
+ << " maps: " << llvm::interleaved (indexingMaps) << " \n "
496
+ << " iterators: " << llvm::interleaved (iteratorTypes)
497
+ << " \n " );
495
498
496
499
SmallVector<linalg::PackOp> packOps;
497
500
SmallVector<linalg::UnPackOp> unPackOps;
@@ -515,18 +518,18 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
515
518
516
519
LLVM_DEBUG (
517
520
DBGS () << " ++++ After pack size #" << i << " : " << packedSizes[i]
518
- << " \n " ;
519
- llvm::interleaveComma (indexingMaps, DBGS () << " maps: " ); DBGSNL ();
520
- llvm::interleaveComma (iteratorTypes, DBGS () << " iterators: " ); DBGSNL ();
521
- llvm::interleaveComma (packedOperandsDims. packedDimForEachOperand ,
522
- DBGS () << " packedDimForEachOperand: " );
523
- DBGSNL (); );
521
+ << " \n "
522
+ << " maps: " << llvm::interleaved (indexingMaps) << " \n "
523
+ << " iterators: " << llvm::interleaved (iteratorTypes) << " \n "
524
+ << " packedDimForEachOperand: "
525
+ << llvm::interleaved (packedOperandsDims. packedDimForEachOperand )
526
+ << " \n " );
524
527
}
525
528
526
529
// Step 2. Propagate packing to all LinalgOp operands.
527
530
SmallVector<Value> inputsAndInits, results;
528
- SmallVector<OpOperand *> initOperands = llvm::to_vector ( llvm::map_range (
529
- linalgOp.getDpsInitsMutable (), [](OpOperand &o) { return &o; } ));
531
+ SmallVector<OpOperand *> initOperands =
532
+ llvm::to_vector ( llvm::make_pointer_range ( linalgOp.getDpsInitsMutable ()));
530
533
SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands ();
531
534
for (const auto &operandsList : {inputOperands, initOperands}) {
532
535
for (OpOperand *opOperand : operandsList) {
@@ -536,11 +539,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
536
539
listOfPackedOperandsDim.extractPackedDimsForOperand (pos);
537
540
SmallVector<OpFoldResult> innerPackSizes =
538
541
listOfPackedOperandsDim.extractPackSizesForOperand (pos);
539
- LLVM_DEBUG (
540
- DBGS () << " operand: " << operand << " \n " ;
541
- llvm::interleaveComma (innerPos, DBGS () << " innerPos: " ); DBGSNL ();
542
- llvm::interleaveComma (innerPackSizes, DBGS () << " innerPackSizes: " );
543
- DBGSNL (););
542
+ LLVM_DEBUG (DBGS () << " operand: " << operand << " \n "
543
+ << " innerPos: " << llvm::interleaved (innerPos) << " \n "
544
+ << " innerPackSizes: "
545
+ << llvm::interleaved (innerPackSizes) << " \n " );
544
546
if (innerPackSizes.empty ()) {
545
547
inputsAndInits.push_back (operand);
546
548
continue ;
@@ -835,7 +837,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
835
837
// not change the indexings of any operand.
836
838
SmallVector<int64_t > permutation =
837
839
computePermutationVector (numLoops, {mPos , nPos, kPos }, mmnnkkPos);
838
- LLVM_DEBUG (llvm::interleaveComma (permutation, DBGS () << " perm: " ); DBGSNL (); );
840
+ LLVM_DEBUG (DBGS () << " perm: " << llvm::interleaved (permutation) << " \n " );
839
841
// Sign .. unsigned pollution.
840
842
SmallVector<unsigned > unsignedPerm (permutation.begin (), permutation.end ());
841
843
FailureOr<GenericOp> interchangeResult =
@@ -864,12 +866,12 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
864
866
865
867
// Add leading zeros to match numLoops, we only pack the last 3 dimensions
866
868
// post interchange.
867
- LLVM_DEBUG (llvm::interleaveComma ( paddedSizesNextMultipleOf,
868
- DBGS ( ) << " paddedSizesNextMultipleOf: " );
869
- DBGSNL (););
870
- LLVM_DEBUG ( llvm::interleaveComma (loopRanges, DBGS () << " loopRanges: " ,
871
- [](Range r) { llvm::dbgs () << r.size ; });
872
- DBGSNL (); );
869
+ LLVM_DEBUG (DBGS () << " paddedSizesNextMultipleOf: "
870
+ << llvm::interleaved (paddedSizesNextMultipleOf ) << " \n "
871
+ << " loopRanges: "
872
+ << llvm::interleaved ( llvm::map_range (
873
+ loopRanges, [](Range r) { return r.size ; }))
874
+ << " \n " );
873
875
SmallVector<OpFoldResult> adjustedPackedSizes (numLoops - packedSizes.size (),
874
876
rewriter.getIndexAttr (0 ));
875
877
for (int64_t i = 0 , e = numPackedDims; i < e; ++i) {
@@ -885,9 +887,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
885
887
{loopRanges[adjustedPackedSizes.size ()].size ,
886
888
rewriter.getIndexAttr (paddedSizesNextMultipleOf[i])}));
887
889
}
888
- LLVM_DEBUG (llvm::interleaveComma (adjustedPackedSizes,
889
- DBGS () << " adjustedPackedSizes: " );
890
- DBGSNL (););
890
+ LLVM_DEBUG (DBGS () << " adjustedPackedSizes: "
891
+ << llvm::interleaved (adjustedPackedSizes) << " \n " );
891
892
892
893
// TODO: If we wanted to give the genericOp a name after packing, after
893
894
// calling `pack` would be a good time. One would still need to check that
@@ -1202,9 +1203,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
1202
1203
1203
1204
srcPermForTranspose.append (SmallVector<int64_t >(packOp.getInnerDimsPos ()));
1204
1205
1205
- LLVM_DEBUG (DBGS () << " Pack permutation: " << packOp << " \n " ;
1206
- llvm::interleaveComma (srcPermForTranspose, DBGS () << " perm: " );
1207
- DBGSNL (); );
1206
+ LLVM_DEBUG (DBGS () << " Pack permutation: " << packOp << " \n "
1207
+ << " perm: " << llvm::interleaved (srcPermForTranspose)
1208
+ << " \n " );
1208
1209
1209
1210
// 2.1 Create tensor.empty (init value for TransposeOp)
1210
1211
SmallVector<OpFoldResult> transShapeForEmptyOp (srcRank - numTiles,
0 commit comments