@@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
688
688
UnresolvedMaterializationRewrite (
689
689
ConversionPatternRewriterImpl &rewriterImpl,
690
690
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr ,
691
- MaterializationKind kind = MaterializationKind::Target)
692
- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
693
- converterAndKind (converter, kind) {}
691
+ MaterializationKind kind = MaterializationKind::Target);
694
692
695
693
static bool classof (const IRRewrite *rewrite) {
696
694
return rewrite->getKind () == Kind::UnresolvedMaterialization;
@@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
730
728
});
731
729
}
732
730
733
- // / Find the single rewrite object of the specified type and block among the
734
- // / given rewrites. In debug mode, asserts that there is mo more than one such
735
- // / object. Return "nullptr" if no object was found.
736
- template <typename RewriteTy, typename R>
737
- static RewriteTy *findSingleRewrite (R &&rewrites, Block *block) {
738
- RewriteTy *result = nullptr ;
739
- for (auto &rewrite : rewrites) {
740
- auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get ());
741
- if (rewriteTy && rewriteTy->getBlock () == block) {
742
- #ifndef NDEBUG
743
- assert (!result && " expected single matching rewrite" );
744
- result = rewriteTy;
745
- #else
746
- return rewriteTy;
747
- #endif // NDEBUG
748
- }
749
- }
750
- return result;
751
- }
752
-
753
731
// ===----------------------------------------------------------------------===//
754
732
// ConversionPatternRewriterImpl
755
733
// ===----------------------------------------------------------------------===//
@@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
892
870
893
871
bool wasErased (void *ptr) const { return erased.contains (ptr); }
894
872
895
- bool wasErased (OperationRewrite *rewrite) const {
896
- return wasErased (rewrite->getOperation ());
897
- }
898
-
899
873
void notifyOperationErased (Operation *op) override { erased.insert (op); }
900
874
901
875
void notifyBlockErased (Block *block) override { erased.insert (block); }
@@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
935
909
// / to modify/access them is invalid rewriter API usage.
936
910
SetVector<Operation *> replacedOps;
937
911
938
- // / A set of all unresolved materializations.
939
- DenseSet<Operation *> unresolvedMaterializations;
912
+ // / A mapping of all unresolved materializations (UnrealizedConversionCastOp)
913
+ // / to the corresponding rewrite objects.
914
+ DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
915
+ unresolvedMaterializations;
940
916
941
917
// / The current type converter, or nullptr if no type converter is currently
942
918
// / active.
@@ -1058,12 +1034,20 @@ void CreateOperationRewrite::rollback() {
1058
1034
op->erase ();
1059
1035
}
1060
1036
1037
+ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite (
1038
+ ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1039
+ const TypeConverter *converter, MaterializationKind kind)
1040
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1041
+ converterAndKind(converter, kind) {
1042
+ rewriterImpl.unresolvedMaterializations [op] = this ;
1043
+ }
1044
+
1061
1045
void UnresolvedMaterializationRewrite::rollback () {
1062
1046
if (getMaterializationKind () == MaterializationKind::Target) {
1063
1047
for (Value input : op->getOperands ())
1064
1048
rewriterImpl.mapping .erase (input);
1065
1049
}
1066
- rewriterImpl.unresolvedMaterializations .erase (op );
1050
+ rewriterImpl.unresolvedMaterializations .erase (getOperation () );
1067
1051
op->erase ();
1068
1052
}
1069
1053
@@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1345
1329
builder.setInsertionPoint (ip.getBlock (), ip.getPoint ());
1346
1330
auto convertOp =
1347
1331
builder.create <UnrealizedConversionCastOp>(loc, outputType, inputs);
1348
- unresolvedMaterializations.insert (convertOp);
1349
1332
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1350
1333
return convertOp.getResult (0 );
1351
1334
}
@@ -1382,10 +1365,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
1382
1365
for (auto [newValue, result] : llvm::zip (newValues, op->getResults ())) {
1383
1366
if (!newValue) {
1384
1367
// This result was dropped and no replacement value was provided.
1385
- if (unresolvedMaterializations.contains (op)) {
1386
- // Do not create another materializations if we are erasing a
1387
- // materialization.
1388
- continue ;
1368
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1369
+ if (unresolvedMaterializations.contains (castOp)) {
1370
+ // Do not create another materializations if we are erasing a
1371
+ // materialization.
1372
+ continue ;
1373
+ }
1389
1374
}
1390
1375
1391
1376
// Materialize a replacement value "out of thin air".
@@ -2499,15 +2484,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2499
2484
2500
2485
// Gather all unresolved materializations.
2501
2486
SmallVector<UnrealizedConversionCastOp> allCastOps;
2502
- DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
2503
- for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites ) {
2504
- auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get ());
2505
- if (!mat)
2506
- continue ;
2507
- if (rewriterImpl.eraseRewriter .wasErased (mat))
2487
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
2488
+ &materializations = rewriterImpl.unresolvedMaterializations ;
2489
+ for (auto it : materializations) {
2490
+ if (rewriterImpl.eraseRewriter .wasErased (it.first ))
2508
2491
continue ;
2509
- allCastOps.push_back (mat->getOperation ());
2510
- rewriteMap[mat->getOperation ()] = mat;
2492
+ allCastOps.push_back (it.first );
2511
2493
}
2512
2494
2513
2495
// Reconcile all UnrealizedConversionCastOps that were inserted by the
@@ -2520,8 +2502,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2520
2502
if (config.buildMaterializations ) {
2521
2503
IRRewriter rewriter (rewriterImpl.context , config.listener );
2522
2504
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2523
- auto it = rewriteMap .find (castOp. getOperation () );
2524
- assert (it != rewriteMap .end () && " inconsistent state" );
2505
+ auto it = materializations .find (castOp);
2506
+ assert (it != materializations .end () && " inconsistent state" );
2525
2507
if (failed (legalizeUnresolvedMaterialization (rewriter, it->second )))
2526
2508
return failure ();
2527
2509
}
0 commit comments