Skip to content

Commit ea2d938

Browse files
[mlir][Transforms][NFC] Improve listener layering in dialect conversion (#81236)
Context: Conversion patterns provide a `ConversionPatternRewriter` to modify the IR. `ConversionPatternRewriter` provides the public API. Most function calls are forwarded/handled by `ConversionPatternRewriterImpl`. The dialect conversion uses the listener infrastructure to get notified about op/block insertions. In the current design, `ConversionPatternRewriter` inherits from both `PatternRewriter` and `Listener`. The conversion rewriter registers itself as a listener. This is problematic because listener functions such as `notifyOperationInserted` are now part of the public API and can be called from conversion patterns; that would bring the dialect conversion into an inconsistent state. With this commit, `ConversionPatternRewriter` no longer inherits from `Listener`. Instead `ConversionPatternRewriterImpl` inherits from `Listener`. This removes the problematic public API and also simplifies the code a bit: block/op insertion notifications were previously forwarded to the `ConversionPatternRewriterImpl`. This is no longer needed.
1 parent 995c906 commit ea2d938

File tree

4 files changed

+29
-49
lines changed

4 files changed

+29
-49
lines changed

flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,12 +739,12 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
739739
void notifyOperationInserted(mlir::Operation *op,
740740
mlir::OpBuilder::InsertPoint previous) override {
741741
builder.notifyOperationInserted(op, previous);
742-
rewriter.notifyOperationInserted(op, previous);
742+
rewriter.getListener()->notifyOperationInserted(op, previous);
743743
}
744744
virtual void notifyBlockInserted(mlir::Block *block, mlir::Region *previous,
745745
mlir::Region::iterator previousIt) override {
746746
builder.notifyBlockInserted(block, previous, previousIt);
747-
rewriter.notifyBlockInserted(block, previous, previousIt);
747+
rewriter.getListener()->notifyBlockInserted(block, previous, previousIt);
748748
}
749749
fir::FirOpBuilder &builder;
750750
mlir::ConversionPatternRewriter &rewriter;

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -655,8 +655,7 @@ struct ConversionPatternRewriterImpl;
655655
/// This class implements a pattern rewriter for use with ConversionPatterns. It
656656
/// extends the base PatternRewriter and provides special conversion specific
657657
/// hooks.
658-
class ConversionPatternRewriter final : public PatternRewriter,
659-
public RewriterBase::Listener {
658+
class ConversionPatternRewriter final : public PatternRewriter {
660659
public:
661660
explicit ConversionPatternRewriter(MLIRContext *ctx);
662661
~ConversionPatternRewriter() override;
@@ -735,10 +734,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
735734
/// implemented for dialect conversion.
736735
void eraseBlock(Block *block) override;
737736

738-
/// PatternRewriter hook creating a new block.
739-
void notifyBlockInserted(Block *block, Region *previous,
740-
Region::iterator previousIt) override;
741-
742737
/// PatternRewriter hook for splitting a block into two parts.
743738
Block *splitBlock(Block *block, Block::iterator before) override;
744739

@@ -747,9 +742,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
747742
ValueRange argValues = std::nullopt) override;
748743
using PatternRewriter::inlineBlockBefore;
749744

750-
/// PatternRewriter hook for inserting a new operation.
751-
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
752-
753745
/// PatternRewriter hook for updating the given operation in-place.
754746
/// Note: These methods only track updates to the given operation itself,
755747
/// and not nested regions. Updates to regions will still require notification
@@ -762,18 +754,11 @@ class ConversionPatternRewriter final : public PatternRewriter,
762754
/// PatternRewriter hook for updating the given operation in-place.
763755
void cancelOpModification(Operation *op) override;
764756

765-
/// PatternRewriter hook for notifying match failure reasons.
766-
void
767-
notifyMatchFailure(Location loc,
768-
function_ref<void(Diagnostic &)> reasonCallback) override;
769-
using PatternRewriter::notifyMatchFailure;
770-
771757
/// Return a reference to the internal implementation.
772758
detail::ConversionPatternRewriterImpl &getImpl();
773759

774760
private:
775761
// Hide unsupported pattern rewriter API.
776-
using OpBuilder::getListener;
777762
using OpBuilder::setListener;
778763

779764
void moveOpBefore(Operation *op, Block *block,

mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
582582
// Inside regular functions we use the blocking wait operation to wait for
583583
// the async object (token, value or group) to become available.
584584
if (!isInCoroutine) {
585-
ImplicitLocOpBuilder builder(loc, op, &rewriter);
585+
ImplicitLocOpBuilder builder(loc, rewriter);
586586
builder.create<RuntimeAwaitOp>(loc, operand);
587587

588588
// Assert that the awaited operands is not in the error state.
@@ -601,7 +601,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
601601
CoroMachinery &coro = funcCoro->getSecond();
602602
Block *suspended = op->getBlock();
603603

604-
ImplicitLocOpBuilder builder(loc, op, &rewriter);
604+
ImplicitLocOpBuilder builder(loc, rewriter);
605605
MLIRContext *ctx = op->getContext();
606606

607607
// Save the coroutine state and resume on a runtime managed thread when

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ void ArgConverter::insertConversion(Block *newBlock,
825825
//===----------------------------------------------------------------------===//
826826
namespace mlir {
827827
namespace detail {
828-
struct ConversionPatternRewriterImpl {
828+
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
829829
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
830830
: argConverter(rewriter, unresolvedMaterializations),
831831
notifyCallback(nullptr) {}
@@ -903,15 +903,19 @@ struct ConversionPatternRewriterImpl {
903903
// Rewriter Notification Hooks
904904
//===--------------------------------------------------------------------===//
905905

906-
/// PatternRewriter hook for replacing the results of an operation.
906+
//// Notifies that an op was inserted.
907+
void notifyOperationInserted(Operation *op,
908+
OpBuilder::InsertPoint previous) override;
909+
910+
/// Notifies that an op is about to be replaced with the given values.
907911
void notifyOpReplaced(Operation *op, ValueRange newValues);
908912

909913
/// Notifies that a block is about to be erased.
910914
void notifyBlockIsBeingErased(Block *block);
911915

912-
/// Notifies that a block was created.
913-
void notifyInsertedBlock(Block *block, Region *previous,
914-
Region::iterator previousIt);
916+
/// Notifies that a block was inserted.
917+
void notifyBlockInserted(Block *block, Region *previous,
918+
Region::iterator previousIt) override;
915919

916920
/// Notifies that a block was split.
917921
void notifySplitBlock(Block *block, Block *continuation);
@@ -921,8 +925,9 @@ struct ConversionPatternRewriterImpl {
921925
Block::iterator before);
922926

923927
/// Notifies that a pattern match failed for the given reason.
924-
void notifyMatchFailure(Location loc,
925-
function_ref<void(Diagnostic &)> reasonCallback);
928+
void
929+
notifyMatchFailure(Location loc,
930+
function_ref<void(Diagnostic &)> reasonCallback) override;
926931

927932
//===--------------------------------------------------------------------===//
928933
// State
@@ -1363,6 +1368,16 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
13631368
//===----------------------------------------------------------------------===//
13641369
// Rewriter Notification Hooks
13651370

1371+
void ConversionPatternRewriterImpl::notifyOperationInserted(
1372+
Operation *op, OpBuilder::InsertPoint previous) {
1373+
assert(!previous.isSet() && "expected newly created op");
1374+
LLVM_DEBUG({
1375+
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
1376+
<< ")\n";
1377+
});
1378+
createdOps.push_back(op);
1379+
}
1380+
13661381
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13671382
ValueRange newValues) {
13681383
assert(newValues.size() == op->getNumResults());
@@ -1398,7 +1413,7 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
13981413
blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
13991414
}
14001415

1401-
void ConversionPatternRewriterImpl::notifyInsertedBlock(
1416+
void ConversionPatternRewriterImpl::notifyBlockInserted(
14021417
Block *block, Region *previous, Region::iterator previousIt) {
14031418
if (!previous) {
14041419
// This is a newly created block.
@@ -1437,7 +1452,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
14371452
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
14381453
: PatternRewriter(ctx),
14391454
impl(new detail::ConversionPatternRewriterImpl(*this)) {
1440-
setListener(this);
1455+
setListener(impl.get());
14411456
}
14421457

14431458
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
@@ -1540,11 +1555,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
15401555
results);
15411556
}
15421557

1543-
void ConversionPatternRewriter::notifyBlockInserted(
1544-
Block *block, Region *previous, Region::iterator previousIt) {
1545-
impl->notifyInsertedBlock(block, previous, previousIt);
1546-
}
1547-
15481558
Block *ConversionPatternRewriter::splitBlock(Block *block,
15491559
Block::iterator before) {
15501560
auto *continuation = block->splitBlock(before);
@@ -1572,16 +1582,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
15721582
eraseBlock(source);
15731583
}
15741584

1575-
void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
1576-
InsertPoint previous) {
1577-
assert(!previous.isSet() && "expected newly created op");
1578-
LLVM_DEBUG({
1579-
impl->logger.startLine()
1580-
<< "** Insert : '" << op->getName() << "'(" << op << ")\n";
1581-
});
1582-
impl->createdOps.push_back(op);
1583-
}
1584-
15851585
void ConversionPatternRewriter::startOpModification(Operation *op) {
15861586
#ifndef NDEBUG
15871587
impl->pendingRootUpdates.insert(op);
@@ -1614,11 +1614,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
16141614
rootUpdates.erase(rootUpdates.begin() + updateIdx);
16151615
}
16161616

1617-
void ConversionPatternRewriter::notifyMatchFailure(
1618-
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1619-
impl->notifyMatchFailure(loc, reasonCallback);
1620-
}
1621-
16221617
void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
16231618
Block::iterator iterator) {
16241619
llvm_unreachable(

0 commit comments

Comments
 (0)