Skip to content

Commit 892b92e

Browse files
committed
added helper function
1 parent 6bf3852 commit 892b92e

File tree

2 files changed

+59
-8
lines changed

2 files changed

+59
-8
lines changed

clang/lib/CIR/CodeGen/CIRGenCall.cpp

+15-8
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,18 @@ static cir::CIRCallOpInterface emitCallLikeOp(
551551
extraFnAttrs);
552552
}
553553

554+
static RValue getRValueThroughMemory(mlir::Location loc,
555+
CIRGenBuilderTy &builder,
556+
mlir::Value val,
557+
Address addr) {
558+
auto ip = builder.saveInsertionPoint();
559+
builder.setInsertionPointAfterValue(val);
560+
builder.createStore(loc, val, addr);
561+
builder.restoreInsertionPoint(ip);
562+
auto load = builder.createLoad(loc, addr);
563+
return RValue::get(load);
564+
}
565+
554566
RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &CallInfo,
555567
const CIRGenCallee &Callee,
556568
ReturnValueSlot ReturnValue,
@@ -890,19 +902,14 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &CallInfo,
890902
assert(Results.size() <= 1 && "multiple returns NYI");
891903
assert(Results[0].getType() == RetCIRTy && "Bitcast support NYI");
892904

893-
auto reg = builder.getBlock()->getParent();
894-
if (reg != theCall->getParentRegion()) {
905+
auto region = builder.getBlock()->getParent();
906+
if (region != theCall->getParentRegion()) {
895907
Address DestPtr = ReturnValue.getValue();
896908

897909
if (!DestPtr.isValid())
898910
DestPtr = CreateMemTemp(RetTy, callLoc, "tmp");
899911

900-
auto ip = builder.saveInsertionPoint();
901-
builder.setInsertionPointAfter(theCall);
902-
builder.createStore(callLoc, Results[0], DestPtr);
903-
builder.restoreInsertionPoint(ip);
904-
auto load = builder.createLoad(callLoc, DestPtr);
905-
return RValue::get(load);
912+
return getRValueThroughMemory(callLoc, builder, Results[0], DestPtr);
906913
}
907914

908915
return RValue::get(Results[0]);

clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include "clang/CIR/Dialect/IR/CIRDialect.h"
2020
#include "clang/CIR/Dialect/Passes.h"
2121

22+
#include <iostream>
23+
2224
using namespace mlir;
2325
using namespace cir;
2426

@@ -910,6 +912,42 @@ void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
910912
patterns.getContext());
911913
}
912914

915+
void removeTempAllocas(DominanceInfo& dom, FuncOp fun) {
916+
917+
fun.walk([&](AllocaOp op) {
918+
if (op.getName().str().find("tmp") == std::string::npos)
919+
return;
920+
921+
StoreOp store;
922+
LoadOp load;
923+
int total = 0;
924+
925+
for (auto* u : op->getUsers()) {
926+
total++;
927+
if (auto ld = dyn_cast<LoadOp>(u))
928+
load = ld;
929+
if (auto st = dyn_cast<StoreOp>(u))
930+
if (st.getAddr() == op.getResult())
931+
store = st;
932+
}
933+
934+
if (total == 2 && load && store && dom.dominates(store, load)) {
935+
if (load->hasOneUse()) {
936+
if (auto st = dyn_cast<StoreOp>(*load->user_begin())) {
937+
if (auto al = dyn_cast<AllocaOp>(st.getAddr().getDefiningOp())) {
938+
llvm::SmallVector<mlir::Value> vals;
939+
vals.push_back(al.getResult());
940+
op->replaceAllUsesWith(vals);
941+
op->erase();
942+
}
943+
}
944+
}
945+
}
946+
947+
});
948+
949+
}
950+
913951
void FlattenCFGPass::runOnOperation() {
914952
RewritePatternSet patterns(&getContext());
915953
populateFlattenCFGPatterns(patterns);
@@ -924,6 +962,12 @@ void FlattenCFGPass::runOnOperation() {
924962
// Apply patterns.
925963
if (applyOpPatternsAndFold(ops, std::move(patterns)).failed())
926964
signalPassFailure();
965+
966+
auto &dom = getAnalysis<DominanceInfo>();
967+
968+
getOperation()->walk([&](FuncOp fun) {
969+
removeTempAllocas(dom, fun);
970+
});
927971
}
928972

929973
} // namespace

0 commit comments

Comments
 (0)