Skip to content

[mlir][Transforms] Dialect conversion: Assert when accessing erased ops #83132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 55 additions & 38 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -798,13 +798,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
PatternRewriter &rewriter, ValueRange values,
SmallVectorImpl<Value> &remapped);

/// Returns true if the given operation is ignored, and does not need to be
/// Return "true" if the given operation is ignored, and does not need to be
/// converted.
bool isOpIgnored(Operation *op) const;

/// Recursively marks the nested operations under 'op' as ignored. This
/// removes them from being considered for legalization.
void markNestedOpsIgnored(Operation *op);
/// Return "true" if the given operation was replaced or erased.
bool wasOpReplaced(Operation *op) const;

//===--------------------------------------------------------------------===//
// Type Conversion
Expand Down Expand Up @@ -946,18 +945,15 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Ordered list of block operations (creations, splits, motions).
SmallVector<std::unique_ptr<IRRewrite>> rewrites;

/// A set of operations that should no longer be considered for legalization,
/// but were not directly replace/erased/etc. by a pattern. These are
/// generally child operations of other operations who were
/// replaced/erased/etc. This is not meant to be an exhaustive list of all
/// operations, but the minimal set that can be used to detect if a given
/// operation should be `ignored`. For example, we may add the operations that
/// define non-empty regions to the set, but not any of the others. This
/// simplifies the amount of memory needed as we can query if the parent
/// operation was ignored.
/// A set of operations that should no longer be considered for legalization.
/// E.g., ops that are recursively legal. Ops that were replaced/erased are
/// tracked separately.
SetVector<Operation *> ignoredOps;

// A set of operations that were erased.
/// A set of operations that were replaced/erased. Such ops are not erased
/// immediately but only when the dialect conversion succeeds. In the mean
/// time, they should no longer be considered for legalization and any attempt
/// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;

/// The current type converter, or nullptr if no type converter is currently
Expand Down Expand Up @@ -1237,24 +1233,14 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
return success();
}

// TODO: This function is a misnomer. It does not actually check if `op` is in
// `ignoredOps`.
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation or the parent operation is ignored.
return ignoredOps.count(op->getParentOp()) || replacedOps.count(op);
// Check to see if this operation is ignored or was replaced.
return replacedOps.count(op) || ignoredOps.count(op);
}

void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
// Walk this operation and collect nested operations that define non-empty
// regions. We mark such operations as 'ignored' so that we know we don't have
// to convert them, or their nested ops.
if (op->getNumRegions() == 0)
return;
op->walk([&](Operation *op) {
if (llvm::any_of(op->getRegions(),
[](Region &region) { return !region.empty(); }))
ignoredOps.insert(op);
});
bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
// Check to see if this operation was replaced.
return replacedOps.count(op);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1476,6 +1462,9 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
assert(!wasOpReplaced(op->getParentOp()) &&
"attempting to insert into a block within a replaced/erased op");

if (!previous.isSet()) {
// This is a newly created op.
appendRewrite<CreateOperationRewrite>(op);
Expand All @@ -1490,7 +1479,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
assert(!replacedOps.contains(op) && "operation was already replaced");
assert(!ignoredOps.contains(op) && "operation was already replaced");

// Track if any of the results changed, e.g. erased and replaced with null.
bool resultChanged = false;
Expand All @@ -1509,10 +1498,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
resultChanged);

// Mark this operation as recursively ignored so that we don't need to
// convert any nested operations.
replacedOps.insert(op);
markNestedOpsIgnored(op);
// Mark this operation and all nested ops as replaced.
op->walk([&](Operation *op) { replacedOps.insert(op); });
}

void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
Expand All @@ -1523,6 +1510,9 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {

void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
assert(!wasOpReplaced(block->getParentOp()) &&
"attempting to insert into a region within a replaced/erased op");

if (!previous) {
// This is a newly created block.
appendRewrite<CreateBlockRewrite>(block);
Expand Down Expand Up @@ -1604,6 +1594,9 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
}

void ConversionPatternRewriter::eraseBlock(Block *block) {
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to erase a block within a replaced/erased op");

// Mark all ops for erasure.
for (Operation &op : *block)
eraseOp(&op);
Expand All @@ -1619,18 +1612,27 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
Block *ConversionPatternRewriter::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
return impl->applySignatureConversion(region, conversion, converter);
}

FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
return impl->convertRegionTypes(region, converter, entryConversion);
}

LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
}

Expand Down Expand Up @@ -1665,6 +1667,8 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,

Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to split a block within a replaced/erased op");
auto *continuation = block->splitBlock(before);
impl->notifySplitBlock(block, continuation);
return continuation;
Expand All @@ -1673,15 +1677,19 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
Block::iterator before,
ValueRange argValues) {
#ifndef NDEBUG
assert(argValues.size() == source->getNumArguments() &&
"incorrect # of argument replacement values");
#ifndef NDEBUG
assert(!impl->wasOpReplaced(source->getParentOp()) &&
"attempting to inline a block from a replaced/erased op");
assert(!impl->wasOpReplaced(dest->getParentOp()) &&
"attempting to inline a block into a replaced/erased op");
auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
#endif // NDEBUG
// The source block will be deleted, so it should not have any users (i.e.,
// there should be no predecessors).
assert(llvm::all_of(source->getUsers(), opIgnored) &&
"expected 'source' to have no predecessors");
#endif // NDEBUG

impl->notifyBlockBeingInlined(dest, source, before);
for (auto it : llvm::zip(source->getArguments(), argValues))
Expand All @@ -1691,13 +1699,17 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
}

void ConversionPatternRewriter::startOpModification(Operation *op) {
assert(!impl->wasOpReplaced(op) &&
"attempting to modify a replaced/erased op");
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
#endif
impl->appendRewrite<ModifyOperationRewrite>(op);
}

void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
assert(!impl->wasOpReplaced(op) &&
"attempting to modify a replaced/erased op");
PatternRewriter::finalizeOpModification(op);
// There is nothing to do here, we only need to track the operation at the
// start of the update.
Expand Down Expand Up @@ -1912,8 +1924,13 @@ OperationLegalizer::legalize(Operation *op,

// If this operation is recursively legal, mark its children as ignored so
// that we don't consider them for legalization.
if (legalityInfo->isRecursivelyLegal)
rewriter.getImpl().markNestedOpsIgnored(op);
if (legalityInfo->isRecursivelyLegal) {
op->walk([&](Operation *nested) {
if (op != nested)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op is not visited by the walk, this test seems not necessary to me?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Operation::walk also enumerates the op itself as far as I know.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you point me to the code where this happens?

rewriter.getImpl().ignoredOps.insert(nested);
});
}

return success();
}

Expand Down
1 change: 0 additions & 1 deletion mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1768,7 +1768,6 @@ struct TestMergeSingleBlockOps
rewriter.inlineBlockBefore(&innerBlock, op);
rewriter.eraseOp(innerTerminator);
rewriter.eraseOp(op);
rewriter.modifyOpInPlace(op, [] {});
return success();
}
};
Expand Down