Skip to content

Commit 843062e

Browse files
committed
[mlir][linalg] Clean up debug prints. NFC.
Use `llvm::interleaved` from llvm#135517 to simplify printing.
1 parent 198c5da commit 843062e

File tree

1 file changed

+61
-57
lines changed

1 file changed

+61
-57
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 61 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1515
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1616
#include "mlir/Dialect/Arith/IR/Arith.h"
17+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1718
#include "mlir/Dialect/Func/IR/FuncOps.h"
1819
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1920
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -30,9 +31,12 @@
3031
#include "mlir/Pass/Pass.h"
3132
#include "mlir/Support/LLVM.h"
3233
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34+
#include "llvm/ADT/STLExtras.h"
3335
#include "llvm/ADT/ScopeExit.h"
3436
#include "llvm/ADT/TypeSwitch.h"
37+
#include "llvm/ADT/iterator.h"
3538
#include "llvm/Support/Debug.h"
39+
#include "llvm/Support/InterleavedRange.h"
3640
#include "llvm/Support/raw_ostream.h"
3741
#include <type_traits>
3842
#include <utility>
@@ -95,6 +99,10 @@ static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
9599
}
96100
return true;
97101
}
102+
103+
static std::string stringifyReassocIndices(ReassociationIndicesRef ri) {
104+
return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/"");
105+
}
98106
#endif // NDEBUG
99107

100108
/// Return the index of the first result of `map` that is a function of
@@ -278,22 +286,21 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
278286
highs, paddingValue, /*nofold=*/false);
279287

280288
LLVM_DEBUG(
281-
DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
282-
DBGS() << "insertPositions: ");
283-
DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
284-
DBGS() << "outerPositions: ");
285-
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
286-
DBGS() << "packedShape: ");
289+
DBGSNL(); DBGSNL();
290+
DBGS() << "insertPositions: "
291+
<< llvm::interleaved(packingMetadata.insertPositions);
292+
DBGSNL(); DBGS() << "outerPositions: "
293+
<< llvm::interleaved(packingMetadata.outerPositions);
294+
DBGSNL(); DBGS() << "packedShape: "
295+
<< llvm::interleaved(packedTensorType.getShape());
296+
DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
297+
<< llvm::interleaved(packedToStripMinedShapePerm);
287298
DBGSNL();
288-
llvm::interleaveComma(packedToStripMinedShapePerm,
289-
DBGS() << "packedToStripMinedShapePerm: ");
290-
DBGSNL(); llvm::interleaveComma(
291-
packingMetadata.reassociations, DBGS() << "reassociations: ",
292-
[&](ReassociationIndices ri) {
293-
llvm::interleaveComma(ri, llvm::dbgs() << "|");
294-
});
299+
DBGS() << "reassociations: "
300+
<< llvm::interleaved(llvm::map_range(
301+
packingMetadata.reassociations, stringifyReassocIndices));
295302
DBGSNL();
296-
llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
303+
DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
297304
DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
298305

299306
if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
@@ -343,7 +350,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
343350

344351
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
345352
DBGS() << "reshape op: " << reshapeOp; DBGSNL();
346-
llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
353+
DBGS() << "transpPerm: " << llvm::interleaved(transpPerm);
347354
DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
348355

349356
// 7. Replace packOp by transposeOp.
@@ -412,20 +419,19 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
412419
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
413420

414421
LLVM_DEBUG(
415-
DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
416-
DBGS() << "insertPositions: ");
417-
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
418-
DBGS() << "packedShape: ");
422+
DBGSNL(); DBGSNL();
423+
DBGS() << "insertPositions: "
424+
<< llvm::interleaved(packingMetadata.insertPositions);
425+
DBGSNL(); DBGS() << "packedShape: "
426+
<< llvm::interleaved(packedTensorType.getShape());
427+
DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
428+
<< llvm::interleaved(packedToStripMinedShapePerm);
419429
DBGSNL();
420-
llvm::interleaveComma(packedToStripMinedShapePerm,
421-
DBGS() << "packedToStripMinedShapePerm: ");
422-
DBGSNL(); llvm::interleaveComma(
423-
packingMetadata.reassociations, DBGS() << "reassociations: ",
424-
[&](ReassociationIndices ri) {
425-
llvm::interleaveComma(ri, llvm::dbgs() << "|");
426-
});
430+
DBGS() << "reassociations: "
431+
<< llvm::interleaved(llvm::map_range(
432+
packingMetadata.reassociations, stringifyReassocIndices));
427433
DBGSNL();
428-
llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
434+
DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
429435
DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
430436

431437
// 4. Collapse from the stripMinedShape to the padded result.
@@ -488,10 +494,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
488494
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
489495
SmallVector<utils::IteratorType> iteratorTypes =
490496
linalgOp.getIteratorTypesArray();
491-
LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n";
492-
llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
493-
llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
494-
DBGSNL(););
497+
LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"
498+
<< "maps: " << llvm::interleaved(indexingMaps) << "\n"
499+
<< "iterators: " << llvm::interleaved(iteratorTypes)
500+
<< "\n");
495501

496502
SmallVector<linalg::PackOp> packOps;
497503
SmallVector<linalg::UnPackOp> unPackOps;
@@ -515,18 +521,18 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
515521

516522
LLVM_DEBUG(
517523
DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
518-
<< "\n";
519-
llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
520-
llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();
521-
llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
522-
DBGS() << "packedDimForEachOperand: ");
523-
DBGSNL(););
524+
<< "\n"
525+
<< "maps: " << llvm::interleaved(indexingMaps) << "\n"
526+
<< "iterators: " << llvm::interleaved(iteratorTypes) << "\n"
527+
<< "packedDimForEachOperand: "
528+
<< llvm::interleaved(packedOperandsDims.packedDimForEachOperand)
529+
<< "\n");
524530
}
525531

526532
// Step 2. Propagate packing to all LinalgOp operands.
527533
SmallVector<Value> inputsAndInits, results;
528-
SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range(
529-
linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
534+
SmallVector<OpOperand *> initOperands =
535+
llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
530536
SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
531537
for (const auto &operandsList : {inputOperands, initOperands}) {
532538
for (OpOperand *opOperand : operandsList) {
@@ -536,11 +542,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
536542
listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
537543
SmallVector<OpFoldResult> innerPackSizes =
538544
listOfPackedOperandsDim.extractPackSizesForOperand(pos);
539-
LLVM_DEBUG(
540-
DBGS() << "operand: " << operand << "\n";
541-
llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL();
542-
llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: ");
543-
DBGSNL(););
545+
LLVM_DEBUG(DBGS() << "operand: " << operand << "\n"
546+
<< "innerPos: " << llvm::interleaved(innerPos) << "\n"
547+
<< "innerPackSizes: "
548+
<< llvm::interleaved(innerPackSizes) << "\n");
544549
if (innerPackSizes.empty()) {
545550
inputsAndInits.push_back(operand);
546551
continue;
@@ -835,7 +840,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
835840
// not change the indexings of any operand.
836841
SmallVector<int64_t> permutation =
837842
computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
838-
LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
843+
LLVM_DEBUG(DBGS() << "perm: " << llvm::interleaved(permutation) << "\n");
839844
// Sign .. unsigned pollution.
840845
SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
841846
FailureOr<GenericOp> interchangeResult =
@@ -864,12 +869,12 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
864869

865870
// Add leading zeros to match numLoops, we only pack the last 3 dimensions
866871
// post interchange.
867-
LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
868-
DBGS() << "paddedSizesNextMultipleOf: ");
869-
DBGSNL(););
870-
LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ",
871-
[](Range r) { llvm::dbgs() << r.size; });
872-
DBGSNL(););
872+
LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: "
873+
<< llvm::interleaved(paddedSizesNextMultipleOf) << "\n");
874+
LLVM_DEBUG(DBGS() << "loopRanges: "
875+
<< llvm::interleaved(llvm::map_range(
876+
loopRanges, [](Range r) { return r.size; }))
877+
<< "\n");
873878
SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
874879
rewriter.getIndexAttr(0));
875880
for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
@@ -885,9 +890,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
885890
{loopRanges[adjustedPackedSizes.size()].size,
886891
rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
887892
}
888-
LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
889-
DBGS() << "adjustedPackedSizes: ");
890-
DBGSNL(););
893+
LLVM_DEBUG(DBGS() << "adjustedPackedSizes: "
894+
<< llvm::interleaved(adjustedPackedSizes) << "\n");
891895

892896
// TODO: If we wanted to give the genericOp a name after packing, after
893897
// calling `pack` would be a good time. One would still need to check that
@@ -1202,9 +1206,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12021206

12031207
srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
12041208

1205-
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
1206-
llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: ");
1207-
DBGSNL(););
1209+
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
1210+
<< "perm: " << llvm::interleaved(srcPermForTranspose)
1211+
<< "\n");
12081212

12091213
// 2.1 Create tensor.empty (init value for TransposeOp)
12101214
SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,

0 commit comments

Comments
 (0)