-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][Transforms][NFC] Dialect conversion: Cache UnresolvedMaterializationRewrite
#108359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThe dialect conversion maintains a set of unresolved materializations ( Also delete some dead code. Full diff: https://github.com/llvm/llvm-project/pull/108359.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b58a95c3baf70a..ed15b571f01883 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
- MaterializationKind kind = MaterializationKind::Target)
- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind) {}
+ MaterializationKind kind = MaterializationKind::Target);
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
});
}
-/// Find the single rewrite object of the specified type and block among the
-/// given rewrites. In debug mode, asserts that there is mo more than one such
-/// object. Return "nullptr" if no object was found.
-template <typename RewriteTy, typename R>
-static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
- RewriteTy *result = nullptr;
- for (auto &rewrite : rewrites) {
- auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
- if (rewriteTy && rewriteTy->getBlock() == block) {
-#ifndef NDEBUG
- assert(!result && "expected single matching rewrite");
- result = rewriteTy;
-#else
- return rewriteTy;
-#endif // NDEBUG
- }
- }
- return result;
-}
-
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
@@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasErased(void *ptr) const { return erased.contains(ptr); }
- bool wasErased(OperationRewrite *rewrite) const {
- return wasErased(rewrite->getOperation());
- }
-
void notifyOperationErased(Operation *op) override { erased.insert(op); }
void notifyBlockErased(Block *block) override { erased.insert(block); }
@@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;
- /// A set of all unresolved materializations.
- DenseSet<Operation *> unresolvedMaterializations;
+ /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
+ /// to the corresponding rewrite objects.
+ DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+ unresolvedMaterializations;
/// The current type converter, or nullptr if no type converter is currently
/// active.
@@ -1058,6 +1034,14 @@ void CreateOperationRewrite::rollback() {
op->erase();
}
+UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
+ const TypeConverter *converter, MaterializationKind kind)
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+ converterAndKind(converter, kind) {
+ rewriterImpl.unresolvedMaterializations[op] = this;
+}
+
void UnresolvedMaterializationRewrite::rollback() {
if (getMaterializationKind() == MaterializationKind::Target) {
for (Value input : op->getOperands())
@@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- unresolvedMaterializations.insert(convertOp);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
@@ -2499,15 +2482,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Gather all unresolved materializations.
SmallVector<UnrealizedConversionCastOp> allCastOps;
- DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
- for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) {
- auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
- if (!mat)
- continue;
- if (rewriterImpl.eraseRewriter.wasErased(mat))
+ const DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+ &materializations = rewriterImpl.unresolvedMaterializations;
+ for (auto it : materializations) {
+ if (rewriterImpl.eraseRewriter.wasErased(it.first))
continue;
- allCastOps.push_back(mat->getOperation());
- rewriteMap[mat->getOperation()] = mat;
+ allCastOps.push_back(cast<UnrealizedConversionCastOp>(it.first));
}
// Reconcile all UnrealizedConversionCastOps that were inserted by the
@@ -2520,8 +2500,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (config.buildMaterializations) {
IRRewriter rewriter(rewriterImpl.context, config.listener);
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
- auto it = rewriteMap.find(castOp.getOperation());
- assert(it != rewriteMap.end() && "inconsistent state");
+ auto it = materializations.find(castOp.getOperation());
+ assert(it != materializations.end() && "inconsistent state");
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
return failure();
}
|
4cb4bcf
to
066359e
Compare
…izationRewrite` The dialect conversion already maintains a set of unresolved materializations (`UnrealizedConversionCastOp`). Turn that set into a map that maps from ops to `UnresolvedMaterializationRewrite *`. This improves efficiency a bit, because an iteration over `ConversionPatternRewriterImpl::rewrites` can be avoided. Also delete some dead code.
066359e
to
e724e44
Compare
The dialect conversion maintains a set of unresolved materializations (
UnrealizedConversionCastOp
). Turn that set into aDenseMap
that maps from ops toUnresolvedMaterializationRewrite *
. This improves efficiency a bit, because an iteration overConversionPatternRewriterImpl::rewrites
can be avoided.Also delete some dead code.