Skip to content

Commit 8527861

Browse files
[mlir][Transforms] Dialect conversion: Unify materialization of value replacements (llvm#108381)
PR llvm#106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (`legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes`). This PR merges the two functions and moves the implementation directly into `finalize`. This PR simplifies the code base and improves the efficiency a bit: previously, `finalize` iterated over `ConversionPatternRewriterImpl::rewrites` twice. Now, only one iteration is needed. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent c57b9f5 commit 8527861

File tree

1 file changed

+41
-92
lines changed

1 file changed

+41
-92
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 41 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,17 +2338,6 @@ struct OperationConverter {
23382338
/// remaining artifacts and complete the conversion.
23392339
LogicalResult finalize(ConversionPatternRewriter &rewriter);
23402340

2341-
/// Legalize the types of converted block arguments.
2342-
LogicalResult
2343-
legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
2344-
ConversionPatternRewriterImpl &rewriterImpl);
2345-
2346-
/// Legalize the types of converted op results.
2347-
LogicalResult legalizeConvertedOpResultTypes(
2348-
ConversionPatternRewriter &rewriter,
2349-
ConversionPatternRewriterImpl &rewriterImpl,
2350-
DenseMap<Value, SmallVector<Value>> &inverseMapping);
2351-
23522341
/// Dialect conversion configuration.
23532342
ConversionConfig config;
23542343

@@ -2512,19 +2501,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25122501
return success();
25132502
}
25142503

2515-
LogicalResult
2516-
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2517-
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2518-
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2519-
return failure();
2520-
DenseMap<Value, SmallVector<Value>> inverseMapping =
2521-
rewriterImpl.mapping.getInverse();
2522-
if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
2523-
inverseMapping)))
2524-
return failure();
2525-
return success();
2526-
}
2527-
25282504
/// Finds a user of the given value, or of any other value that the given value
25292505
/// replaced, that was not replaced in the conversion process.
25302506
static Operation *findLiveUserOfReplaced(
@@ -2548,87 +2524,60 @@ static Operation *findLiveUserOfReplaced(
25482524
return nullptr;
25492525
}
25502526

2551-
LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
2552-
ConversionPatternRewriter &rewriter,
2553-
ConversionPatternRewriterImpl &rewriterImpl,
2554-
DenseMap<Value, SmallVector<Value>> &inverseMapping) {
2555-
// Process requested operation replacements.
2556-
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
2557-
auto *opReplacement =
2558-
dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
2559-
if (!opReplacement)
2560-
continue;
2561-
Operation *op = opReplacement->getOperation();
2562-
for (OpResult result : op->getResults()) {
2563-
// If the type of this op result changed and the result is still live,
2564-
// we need to materialize a conversion.
2565-
if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
2527+
/// Helper function that returns the replaced values and the type converter if
2528+
/// the given rewrite object is an "operation replacement" or a "block type
2529+
/// conversion" (which corresponds to a "block replacement"). Otherwise, return
2530+
/// an empty ValueRange and a null type converter pointer.
2531+
static std::pair<ValueRange, const TypeConverter *>
2532+
getReplacedValues(IRRewrite *rewrite) {
2533+
if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
2534+
return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
2535+
if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
2536+
return {blockRewrite->getOrigBlock()->getArguments(),
2537+
blockRewrite->getConverter()};
2538+
return {};
2539+
}
2540+
2541+
LogicalResult
2542+
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2543+
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2544+
DenseMap<Value, SmallVector<Value>> inverseMapping =
2545+
rewriterImpl.mapping.getInverse();
2546+
2547+
// Process requested value replacements.
2548+
for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
2549+
ValueRange replacedValues;
2550+
const TypeConverter *converter;
2551+
std::tie(replacedValues, converter) =
2552+
getReplacedValues(rewriterImpl.rewrites[i].get());
2553+
for (Value originalValue : replacedValues) {
2554+
// If the type of this value changed and the value is still live, we need
2555+
// to materialize a conversion.
2556+
if (rewriterImpl.mapping.lookupOrNull(originalValue,
2557+
originalValue.getType()))
25662558
continue;
25672559
Operation *liveUser =
2568-
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
2560+
findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
25692561
if (!liveUser)
25702562
continue;
25712563

2572-
// Legalize this result.
2573-
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
2564+
// Legalize this value replacement.
2565+
Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
25742566
assert(newValue && "replacement value not found");
25752567
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
2576-
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
2577-
/*inputs=*/newValue, /*outputType=*/result.getType(),
2578-
opReplacement->getConverter());
2579-
rewriterImpl.mapping.map(result, castValue);
2580-
inverseMapping[castValue].push_back(result);
2581-
llvm::erase(inverseMapping[newValue], result);
2568+
MaterializationKind::Source, computeInsertPoint(newValue),
2569+
originalValue.getLoc(),
2570+
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
2571+
converter);
2572+
rewriterImpl.mapping.map(originalValue, castValue);
2573+
inverseMapping[castValue].push_back(originalValue);
2574+
llvm::erase(inverseMapping[newValue], originalValue);
25822575
}
25832576
}
25842577

25852578
return success();
25862579
}
25872580

2588-
LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
2589-
ConversionPatternRewriter &rewriter,
2590-
ConversionPatternRewriterImpl &rewriterImpl) {
2591-
// Functor used to check if all users of a value will be dead after
2592-
// conversion.
2593-
// TODO: This should probably query the inverse mapping, same as in
2594-
// `legalizeConvertedOpResultTypes`.
2595-
auto findLiveUser = [&](Value val) {
2596-
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
2597-
return rewriterImpl.isOpIgnored(user);
2598-
});
2599-
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
2600-
};
2601-
// Note: `rewrites` may be reallocated as the loop is running.
2602-
for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
2603-
++i) {
2604-
auto &rewrite = rewriterImpl.rewrites[i];
2605-
if (auto *blockTypeConversionRewrite =
2606-
dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
2607-
// Process the remapping for each of the original arguments.
2608-
for (Value origArg :
2609-
blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
2610-
// If the type of this argument changed and the argument is still live,
2611-
// we need to materialize a conversion.
2612-
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
2613-
continue;
2614-
Operation *liveUser = findLiveUser(origArg);
2615-
if (!liveUser)
2616-
continue;
2617-
2618-
Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
2619-
assert(replacementValue && "replacement value not found");
2620-
Value repl = rewriterImpl.buildUnresolvedMaterialization(
2621-
MaterializationKind::Source, computeInsertPoint(replacementValue),
2622-
origArg.getLoc(), /*inputs=*/replacementValue,
2623-
/*outputType=*/origArg.getType(),
2624-
blockTypeConversionRewrite->getConverter());
2625-
rewriterImpl.mapping.map(origArg, repl);
2626-
}
2627-
}
2628-
}
2629-
return success();
2630-
}
2631-
26322581
//===----------------------------------------------------------------------===//
26332582
// Reconcile Unrealized Casts
26342583
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)