Skip to content

Commit e724e44

Browse files
[mlir][Transforms][NFC] Dialect conversion: Cache UnresolvedMaterializationRewrite
The dialect conversion already maintains a set of unresolved materializations (`UnrealizedConversionCastOp`). Turn that set into a map that maps from ops to `UnresolvedMaterializationRewrite *`. This improves efficiency a bit, because an iteration over `ConversionPatternRewriterImpl::rewrites` can be avoided. Also delete some dead code.
1 parent 0351dc5 commit e724e44

File tree

1 file changed

+27
-45
lines changed

1 file changed

+27
-45
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

+27-45
Original file line numberDiff line numberDiff line change
@@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
688688
UnresolvedMaterializationRewrite(
689689
ConversionPatternRewriterImpl &rewriterImpl,
690690
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
691-
MaterializationKind kind = MaterializationKind::Target)
692-
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
693-
converterAndKind(converter, kind) {}
691+
MaterializationKind kind = MaterializationKind::Target);
694692

695693
static bool classof(const IRRewrite *rewrite) {
696694
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
730728
});
731729
}
732730

733-
/// Find the single rewrite object of the specified type and block among the
734-
/// given rewrites. In debug mode, asserts that there is mo more than one such
735-
/// object. Return "nullptr" if no object was found.
736-
template <typename RewriteTy, typename R>
737-
static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
738-
RewriteTy *result = nullptr;
739-
for (auto &rewrite : rewrites) {
740-
auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
741-
if (rewriteTy && rewriteTy->getBlock() == block) {
742-
#ifndef NDEBUG
743-
assert(!result && "expected single matching rewrite");
744-
result = rewriteTy;
745-
#else
746-
return rewriteTy;
747-
#endif // NDEBUG
748-
}
749-
}
750-
return result;
751-
}
752-
753731
//===----------------------------------------------------------------------===//
754732
// ConversionPatternRewriterImpl
755733
//===----------------------------------------------------------------------===//
@@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
892870

893871
bool wasErased(void *ptr) const { return erased.contains(ptr); }
894872

895-
bool wasErased(OperationRewrite *rewrite) const {
896-
return wasErased(rewrite->getOperation());
897-
}
898-
899873
void notifyOperationErased(Operation *op) override { erased.insert(op); }
900874

901875
void notifyBlockErased(Block *block) override { erased.insert(block); }
@@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
935909
/// to modify/access them is invalid rewriter API usage.
936910
SetVector<Operation *> replacedOps;
937911

938-
/// A set of all unresolved materializations.
939-
DenseSet<Operation *> unresolvedMaterializations;
912+
/// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
913+
/// to the corresponding rewrite objects.
914+
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
915+
unresolvedMaterializations;
940916

941917
/// The current type converter, or nullptr if no type converter is currently
942918
/// active.
@@ -1058,12 +1034,20 @@ void CreateOperationRewrite::rollback() {
10581034
op->erase();
10591035
}
10601036

1037+
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1038+
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1039+
const TypeConverter *converter, MaterializationKind kind)
1040+
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1041+
converterAndKind(converter, kind) {
1042+
rewriterImpl.unresolvedMaterializations[op] = this;
1043+
}
1044+
10611045
void UnresolvedMaterializationRewrite::rollback() {
10621046
if (getMaterializationKind() == MaterializationKind::Target) {
10631047
for (Value input : op->getOperands())
10641048
rewriterImpl.mapping.erase(input);
10651049
}
1066-
rewriterImpl.unresolvedMaterializations.erase(op);
1050+
rewriterImpl.unresolvedMaterializations.erase(getOperation());
10671051
op->erase();
10681052
}
10691053

@@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13451329
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
13461330
auto convertOp =
13471331
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1348-
unresolvedMaterializations.insert(convertOp);
13491332
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
13501333
return convertOp.getResult(0);
13511334
}
@@ -1382,10 +1365,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13821365
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
13831366
if (!newValue) {
13841367
// This result was dropped and no replacement value was provided.
1385-
if (unresolvedMaterializations.contains(op)) {
1386-
// Do not create another materializations if we are erasing a
1387-
// materialization.
1388-
continue;
1368+
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1369+
if (unresolvedMaterializations.contains(castOp)) {
1370+
// Do not create another materializations if we are erasing a
1371+
// materialization.
1372+
continue;
1373+
}
13891374
}
13901375

13911376
// Materialize a replacement value "out of thin air".
@@ -2499,15 +2484,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
24992484

25002485
// Gather all unresolved materializations.
25012486
SmallVector<UnrealizedConversionCastOp> allCastOps;
2502-
DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
2503-
for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) {
2504-
auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
2505-
if (!mat)
2506-
continue;
2507-
if (rewriterImpl.eraseRewriter.wasErased(mat))
2487+
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
2488+
&materializations = rewriterImpl.unresolvedMaterializations;
2489+
for (auto it : materializations) {
2490+
if (rewriterImpl.eraseRewriter.wasErased(it.first))
25082491
continue;
2509-
allCastOps.push_back(mat->getOperation());
2510-
rewriteMap[mat->getOperation()] = mat;
2492+
allCastOps.push_back(it.first);
25112493
}
25122494

25132495
// Reconcile all UnrealizedConversionCastOps that were inserted by the
@@ -2520,8 +2502,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25202502
if (config.buildMaterializations) {
25212503
IRRewriter rewriter(rewriterImpl.context, config.listener);
25222504
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2523-
auto it = rewriteMap.find(castOp.getOperation());
2524-
assert(it != rewriteMap.end() && "inconsistent state");
2505+
auto it = materializations.find(castOp);
2506+
assert(it != materializations.end() && "inconsistent state");
25252507
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
25262508
return failure();
25272509
}

0 commit comments

Comments
 (0)