Skip to content

Commit 73d2403

Browse files
seven-milelanza
authored andcommitted
[CIR][Dialect] Add convergent attribute to functions for SIMT languages (#840)
Fix #805. This PR includes end-to-end implementation. The `convergent` attribute is set depending on languages, which is wrapped as `langOpts.assumeFunctionsAreConvergent()`. Therefore, in ClangIR, every `cir.func` under `#cir.lang<opencl_c>` is set to be convergent. After lowering to LLVM IR, `PostOrderFunctionAttrs` pass will remove unnecessary `convergent` then. In other words, we will still see `convergent` on every function with `-O0`, but not with default optimization level. The test taken from `CodeGenOpenCL/convergent.cl` is a bit complicated. However, the core of it is that `convergent` is set properly for `convfun()` `non_convfun()` `f()` and `g()`. Merge of two `if` is more or less a result of generating the same LLVM IR as OG.
1 parent 6097fcd commit 73d2403

File tree

5 files changed

+166
-0
lines changed

5 files changed

+166
-0
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

+4
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,10 @@ def NoThrowAttr : CIRUnitAttr<"NoThrow", "nothrow"> {
10421042
let storageType = [{ NoThrowAttr }];
10431043
}
10441044

1045+
def ConvergentAttr : CIRUnitAttr<"Convergent", "convergent"> {
1046+
let storageType = [{ ConvergentAttr }];
1047+
}
1048+
10451049
class CIR_GlobalCtorDtor<string name, string attrMnemonic,
10461050
string sum, string desc>
10471051
: CIR_Attr<"Global" # name, "global_" # attrMnemonic> {

clang/lib/CIR/CodeGen/CIRGenCall.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,9 @@ void CIRGenModule::constructAttributeList(StringRef Name,
440440
if (TargetDecl->hasAttr<ArmLocallyStreamingAttr>())
441441
;
442442
}
443+
444+
getDefaultFunctionAttributes(Name, HasOptnone, AttrOnCallSite, funcAttrs);
445+
443446
}
444447

445448
static mlir::cir::CIRCallOpInterface
@@ -1559,3 +1562,40 @@ mlir::Value CIRGenFunction::buildVAArg(VAArgExpr *VE, Address &VAListAddr) {
15591562
auto vaList = buildVAListRef(VE->getSubExpr()).getPointer();
15601563
return builder.create<mlir::cir::VAArgOp>(loc, type, vaList);
15611564
}
1565+
1566+
static void getTrivialDefaultFunctionAttributes(
1567+
StringRef name, bool hasOptnone, const CodeGenOptions &codeGenOpts,
1568+
const LangOptions &langOpts, bool attrOnCallSite, CIRGenModule &CGM,
1569+
mlir::NamedAttrList &funcAttrs) {
1570+
1571+
if (langOpts.assumeFunctionsAreConvergent()) {
1572+
// Conservatively, mark all functions and calls in CUDA and OpenCL as
1573+
// convergent (meaning, they may call an intrinsically convergent op, such
1574+
// as __syncthreads() / barrier(), and so can't have certain optimizations
1575+
// applied around them). LLVM will remove this attribute where it safely
1576+
// can.
1577+
1578+
auto convgt = mlir::cir::ConvergentAttr::get(CGM.getBuilder().getContext());
1579+
funcAttrs.set(convgt.getMnemonic(), convgt);
1580+
}
1581+
}
1582+
1583+
void CIRGenModule::getTrivialDefaultFunctionAttributes(
1584+
StringRef name, bool hasOptnone, bool attrOnCallSite,
1585+
mlir::NamedAttrList &funcAttrs) {
1586+
::getTrivialDefaultFunctionAttributes(name, hasOptnone, getCodeGenOpts(),
1587+
getLangOpts(), attrOnCallSite, *this,
1588+
funcAttrs);
1589+
}
1590+
1591+
void CIRGenModule::getDefaultFunctionAttributes(StringRef name, bool hasOptnone,
1592+
bool attrOnCallSite,
1593+
mlir::NamedAttrList &funcAttrs) {
1594+
getTrivialDefaultFunctionAttributes(name, hasOptnone, attrOnCallSite,
1595+
funcAttrs);
1596+
// If we're just getting the default, get the default values for mergeable
1597+
// attributes.
1598+
if (!attrOnCallSite) {
1599+
// TODO(cir): addMergableDefaultFunctionAttributes(codeGenOpts, funcAttrs);
1600+
}
1601+
}

clang/lib/CIR/CodeGen/CIRGenModule.h

+13
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,19 @@ class CIRGenModule : public CIRGenTypeCache {
296296
mlir::cir::CallingConv &callingConv,
297297
bool AttrOnCallSite, bool IsThunk);
298298

299+
/// Helper function for getDefaultFunctionAttributes. Builds a set of function
300+
/// attributes which can be simply added to a function.
301+
void getTrivialDefaultFunctionAttributes(StringRef name, bool hasOptnone,
302+
bool attrOnCallSite,
303+
mlir::NamedAttrList &funcAttrs);
304+
305+
/// Helper function for constructAttributeList and
306+
/// addDefaultFunctionDefinitionAttributes. Builds a set of function
307+
/// attributes to add to a function with the given properties.
308+
void getDefaultFunctionAttributes(StringRef name, bool hasOptnone,
309+
bool attrOnCallSite,
310+
mlir::NamedAttrList &funcAttrs);
311+
299312
/// Will return a global variable of the given type. If a variable with a
300313
/// different type already exists then a new variable with the right type
301314
/// will be created and all uses of the old variable will be replaced with a

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVMIR.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ class CIRDialectLLVMIRTranslationInterface
109109
llvmFunc->addFnAttr(llvm::Attribute::OptimizeNone);
110110
} else if (mlir::dyn_cast<mlir::cir::NoThrowAttr>(attr.getValue())) {
111111
llvmFunc->addFnAttr(llvm::Attribute::NoUnwind);
112+
} else if (mlir::dyn_cast<mlir::cir::ConvergentAttr>(attr.getValue())) {
113+
llvmFunc->addFnAttr(llvm::Attribute::Convergent);
112114
} else if (auto clKernelMetadata =
113115
mlir::dyn_cast<mlir::cir::OpenCLKernelMetadataAttr>(
114116
attr.getValue())) {
+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// RUN: %clang_cc1 -fclangir -triple spirv64-unknown-unknown -emit-cir %s -o %t.cir
2+
// RUN: FileCheck %s --input-file=%t.cir --check-prefix=CIR
3+
// RUN: %clang_cc1 -fclangir -triple spirv64-unknown-unknown -emit-llvm %s -o %t.ll
4+
// RUN: FileCheck %s --input-file=%t.ll --check-prefix=LLVM
5+
6+
// In ClangIR for OpenCL, all functions should be marked convergent.
7+
// In LLVM IR, it is initially assumed convergent, but can be deduced to not require it.
8+
9+
// CIR: #fn_attr[[CONV_NOINLINE_ATTR:[0-9]*]] = #cir<extra({convergent = #cir.convergent, inline = #cir.inline<no>
10+
// CIR-NEXT: #fn_attr[[CONV_DECL_ATTR:[0-9]*]] = #cir<extra({convergent = #cir.convergent
11+
// CIR-NEXT: #fn_attr[[CONV_NOTHROW_ATTR:[0-9]*]] = #cir<extra({convergent = #cir.convergent, nothrow = #cir.nothrow
12+
13+
__attribute__((noinline))
14+
void non_convfun(void) {
15+
volatile int* p;
16+
*p = 0;
17+
}
18+
// CIR: cir.func @non_convfun(){{.*}} extra(#fn_attr[[CONV_NOINLINE_ATTR]])
19+
// LLVM: define{{.*}} spir_func void @non_convfun() local_unnamed_addr #[[NON_CONV_ATTR:[0-9]+]]
20+
// LLVM: ret void
21+
22+
// External functions should be assumed convergent.
23+
void f(void);
24+
// CIR: cir.func{{.+}} @f(){{.*}} extra(#fn_attr[[CONV_DECL_ATTR]])
25+
// LLVM: declare {{.+}} spir_func void @f() local_unnamed_addr #[[CONV_ATTR:[0-9]+]]
26+
void g(void);
27+
// CIR: cir.func{{.+}} @g(){{.*}} extra(#fn_attr[[CONV_DECL_ATTR]])
28+
// LLVM: declare {{.+}} spir_func void @g() local_unnamed_addr #[[CONV_ATTR]]
29+
30+
// Test two if's are merged and non_convfun duplicated.
31+
void test_merge_if(int a) {
32+
if (a) {
33+
f();
34+
}
35+
non_convfun();
36+
if (a) {
37+
g();
38+
}
39+
}
40+
// CIR: cir.func @test_merge_if{{.*}} extra(#fn_attr[[CONV_NOTHROW_ATTR]])
41+
42+
// The LLVM IR below is equivalent to:
43+
// if (a) {
44+
// f();
45+
// non_convfun();
46+
// g();
47+
// } else {
48+
// non_convfun();
49+
// }
50+
51+
// LLVM-LABEL: define{{.*}} spir_func void @test_merge_if
52+
// LLVM: %[[tobool:.+]] = icmp eq i32 %[[ARG:.+]], 0
53+
// LLVM: br i1 %[[tobool]], label %[[if_end3_critedge:[^,]+]], label %[[if_then:[^,]+]]
54+
55+
// LLVM: [[if_end3_critedge]]:
56+
// LLVM: tail call spir_func void @non_convfun()
57+
// LLVM: br label %[[if_end3:[^,]+]]
58+
59+
// LLVM: [[if_then]]:
60+
// LLVM: tail call spir_func void @f()
61+
// LLVM: tail call spir_func void @non_convfun()
62+
// LLVM: tail call spir_func void @g()
63+
64+
// LLVM: br label %[[if_end3]]
65+
66+
// LLVM: [[if_end3]]:
67+
// LLVM: ret void
68+
69+
70+
void convfun(void) __attribute__((convergent));
71+
// CIR: cir.func{{.+}} @convfun(){{.*}} extra(#fn_attr[[CONV_DECL_ATTR]])
72+
// LLVM: declare {{.+}} spir_func void @convfun() local_unnamed_addr #[[CONV_ATTR]]
73+
74+
// Test two if's are not merged.
75+
void test_no_merge_if(int a) {
76+
if (a) {
77+
f();
78+
}
79+
convfun();
80+
if(a) {
81+
g();
82+
}
83+
}
84+
// CIR: cir.func @test_no_merge_if{{.*}} extra(#fn_attr[[CONV_NOTHROW_ATTR]])
85+
86+
// LLVM-LABEL: define{{.*}} spir_func void @test_no_merge_if
87+
// LLVM: %[[tobool:.+]] = icmp eq i32 %[[ARG:.+]], 0
88+
// LLVM: br i1 %[[tobool]], label %[[if_end:[^,]+]], label %[[if_then:[^,]+]]
89+
// LLVM: [[if_then]]:
90+
// LLVM: tail call spir_func void @f()
91+
// LLVM-NOT: call spir_func void @convfun()
92+
// LLVM-NOT: call spir_func void @g()
93+
// LLVM: br label %[[if_end]]
94+
// LLVM: [[if_end]]:
95+
// LLVM-NOT: phi i1
96+
// LLVM: tail call spir_func void @convfun()
97+
// LLVM: br i1 %[[tobool]], label %[[if_end3:[^,]+]], label %[[if_then2:[^,]+]]
98+
// LLVM: [[if_then2]]:
99+
// LLVM: tail call spir_func void @g()
100+
// LLVM: br label %[[if_end3:[^,]+]]
101+
// LLVM: [[if_end3]]:
102+
// LLVM: ret void
103+
104+
105+
// LLVM attribute definitions.
106+
// LLVM-NOT: attributes #[[NON_CONV_ATTR]] = { {{.*}}convergent{{.*}} }
107+
// LLVM: attributes #[[CONV_ATTR]] = { {{.*}}convergent{{.*}} }

0 commit comments

Comments
 (0)