Skip to content

Commit a27987e

Browse files
committed
[Lowering][DirectToLLVM] Fix calling variadic functions (llvm#945)
After 5da4310, the LLVM dialect requires the variadic callee type to be present for variadic calls. The op builders take care of this automatically if you pass the function type, so change our lowering logic to do so. Add tests for this as well as a missing test for indirect function call lowering. Fixes llvm#913 Fixes llvm#933
1 parent 7dcb8b2 commit a27987e

File tree

3 files changed

+82
-26
lines changed

3 files changed

+82
-26
lines changed

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -937,17 +937,14 @@ rewriteToCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
937937

938938
auto cconv = convertCallingConv(callIf.getCallingConv());
939939

940+
mlir::LLVM::LLVMFunctionType llvmFnTy;
940941
if (calleeAttr) { // direct call
941-
if (landingPadBlock) {
942-
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
943-
op, llvmResults, calleeAttr, callOperands, continueBlock,
944-
mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
945-
newOp.setCConv(cconv);
946-
} else {
947-
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
948-
op, llvmResults, calleeAttr, callOperands);
949-
newOp.setCConv(cconv);
950-
}
942+
auto fn =
943+
mlir::SymbolTable::lookupNearestSymbolFrom<mlir::FunctionOpInterface>(
944+
op, calleeAttr);
945+
assert(fn && "Did not find function for call");
946+
llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
947+
converter->convertType(fn.getFunctionType()));
951948
} else { // indirect call
952949
assert(op->getOperands().size() &&
953950
"operands list must no be empty for the indirect call");
@@ -956,21 +953,18 @@ rewriteToCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
956953
auto ptyp = dyn_cast<mlir::cir::PointerType>(typ);
957954
auto ftyp = dyn_cast<mlir::cir::FuncType>(ptyp.getPointee());
958955
assert(ftyp && "expected a pointer to a function as the first operand");
956+
llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(converter->convertType(ftyp));
957+
}
959958

960-
if (landingPadBlock) {
961-
auto llvmFnTy =
962-
dyn_cast<mlir::LLVM::LLVMFunctionType>(converter->convertType(ftyp));
963-
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
964-
op, llvmFnTy, mlir::FlatSymbolRefAttr{}, callOperands, continueBlock,
965-
mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
966-
newOp.setCConv(cconv);
967-
} else {
968-
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
969-
op,
970-
dyn_cast<mlir::LLVM::LLVMFunctionType>(converter->convertType(ftyp)),
971-
callOperands);
972-
newOp.setCConv(cconv);
973-
}
959+
if (landingPadBlock) {
960+
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
961+
op, llvmFnTy, calleeAttr, callOperands, continueBlock,
962+
mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
963+
newOp.setCConv(cconv);
964+
} else {
965+
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
966+
op, llvmFnTy, calleeAttr, callOperands);
967+
newOp.setCConv(cconv);
974968
}
975969
return mlir::success();
976970
}

clang/test/CIR/Lowering/call.cir

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: cir-opt %s -cir-to-llvm -o - | FileCheck %s -check-prefix=MLIR
22
// RUN: cir-translate %s -cir-to-llvmir | FileCheck %s -check-prefix=LLVM
33

4+
!s32i = !cir.int<s, 32>
45
module {
56
cir.func @a() {
67
cir.return
@@ -36,4 +37,66 @@ module {
3637
cir.return %0 : !cir.ptr<i32>
3738
}
3839

40+
// check indirect call lowering
41+
cir.global "private" external @fp : !cir.ptr<!cir.func<!s32i (!s32i)>>
42+
cir.func @callIndirect(%arg: !s32i) -> !s32i {
43+
%fpp = cir.get_global @fp : !cir.ptr<!cir.ptr<!cir.func<!s32i (!s32i)>>>
44+
%fp = cir.load %fpp : !cir.ptr<!cir.ptr<!cir.func<!s32i (!s32i)>>>, !cir.ptr<!cir.func<!s32i (!s32i)>>
45+
%retval = cir.call %fp(%arg) : (!cir.ptr<!cir.func<!s32i (!s32i)>>, !s32i) -> !s32i
46+
cir.return %retval : !s32i
47+
}
48+
49+
// MLIR: llvm.mlir.global external @fp() {addr_space = 0 : i32} : !llvm.ptr
50+
// MLIR: llvm.func @callIndirect(%arg0: i32) -> i32
51+
// MLIR-NEXT: %0 = llvm.mlir.addressof @fp : !llvm.ptr
52+
// MLIR-NEXT: %1 = llvm.load %0 {{.*}} : !llvm.ptr -> !llvm.ptr
53+
// MLIR-NEXT: %2 = llvm.call %1(%arg0) : !llvm.ptr, (i32) -> i32
54+
// MLIR-NEXT: llvm.return %2 : i32
55+
56+
// LLVM: define i32 @callIndirect(i32 %0)
57+
// LLVM-NEXT: %2 = load ptr, ptr @fp
58+
// LLVM-NEXT: %3 = call i32 %2(i32 %0)
59+
// LLVM-NEXT: ret i32 %3
60+
61+
// check direct vararg call lowering
62+
cir.func private @varargCallee(!s32i, ...) -> !s32i
63+
cir.func @varargCaller() -> !s32i {
64+
%zero = cir.const #cir.int<0> : !s32i
65+
%retval = cir.call @varargCallee(%zero, %zero) : (!s32i, !s32i) -> !s32i
66+
cir.return %retval : !s32i
67+
}
68+
69+
// MLIR: llvm.func @varargCallee(i32, ...) -> i32
70+
// MLIR: llvm.func @varargCaller() -> i32
71+
// MLIR-NEXT: %0 = llvm.mlir.constant(0 : i32) : i32
72+
// MLIR-NEXT: %1 = llvm.call @varargCallee(%0, %0) vararg(!llvm.func<i32 (i32, ...)>) : (i32, i32) -> i32
73+
// MLIR-NEXT: llvm.return %1 : i32
74+
75+
// LLVM: define i32 @varargCaller()
76+
// LLVM-NEXT: %1 = call i32 (i32, ...) @varargCallee(i32 0, i32 0)
77+
// LLVM-NEXT: ret i32 %1
78+
79+
// check indirect vararg call lowering
80+
cir.global "private" external @varargfp : !cir.ptr<!cir.func<!s32i (!s32i, ...)>>
81+
cir.func @varargCallIndirect() -> !s32i {
82+
%fpp = cir.get_global @varargfp : !cir.ptr<!cir.ptr<!cir.func<!s32i (!s32i, ...)>>>
83+
%fp = cir.load %fpp : !cir.ptr<!cir.ptr<!cir.func<!s32i (!s32i, ...)>>>, !cir.ptr<!cir.func<!s32i (!s32i, ...)>>
84+
%zero = cir.const #cir.int<0> : !s32i
85+
%retval = cir.call %fp(%zero, %zero) : (!cir.ptr<!cir.func<!s32i (!s32i, ...)>>, !s32i, !s32i) -> !s32i
86+
cir.return %retval : !s32i
87+
}
88+
89+
// MLIR: llvm.mlir.global external @varargfp() {addr_space = 0 : i32} : !llvm.ptr
90+
// MLIR: llvm.func @varargCallIndirect() -> i32
91+
// MLIR-NEXT: %0 = llvm.mlir.addressof @varargfp : !llvm.ptr
92+
// MLIR-NEXT: %1 = llvm.load %0 {{.*}} : !llvm.ptr -> !llvm.ptr
93+
// MLIR-NEXT: %2 = llvm.mlir.constant(0 : i32) : i32
94+
// MLIR-NEXT: %3 = llvm.call %1(%2, %2) vararg(!llvm.func<i32 (i32, ...)>) : !llvm.ptr, (i32, i32) -> i32
95+
// MLIR-NEXT: llvm.return %3 : i32
96+
97+
// LLVM: define i32 @varargCallIndirect()
98+
// LLVM-NEXT: %1 = load ptr, ptr @varargfp
99+
// LLVM-NEXT: %2 = call i32 (i32, ...) %1(i32 0, i32 0)
100+
// LLVM-NEXT: ret i32 %2
101+
39102
} // end module

clang/test/CIR/Lowering/hello.cir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// RUN: cir-opt %s -cir-to-llvm -o %t.mlir
22
// RUN: FileCheck --input-file=%t.mlir %s
3-
// XFAIL: *
43

54
!s32i = !cir.int<s, 32>
65
!s8i = !cir.int<s, 8>
@@ -28,7 +27,7 @@ module @"/tmp/test.raw" attributes {cir.lang = #cir.lang<c>, cir.sob = #cir.sign
2827
// CHECK: %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i64) -> !llvm.ptr
2928
// CHECK: %2 = llvm.mlir.addressof @".str" : !llvm.ptr
3029
// CHECK: %3 = llvm.getelementptr %2[0] : (!llvm.ptr) -> !llvm.ptr, i8
31-
// CHECK: %4 = llvm.call @printf(%3) : (!llvm.ptr) -> i32
30+
// CHECK: %4 = llvm.call @printf(%3) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr) -> i32
3231
// CHECK: %5 = llvm.mlir.constant(0 : i32) : i32
3332
// CHECK: llvm.store %5, %1 {{.*}} : i32, !llvm.ptr
3433
// CHECK: %6 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32

0 commit comments

Comments
 (0)