Skip to content

[mlir][Transforms][NFC] Dialect conversion: Cache UnresolvedMaterializationRewrite #108359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 13, 2024

Conversation

matthias-springer
Copy link
Member

The dialect conversion maintains a set of unresolved materializations (UnrealizedConversionCastOp). Turn that set into a DenseMap that maps from ops to UnresolvedMaterializationRewrite *. This improves efficiency a bit, because an iteration over ConversionPatternRewriterImpl::rewrites can be avoided.

Also delete some dead code.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Sep 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

The dialect conversion maintains a set of unresolved materializations (UnrealizedConversionCastOp). Turn that set into a DenseMap that maps from ops to UnresolvedMaterializationRewrite *. This improves efficiency a bit, because an iteration over ConversionPatternRewriterImpl::rewrites can be avoided.

Also delete some dead code.


Full diff: https://github.com/llvm/llvm-project/pull/108359.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-40)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b58a95c3baf70a..ed15b571f01883 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   UnresolvedMaterializationRewrite(
       ConversionPatternRewriterImpl &rewriterImpl,
       UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
-      MaterializationKind kind = MaterializationKind::Target)
-      : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-        converterAndKind(converter, kind) {}
+      MaterializationKind kind = MaterializationKind::Target);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
   });
 }
 
-/// Find the single rewrite object of the specified type and block among the
-/// given rewrites. In debug mode, asserts that there is mo more than one such
-/// object. Return "nullptr" if no object was found.
-template <typename RewriteTy, typename R>
-static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
-  RewriteTy *result = nullptr;
-  for (auto &rewrite : rewrites) {
-    auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
-    if (rewriteTy && rewriteTy->getBlock() == block) {
-#ifndef NDEBUG
-      assert(!result && "expected single matching rewrite");
-      result = rewriteTy;
-#else
-      return rewriteTy;
-#endif // NDEBUG
-    }
-  }
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
@@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
     bool wasErased(void *ptr) const { return erased.contains(ptr); }
 
-    bool wasErased(OperationRewrite *rewrite) const {
-      return wasErased(rewrite->getOperation());
-    }
-
     void notifyOperationErased(Operation *op) override { erased.insert(op); }
 
     void notifyBlockErased(Block *block) override { erased.insert(block); }
@@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// to modify/access them is invalid rewriter API usage.
   SetVector<Operation *> replacedOps;
 
-  /// A set of all unresolved materializations.
-  DenseSet<Operation *> unresolvedMaterializations;
+  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
+  /// to the corresponding rewrite objects.
+  DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+      unresolvedMaterializations;
 
   /// The current type converter, or nullptr if no type converter is currently
   /// active.
@@ -1058,6 +1034,14 @@ void CreateOperationRewrite::rollback() {
   op->erase();
 }
 
+UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
+    ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
+    const TypeConverter *converter, MaterializationKind kind)
+    : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+      converterAndKind(converter, kind) {
+  rewriterImpl.unresolvedMaterializations[op] = this;
+}
+
 void UnresolvedMaterializationRewrite::rollback() {
   if (getMaterializationKind() == MaterializationKind::Target) {
     for (Value input : op->getOperands())
@@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  unresolvedMaterializations.insert(convertOp);
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
   return convertOp.getResult(0);
 }
@@ -2499,15 +2482,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
 
   // Gather all unresolved materializations.
   SmallVector<UnrealizedConversionCastOp> allCastOps;
-  DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
-  for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) {
-    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
-    if (!mat)
-      continue;
-    if (rewriterImpl.eraseRewriter.wasErased(mat))
+  const DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+      &materializations = rewriterImpl.unresolvedMaterializations;
+  for (auto it : materializations) {
+    if (rewriterImpl.eraseRewriter.wasErased(it.first))
       continue;
-    allCastOps.push_back(mat->getOperation());
-    rewriteMap[mat->getOperation()] = mat;
+    allCastOps.push_back(cast<UnrealizedConversionCastOp>(it.first));
   }
 
   // Reconcile all UnrealizedConversionCastOps that were inserted by the
@@ -2520,8 +2500,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (config.buildMaterializations) {
     IRRewriter rewriter(rewriterImpl.context, config.listener);
     for (UnrealizedConversionCastOp castOp : remainingCastOps) {
-      auto it = rewriteMap.find(castOp.getOperation());
-      assert(it != rewriteMap.end() && "inconsistent state");
+      auto it = materializations.find(castOp.getOperation());
+      assert(it != materializations.end() && "inconsistent state");
       if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
         return failure();
     }

Base automatically changed from users/matthias-springer/replace_op_source_mat to main September 12, 2024 13:30
@matthias-springer matthias-springer force-pushed the users/matthias-springer/mat_cache branch 2 times, most recently from 4cb4bcf to 066359e Compare September 12, 2024 13:36
…izationRewrite`

The dialect conversion already maintains a set of unresolved materializations (`UnrealizedConversionCastOp`). Turn that set into a map that maps from ops to `UnresolvedMaterializationRewrite *`. This improves efficiency a bit, because an iteration over `ConversionPatternRewriterImpl::rewrites` can be avoided.

Also delete some dead code.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/mat_cache branch from 066359e to e724e44 Compare September 13, 2024 17:55
@matthias-springer matthias-springer merged commit d588e49 into main Sep 13, 2024
6 of 7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/mat_cache branch September 13, 2024 18:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants