Skip to content

Commit b9b7ee9

Browse files
seven-milebcardosolopes
authored andcommitted
[CIR][Dialect][Lowering] Add calling convention attribute to FuncOp (#760)
This PR simply adds the calling convention attribute to FuncOp with LLVM Lowering support. The overall approach follows `GlobalLinkageKind`: Extend the ODS, parser, printer and lowering pass. When the call conv is C call conv, it's omitted in the output assembly. --------- Co-authored-by: Bruno Cardoso Lopes <[email protected]>
1 parent 0d73e57 commit b9b7ee9

File tree

6 files changed

+137
-2
lines changed

6 files changed

+137
-2
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2794,6 +2794,19 @@ def BaseClassAddrOp : CIR_Op<"base_class_addr"> {
27942794
// FuncOp
27952795
//===----------------------------------------------------------------------===//
27962796

2797+
// The enumeration values are not necessarily in sync with `clang::CallingConv`
2798+
// or `llvm::CallingConv`.
2799+
def CC_C : I32EnumAttrCase<"C", 1, "c">;
2800+
def CC_SpirKernel : I32EnumAttrCase<"SpirKernel", 2, "spir_kernel">;
2801+
def CC_SpirFunction : I32EnumAttrCase<"SpirFunction", 3, "spir_function">;
2802+
2803+
def CallingConv : I32EnumAttr<
2804+
"CallingConv",
2805+
"calling convention",
2806+
[CC_C, CC_SpirKernel, CC_SpirFunction]> {
2807+
let cppNamespace = "::mlir::cir";
2808+
}
2809+
27972810
def FuncOp : CIR_Op<"func", [
27982811
AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface,
27992812
DeclareOpInterfaceMethods<CIRGlobalValueInterface>,
@@ -2819,6 +2832,9 @@ def FuncOp : CIR_Op<"func", [
28192832
The function linkage information is specified by `linkage`, as defined by
28202833
`GlobalLinkageKind` attribute.
28212834

2835+
The `calling_conv` attribute specifies the calling convention of the function.
2836+
The default calling convention is `CallingConv::C`.
2837+
28222838
A compiler builtin function must be marked as `builtin` for further
28232839
processing when lowering from CIR.
28242840

@@ -2857,6 +2873,9 @@ def FuncOp : CIR_Op<"func", [
28572873
// Linkage information
28582874
cir.func linkonce_odr @some_method(...)
28592875

2876+
// Calling convention information
2877+
cir.func @another_func(...) cc(spir_kernel) extra(#fn_attr)
2878+
28602879
// Builtin function
28612880
cir.func builtin @__builtin_coro_end(!cir.ptr<i8>, !cir.bool) -> !cir.bool
28622881

@@ -2878,6 +2897,8 @@ def FuncOp : CIR_Op<"func", [
28782897
UnitAttr:$dsolocal,
28792898
DefaultValuedAttr<GlobalLinkageKind,
28802899
"GlobalLinkageKind::ExternalLinkage">:$linkage,
2900+
DefaultValuedAttr<CallingConv,
2901+
"CallingConv::C">:$calling_conv,
28812902
ExtraFuncAttr:$extra_attrs,
28822903
OptionalAttr<StrAttr>:$sym_visibility,
28832904
UnitAttr:$comdat,
@@ -2893,6 +2914,7 @@ def FuncOp : CIR_Op<"func", [
28932914
let builders = [OpBuilder<(ins
28942915
"StringRef":$name, "FuncType":$type,
28952916
CArg<"GlobalLinkageKind", "GlobalLinkageKind::ExternalLinkage">:$linkage,
2917+
CArg<"CallingConv", "CallingConv::C">:$callingConv,
28962918
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
28972919
CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)
28982920
>];

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ template <typename Ty> struct EnumTraits {};
157157
}
158158

159159
REGISTER_ENUM_TYPE(GlobalLinkageKind);
160+
REGISTER_ENUM_TYPE(CallingConv);
160161
REGISTER_ENUM_TYPE_WITH_NS(sob, SignedOverflowBehavior);
161162
} // namespace
162163

@@ -176,6 +177,20 @@ static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
176177
return static_cast<RetTy>(index);
177178
}
178179

180+
/// Parse an enum from the keyword, return failure if the keyword is not found.
181+
template <typename EnumTy, typename RetTy = EnumTy>
182+
static ParseResult parseCIRKeyword(AsmParser &parser, RetTy &result) {
183+
SmallVector<StringRef, 10> names;
184+
for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
185+
names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
186+
187+
int index = parseOptionalKeywordAlternative(parser, names);
188+
if (index == -1)
189+
return failure();
190+
result = static_cast<RetTy>(index);
191+
return success();
192+
}
193+
179194
// Check if a region's termination omission is valid and, if so, creates and
180195
// inserts the omitted terminator into the region.
181196
LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
@@ -1874,7 +1889,7 @@ static StringRef getLinkageAttrNameString() { return "linkage"; }
18741889

18751890
void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
18761891
StringRef name, cir::FuncType type,
1877-
GlobalLinkageKind linkage,
1892+
GlobalLinkageKind linkage, CallingConv callingConv,
18781893
ArrayRef<NamedAttribute> attrs,
18791894
ArrayRef<DictionaryAttr> argAttrs) {
18801895
result.addRegion();
@@ -1885,6 +1900,8 @@ void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
18851900
result.addAttribute(
18861901
getLinkageAttrNameString(),
18871902
GlobalLinkageKindAttr::get(builder.getContext(), linkage));
1903+
result.addAttribute(getCallingConvAttrName(result.name),
1904+
CallingConvAttr::get(builder.getContext(), callingConv));
18881905
result.attributes.append(attrs.begin(), attrs.end());
18891906
if (argAttrs.empty())
18901907
return;
@@ -1991,6 +2008,20 @@ ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
19912008
hasAlias = true;
19922009
}
19932010

2011+
// Default to C calling convention if no keyword is provided.
2012+
auto callConvNameAttr = getCallingConvAttrName(state.name);
2013+
CallingConv callConv = CallingConv::C;
2014+
if (parser.parseOptionalKeyword("cc").succeeded()) {
2015+
if (parser.parseLParen().failed())
2016+
return failure();
2017+
if (parseCIRKeyword<CallingConv>(parser, callConv).failed())
2018+
return parser.emitError(loc) << "unknown calling convention";
2019+
if (parser.parseRParen().failed())
2020+
return failure();
2021+
}
2022+
state.addAttribute(callConvNameAttr,
2023+
CallingConvAttr::get(parser.getContext(), callConv));
2024+
19942025
auto parseGlobalDtorCtor =
19952026
[&](StringRef keyword,
19962027
llvm::function_ref<void(std::optional<int> prio)> createAttr)
@@ -2144,6 +2175,7 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
21442175
getGlobalDtorAttrName(),
21452176
getLambdaAttrName(),
21462177
getLinkageAttrName(),
2178+
getCallingConvAttrName(),
21472179
getNoProtoAttrName(),
21482180
getSymVisibilityAttrName(),
21492181
getArgAttrsAttrName(),
@@ -2157,6 +2189,12 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
21572189
p << ")";
21582190
}
21592191

2192+
if (getCallingConv() != CallingConv::C) {
2193+
p << " cc(";
2194+
p << stringifyCallingConv(getCallingConv());
2195+
p << ")";
2196+
}
2197+
21602198
if (auto globalCtor = getGlobalCtorAttr()) {
21612199
p << " global_ctor";
21622200
if (!globalCtor.isDefaultPriority())

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,22 @@ mlir::LLVM::Linkage convertLinkage(mlir::cir::GlobalLinkageKind linkage) {
463463
};
464464
}
465465

466+
mlir::LLVM::CConv convertCallingConv(mlir::cir::CallingConv callinvConv) {
467+
using CIR = mlir::cir::CallingConv;
468+
using LLVM = mlir::LLVM::CConv;
469+
470+
switch (callinvConv) {
471+
case CIR::C:
472+
return LLVM::C;
473+
case CIR::SpirKernel:
474+
return LLVM::SPIR_KERNEL;
475+
case CIR::SpirFunction:
476+
return LLVM::SPIR_FUNC;
477+
default:
478+
llvm_unreachable("Unknown calling convention");
479+
}
480+
}
481+
466482
class CIRCopyOpLowering : public mlir::OpConversionPattern<mlir::cir::CopyOp> {
467483
public:
468484
using mlir::OpConversionPattern<mlir::cir::CopyOp>::OpConversionPattern;
@@ -1529,6 +1545,7 @@ class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {
15291545
if (attr.getName() == mlir::SymbolTable::getSymbolAttrName() ||
15301546
attr.getName() == func.getFunctionTypeAttrName() ||
15311547
attr.getName() == getLinkageAttrNameString() ||
1548+
attr.getName() == func.getCallingConvAttrName() ||
15321549
(filterArgAndResAttrs &&
15331550
(attr.getName() == func.getArgAttrsAttrName() ||
15341551
attr.getName() == func.getResAttrsAttrName())))
@@ -1614,11 +1631,12 @@ class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {
16141631
"expected single location or unknown location here");
16151632

16161633
auto linkage = convertLinkage(op.getLinkage());
1634+
auto cconv = convertCallingConv(op.getCallingConv());
16171635
SmallVector<mlir::NamedAttribute, 4> attributes;
16181636
lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes);
16191637

16201638
auto fn = rewriter.create<mlir::LLVM::LLVMFuncOp>(
1621-
Loc, op.getName(), llvmFnTy, linkage, isDsoLocal, mlir::LLVM::CConv::C,
1639+
Loc, op.getName(), llvmFnTy, linkage, isDsoLocal, cconv,
16221640
mlir::SymbolRefAttr(), attributes);
16231641

16241642
rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end());

clang/test/CIR/IR/func-call-conv.cir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: cir-opt %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir %s
3+
4+
!s32i = !cir.int<s, 32>
5+
6+
#fn_attr = #cir<extra({inline = #cir.inline<no>})>
7+
8+
module {
9+
// CHECK: cir.func @foo() {
10+
cir.func @foo() cc(c) {
11+
cir.return
12+
}
13+
14+
// CHECK: cir.func @bar() cc(spir_kernel)
15+
cir.func @bar() cc(spir_kernel) {
16+
cir.return
17+
}
18+
19+
// CHECK: cir.func @bar_alias() alias(@bar) cc(spir_kernel)
20+
cir.func @bar_alias() alias(@bar) cc(spir_kernel)
21+
22+
// CHECK: cir.func @baz() cc(spir_function) extra(#fn_attr)
23+
cir.func @baz() cc(spir_function) extra(#fn_attr) {
24+
cir.return
25+
}
26+
}
27+

clang/test/CIR/IR/invalid.cir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,3 +1263,13 @@ cir.func @address_space4(%p : !cir.ptr<!u64i, addrspace(foobar)>) { // expected-
12631263
type_qual = [""],
12641264
name = ["foo"]
12651265
>
1266+
1267+
// -----
1268+
1269+
module {
1270+
// expected-error@+1 {{unknown calling convention}}
1271+
cir.func @foo() cc(foobar) {
1272+
cir.return
1273+
}
1274+
}
1275+
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: cir-translate %s -cir-to-llvmir -o %t.ll
2+
// RUN: FileCheck %s --input-file=%t.ll --check-prefix=LLVM
3+
4+
!s32i = !cir.int<s, 32>
5+
module {
6+
// LLVM: define void @foo()
7+
cir.func @foo() cc(c) {
8+
cir.return
9+
}
10+
11+
// LLVM: define spir_kernel void @bar()
12+
cir.func @bar() cc(spir_kernel) {
13+
cir.return
14+
}
15+
16+
// LLVM: define spir_func void @baz()
17+
cir.func @baz() cc(spir_function) {
18+
cir.return
19+
}
20+
}

0 commit comments

Comments
 (0)