Skip to content

Commit 952f520

Browse files
gitoleglanza
authored andcommitted
[CIR][ABI][Lowering] Supports call by function pointer in the calling convention lowering pass (#1034)
This PR adds a support for calls by function pointers. @sitio-couto I think would be great if you'll also take a look
1 parent 8c9efae commit 952f520

File tree

4 files changed

+57
-7
lines changed

4 files changed

+57
-7
lines changed

clang/lib/CIR/Dialect/Transforms/CallConvLowering.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ struct CallConvLowering {
103103
rewriter.setInsertionPoint(op);
104104
auto typ = op.getIndirectCall().getType();
105105
if (isFuncPointerTy(typ)) {
106-
cir_cconv_unreachable("Indirect calls NYI");
106+
lowerModule->rewriteFunctionCall(op);
107107
}
108108
}
109109

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp

+26-5
Original file line numberDiff line numberDiff line change
@@ -630,8 +630,16 @@ LogicalResult LowerFunction::rewriteCallOp(CallOp op,
630630
// NOTE(cir): There is no direct way to fetch the function type from the
631631
// CallOp, so we fetch it from the source function. This assumes the
632632
// function definition has not yet been lowered.
633-
cir_cconv_assert(SrcFn && "No source function");
634-
auto fnType = SrcFn.getFunctionType();
633+
634+
FuncType fnType;
635+
if (SrcFn) {
636+
fnType = SrcFn.getFunctionType();
637+
} else if (op.isIndirect()) {
638+
if (auto ptrTy = dyn_cast<PointerType>(op.getIndirectCall().getType()))
639+
fnType = dyn_cast<FuncType>(ptrTy.getPointee());
640+
}
641+
642+
cir_cconv_assert(fnType && "No callee function type");
635643

636644
// Rewrite the call operation to abide to the ABI calling convention.
637645
auto Ret = rewriteCallOp(fnType, SrcFn, op, retValSlot);
@@ -687,7 +695,7 @@ Value LowerFunction::rewriteCallOp(FuncType calleeTy, FuncOp origCallee,
687695
//
688696
// Chain calls use this same code path to add the invisible chain parameter
689697
// to the function type.
690-
if (origCallee.getNoProto() || Chain) {
698+
if ((origCallee && origCallee.getNoProto()) || Chain) {
691699
cir_cconv_assert_or_abort(::cir::MissingFeatures::ABINoProtoFunctions(),
692700
"NYI");
693701
}
@@ -870,8 +878,21 @@ Value LowerFunction::rewriteCallOp(const LowerFunctionInfo &CallInfo,
870878
// NOTE(cir): We don't know if the callee was already lowered, so we only
871879
// fetch the name from the callee, while the return type is fetch from the
872880
// lowering types manager.
873-
CallOp newCallOp = rewriter.create<CallOp>(
874-
loc, Caller.getCalleeAttr(), IRFuncTy.getReturnType(), IRCallArgs);
881+
882+
CallOp newCallOp;
883+
884+
if (Caller.isIndirect()) {
885+
rewriter.setInsertionPoint(Caller);
886+
auto val = Caller.getIndirectCall();
887+
auto ptrTy = PointerType::get(val.getContext(), IRFuncTy);
888+
auto callee =
889+
rewriter.create<CastOp>(val.getLoc(), ptrTy, CastKind::bitcast, val);
890+
newCallOp = rewriter.create<CallOp>(loc, callee, IRFuncTy, IRCallArgs);
891+
} else {
892+
newCallOp = rewriter.create<CallOp>(loc, Caller.getCalleeAttr(),
893+
IRFuncTy.getReturnType(), IRCallArgs);
894+
}
895+
875896
auto extraAttrs =
876897
rewriter.getAttr<ExtraFuncAttributesAttr>(rewriter.getDictionaryAttr({}));
877898
newCallOp->setAttr("extra_attrs", extraAttrs);

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerModule.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class LowerModule {
9696
LogicalResult rewriteFunctionDefinition(FuncOp op);
9797

9898
// Rewrite CIR CallOp to match the target ABI.
99-
LogicalResult rewriteFunctionCall(CallOp callOp, FuncOp funcOp);
99+
LogicalResult rewriteFunctionCall(CallOp callOp, FuncOp funcOp = {});
100100
};
101101

102102
std::unique_ptr<LowerModule> createLowerModule(ModuleOp module,

clang/test/CIR/CallConvLowering/x86_64/fptrs.c

+29
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir-flat -fclangir-call-conv-lowering %s -o - | FileCheck %s
2+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm -fclangir-call-conv-lowering %s -o -| FileCheck %s -check-prefix=LLVM
23

34
typedef struct {
45
int a;
@@ -16,3 +17,31 @@ int foo(S s) { return 42 + s.a; }
1617
void bar() {
1718
myfptr a = foo;
1819
}
20+
21+
// CHECK: cir.func {{.*@baz}}(%arg0: !s32i
22+
// CHECK: %[[#V0:]] = cir.alloca !ty_S, !cir.ptr<!ty_S>, [""] {alignment = 4 : i64}
23+
// CHECK: %[[#V1:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_S>), !cir.ptr<!s32i>
24+
// CHECK: cir.store %arg0, %[[#V1]] : !s32i, !cir.ptr<!s32i>
25+
// CHECK: %[[#V2:]] = cir.alloca !cir.ptr<!cir.func<!s32i (!ty_S)>>, !cir.ptr<!cir.ptr<!cir.func<!s32i (!ty_S)>>>, ["a", init]
26+
// CHECK: %[[#V3:]] = cir.get_global @foo : !cir.ptr<!cir.func<!s32i (!s32i)>>
27+
// CHECK: %[[#V4:]] = cir.cast(bitcast, %[[#V3]] : !cir.ptr<!cir.func<!s32i (!s32i)>>), !cir.ptr<!cir.func<!s32i (!ty_S)>>
28+
// CHECK: cir.store %[[#V4]], %[[#V2]] : !cir.ptr<!cir.func<!s32i (!ty_S)>>, !cir.ptr<!cir.ptr<!cir.func<!s32i (!ty_S)>>>
29+
// CHECK: %[[#V5:]] = cir.load %[[#V2]] : !cir.ptr<!cir.ptr<!cir.func<!s32i (!ty_S)>>>, !cir.ptr<!cir.func<!s32i (!ty_S)>>
30+
// CHECK: %[[#V6:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_S>), !cir.ptr<!s32i>
31+
// CHECK: %[[#V7:]] = cir.load %[[#V6]] : !cir.ptr<!s32i>, !s32i
32+
// CHECK: %[[#V8:]] = cir.cast(bitcast, %[[#V5]] : !cir.ptr<!cir.func<!s32i (!ty_S)>>), !cir.ptr<!cir.func<!s32i (!s32i)>>
33+
// CHECK: %[[#V9:]] = cir.call %[[#V8]](%[[#V7]]) : (!cir.ptr<!cir.func<!s32i (!s32i)>>, !s32i) -> !s32i
34+
35+
// LLVM: define dso_local void @baz(i32 %0)
36+
// LLVM: %[[#V1:]] = alloca %struct.S, i64 1
37+
// LLVM: store i32 %0, ptr %[[#V1]]
38+
// LLVM: %[[#V2:]] = alloca ptr, i64 1
39+
// LLVM: store ptr @foo, ptr %[[#V2]]
40+
// LLVM: %[[#V3:]] = load ptr, ptr %[[#V2]]
41+
// LLVM: %[[#V4:]] = load i32, ptr %[[#V1]]
42+
// LLVM: %[[#V5:]] = call i32 %[[#V3]](i32 %[[#V4]])
43+
44+
void baz(S s) {
45+
myfptr a = foo;
46+
a(s);
47+
}

0 commit comments

Comments
 (0)