Skip to content

Commit 0af9b84

Browse files
gitoleglanza
authored andcommitted
[CIR][ABI][Lowering] Supports function pointers in the calling convention lowering pass (#1003)
This PR adds initial function pointers support for the calling convention lowering pass. This is a suggestion, so any other ideas are welcome. Several ideas was described in the #995 and basically what I'm trying to do is to generate a clean CIR code without additional `bitcast` operations for function pointers and without mix of lowered and initial function types. Looks like we can not just lower the function type and cast the value since too many operations are involved. For instance, for the next simple code: ``` typedef struct { int a; } S; typedef int (*myfptr)(S); int foo(S s) { return 42 + s.a; } void bar() { myfptr a = foo; } ``` we get the next CIR for the function `bar` , before the calling convention lowering pass: ``` cir.func no_proto @bar() extra(#fn_attr) { %0 = cir.alloca !cir.ptr<!cir.func<!s32i (!ty_S)>>, !cir.ptr<!cir.ptr<!cir.func<!s32i (!ty_S)>>>, ["a", init] %1 = cir.get_global @foo : !cir.ptr<!cir.func<!s32i (!ty_S)>> cir.store %1, %0 : !cir.ptr<!cir.func<!s32i (!ty_S)>>, !cir.ptr<!cir.ptr<!cir.func<!s32i (!ty_S)>>> cir.return } ``` As one can see, first three operations depend on the function type. Once `foo` is lowered, we need to fix `GetGlobalOp`: otherwise the code will fail with the verification error since actual `foo` type (lowered) differs from the one currently expected by the `GetGlobalOp`. First idea would just rewrite only the `GetGlobalOp` and insert a bitcast after, so both `AllocaOp` and `StoreOp` would work witth proper types. Once the code will be more complex, we will need to take care about possible use cases, e.g. if we use arrays, we will need to track array accesses to it as well in order to insert this bitcast every time the array element is needed. One workaround I can think of: we fix the `GetGlobalOp` type and cast from the lowered type to the initial, and cast back before the actual call happens - but it doesn't sound as a good and clean approach (from my point of view, of course). So I suggest to use type converter and rewrite any operation that may deal with function pointers and make sure it has a proper type, and we don't have any unlowered function type in the program after the calling convention lowering pass. I added lowering for `AllocaOp`, `GetGlobalOp`, and split the lowering for `FuncOp` (former `CallConvLoweringPattern`) and lower `CallOp` separately. Frankly speaking, I tried to implement a pattern for each operation, but for some reasons the tests are not passed for windows and macOs in this case - something weird happens inside `applyPatternsAndFold` function. I suspect it's due to two different rewriters used - one in the `LoweringModule` and one in the mentioned function. So I decided to follow the same approach as it's done for the `LoweringPrepare` pass and don't involve this complex rewriting framework. Next I will add a type converter for the struct type, patterns for `ConstantOp` (for const arrays and `GlobalViewAttr`) In the end of the day we'll have (at least I hope so) a clean CIR code without any bitcasts for function pointers. cc @sitio-couto @bcardosolopes
1 parent 1b01311 commit 0af9b84

File tree

2 files changed

+94
-50
lines changed

2 files changed

+94
-50
lines changed

clang/lib/CIR/Dialect/Transforms/CallConvLowering.cpp

+76-50
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
98
#include "TargetLowering/LowerModule.h"
109
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1110
#include "mlir/IR/BuiltinOps.h"
1211
#include "mlir/IR/PatternMatch.h"
1312
#include "mlir/Pass/Pass.h"
13+
#include "mlir/Transforms/DialectConversion.h"
1414
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1515
#include "clang/CIR/Dialect/IR/CIRDialect.h"
1616
#include "clang/CIR/MissingFeatures.h"
@@ -23,50 +23,93 @@
2323
namespace mlir {
2424
namespace cir {
2525

26-
//===----------------------------------------------------------------------===//
27-
// Rewrite Patterns
28-
//===----------------------------------------------------------------------===//
29-
30-
struct CallConvLoweringPattern : public OpRewritePattern<FuncOp> {
31-
using OpRewritePattern<FuncOp>::OpRewritePattern;
26+
FuncType getFuncPointerTy(mlir::Type typ) {
27+
if (auto ptr = dyn_cast<PointerType>(typ))
28+
return dyn_cast<FuncType>(ptr.getPointee());
29+
return {};
30+
}
3231

33-
LogicalResult matchAndRewrite(FuncOp op,
34-
PatternRewriter &rewriter) const final {
35-
llvm::TimeTraceScope scope("Call Conv Lowering Pass", op.getSymName().str());
32+
bool isFuncPointerTy(mlir::Type typ) { return (bool)getFuncPointerTy(typ); }
3633

37-
const auto module = op->getParentOfType<mlir::ModuleOp>();
34+
struct CallConvLowering {
3835

39-
auto modOp = op->getParentOfType<ModuleOp>();
40-
std::unique_ptr<LowerModule> lowerModule =
41-
createLowerModule(modOp, rewriter);
36+
CallConvLowering(ModuleOp module)
37+
: rewriter(module.getContext()),
38+
lowerModule(createLowerModule(module, rewriter)) {}
4239

43-
// Rewrite function calls before definitions. This should be done before
44-
// lowering the definition.
40+
void lower(FuncOp op) {
41+
// Fail the pass on unimplemented function users
42+
const auto module = op->getParentOfType<mlir::ModuleOp>();
4543
auto calls = op.getSymbolUses(module);
4644
if (calls.has_value()) {
4745
for (auto call : calls.value()) {
48-
// FIXME(cir): Function pointers are ignored.
49-
if (isa<GetGlobalOp>(call.getUser())) {
46+
if (auto g = dyn_cast<GetGlobalOp>(call.getUser()))
47+
rewriteGetGlobalOp(g);
48+
else if (auto c = dyn_cast<CallOp>(call.getUser()))
49+
lowerDirectCallOp(c, op);
50+
else {
5051
cir_cconv_assert_or_abort(!::cir::MissingFeatures::ABIFuncPtr(),
5152
"NYI");
52-
continue;
5353
}
54-
55-
auto callOp = dyn_cast_or_null<CallOp>(call.getUser());
56-
if (!callOp)
57-
cir_cconv_unreachable("NYI empty callOp");
58-
if (lowerModule->rewriteFunctionCall(callOp, op).failed())
59-
return failure();
6054
}
6155
}
6256

63-
// TODO(cir): Instead of re-emmiting every load and store, bitcast arguments
64-
// and return values to their ABI-specific counterparts when possible.
65-
if (lowerModule->rewriteFunctionDefinition(op).failed())
66-
return failure();
57+
op.walk([&](CallOp c) {
58+
if (c.isIndirect())
59+
lowerIndirectCallOp(c);
60+
});
6761

68-
return success();
62+
lowerModule->rewriteFunctionDefinition(op);
6963
}
64+
65+
private:
66+
FuncType convert(FuncType t) {
67+
auto &typs = lowerModule->getTypes();
68+
return typs.getFunctionType(typs.arrangeFreeFunctionType(t));
69+
}
70+
71+
mlir::Type convert(mlir::Type t) {
72+
if (auto fTy = getFuncPointerTy(t))
73+
return PointerType::get(rewriter.getContext(), convert(fTy));
74+
return t;
75+
}
76+
77+
void bitcast(Value src, Type newTy) {
78+
if (src.getType() != newTy) {
79+
auto cast =
80+
rewriter.create<CastOp>(src.getLoc(), newTy, CastKind::bitcast, src);
81+
rewriter.replaceAllUsesExcept(src, cast, cast);
82+
}
83+
}
84+
85+
void rewriteGetGlobalOp(GetGlobalOp op) {
86+
auto resTy = op.getResult().getType();
87+
if (isFuncPointerTy(resTy)) {
88+
rewriter.setInsertionPoint(op);
89+
auto newOp = rewriter.replaceOpWithNewOp<GetGlobalOp>(op, convert(resTy),
90+
op.getName());
91+
rewriter.setInsertionPointAfter(newOp);
92+
bitcast(newOp, resTy);
93+
}
94+
}
95+
96+
void lowerDirectCallOp(CallOp op, FuncOp callee) {
97+
lowerModule->rewriteFunctionCall(op, callee);
98+
}
99+
100+
void lowerIndirectCallOp(CallOp op) {
101+
cir_cconv_assert(op.isIndirect());
102+
103+
rewriter.setInsertionPoint(op);
104+
auto typ = op.getIndirectCall().getType();
105+
if (isFuncPointerTy(typ)) {
106+
cir_cconv_unreachable("Indirect calls NYI");
107+
}
108+
}
109+
110+
private:
111+
mlir::PatternRewriter rewriter;
112+
std::unique_ptr<LowerModule> lowerModule;
70113
};
71114

72115
//===----------------------------------------------------------------------===//
@@ -81,27 +124,10 @@ struct CallConvLoweringPass
81124
StringRef getArgument() const override { return "cir-call-conv-lowering"; };
82125
};
83126

84-
void populateCallConvLoweringPassPatterns(RewritePatternSet &patterns) {
85-
patterns.add<CallConvLoweringPattern>(patterns.getContext());
86-
}
87-
88127
void CallConvLoweringPass::runOnOperation() {
89-
90-
// Collect rewrite patterns.
91-
RewritePatternSet patterns(&getContext());
92-
populateCallConvLoweringPassPatterns(patterns);
93-
94-
// Collect operations to be considered by the pass.
95-
SmallVector<Operation *, 16> ops;
96-
getOperation()->walk([&](FuncOp op) { ops.push_back(op); });
97-
98-
// Configure rewrite to ignore new ops created during the pass.
99-
GreedyRewriteConfig config;
100-
config.strictMode = GreedyRewriteStrictness::ExistingOps;
101-
102-
// Apply patterns.
103-
if (failed(applyOpPatternsGreedily(ops, std::move(patterns), config)))
104-
signalPassFailure();
128+
auto module = dyn_cast<ModuleOp>(getOperation());
129+
CallConvLowering cc(module);
130+
module.walk([&](FuncOp op) { cc.lower(op); });
105131
}
106132

107133
} // namespace cir
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir-flat -fclangir-call-conv-lowering %s -o - | FileCheck %s
2+
3+
typedef struct {
4+
int a;
5+
} S;
6+
7+
typedef int (*myfptr)(S);
8+
9+
int foo(S s) { return 42 + s.a; }
10+
11+
// CHECK: cir.func {{.*@bar}}
12+
// CHECK: %[[#V0:]] = cir.alloca !cir.ptr<!cir.func<!s32i (!ty_S)>>, !cir.ptr<!cir.ptr<!cir.func<!s32i (!ty_S)>>>, ["a", init]
13+
// CHECK: %[[#V1:]] = cir.get_global @foo : !cir.ptr<!cir.func<!s32i (!s32i)>>
14+
// CHECK: %[[#V2:]] = cir.cast(bitcast, %[[#V1]] : !cir.ptr<!cir.func<!s32i (!s32i)>>), !cir.ptr<!cir.func<!s32i (!ty_S)>>
15+
// CHECK: cir.store %[[#V2]], %[[#V0]] : !cir.ptr<!cir.func<!s32i (!ty_S)>>, !cir.ptr<!cir.ptr<!cir.func<!s32i (!ty_S)>>>
16+
void bar() {
17+
myfptr a = foo;
18+
}

0 commit comments

Comments
 (0)