Skip to content

Commit 3ddb33a

Browse files
bcardosolopeslanza
authored andcommitted
[CIR][CIRGen][LLVMLowering] Add support retrieving thread local global addresses
1 parent cbd6a37 commit 3ddb33a

File tree

9 files changed

+80
-22
lines changed

9 files changed

+80
-22
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+17-10
Original file line numberDiff line numberDiff line change
@@ -1775,24 +1775,31 @@ def GetGlobalOp : CIR_Op<"get_global",
17751775
[Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
17761776
let summary = "Get the address of a global variable";
17771777
let description = [{
1778-
The `cir.get_global` operation retrieves the address pointing to a
1779-
named global variable. If the global variable is marked constant, writing
1780-
to the resulting address (such as through a `cir.store` operation) is
1781-
undefined. Resulting type must always be a `!cir.ptr<...>` type.
1778+
The `cir.get_global` operation retrieves the address pointing to a
1779+
named global variable. If the global variable is marked constant, writing
1780+
to the resulting address (such as through a `cir.store` operation) is
1781+
undefined. Resulting type must always be a `!cir.ptr<...>` type.
17821782

1783-
Example:
1783+
Addresses of thread local globals can only be retrieved if this operation
1784+
is marked `thread_local`, which indicates the address isn't constant.
17841785

1785-
```mlir
1786-
%x = cir.get_global @foo : !cir.ptr<i32>
1787-
```
1786+
Example:
1787+
```mlir
1788+
%x = cir.get_global @foo : !cir.ptr<i32>
1789+
...
1790+
%y = cir.get_global thread_local @batata : !cir.ptr<i32>
1791+
```
17881792
}];
17891793

1790-
let arguments = (ins FlatSymbolRefAttr:$name);
1794+
let arguments = (ins FlatSymbolRefAttr:$name, UnitAttr:$tls);
17911795
let results = (outs Res<CIR_PointerType, "", []>:$addr);
17921796

17931797
// FIXME: we should not be printing `cir.ptr` below, that should come
17941798
// from the pointer type directly.
1795-
let assemblyFormat = "$name `:` `cir.ptr` type($addr) attr-dict";
1799+
let assemblyFormat = [{
1800+
(`thread_local` $tls^)?
1801+
$name `:` `cir.ptr` type($addr) attr-dict
1802+
}];
17961803

17971804
// `GetGlobalOp` is fully verified by its traits.
17981805
let hasVerifier = 0;

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -697,9 +697,11 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
697697
return create<mlir::cir::GlobalOp>(loc, uniqueName, type, isConst, linkage);
698698
}
699699

700-
mlir::Value createGetGlobal(mlir::cir::GlobalOp global) {
701-
return create<mlir::cir::GetGlobalOp>(
702-
global.getLoc(), getPointerTo(global.getSymType()), global.getName());
700+
mlir::Value createGetGlobal(mlir::cir::GlobalOp global,
701+
bool threadLocal = false) {
702+
return create<mlir::cir::GetGlobalOp>(global.getLoc(),
703+
getPointerTo(global.getSymType()),
704+
global.getName(), threadLocal);
703705
}
704706

705707
mlir::Value createGetBitfield(mlir::Location loc, mlir::Type resultType,

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -719,11 +719,10 @@ static LValue buildGlobalVarDeclLValue(CIRGenFunction &CGF, const Expr *E,
719719
if (CGF.getLangOpts().OpenMP)
720720
llvm_unreachable("not implemented");
721721

722+
// Traditional LLVM codegen handles thread local separately, CIR handles
723+
// as part of getAddrOfGlobalVar.
722724
auto V = CGF.CGM.getAddrOfGlobalVar(VD);
723725

724-
if (VD->getTLSKind() != VarDecl::TLS_None)
725-
llvm_unreachable("NYI");
726-
727726
auto RealVarTy = CGF.getTypes().convertTypeForMem(VD->getType());
728727
auto realPtrTy = CGF.getBuilder().getPointerTo(RealVarTy);
729728
if (realPtrTy != V.getType())

clang/lib/CIR/CodeGen/CIRGenModule.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -836,11 +836,12 @@ mlir::Value CIRGenModule::getAddrOfGlobalVar(const VarDecl *D, mlir::Type Ty,
836836
if (!Ty)
837837
Ty = getTypes().convertTypeForMem(ASTTy);
838838

839+
bool tlsAccess = D->getTLSKind() != VarDecl::TLS_None;
839840
auto g = buildGlobal(D, Ty, IsForDefinition);
840841
auto ptrTy =
841842
mlir::cir::PointerType::get(builder.getContext(), g.getSymType());
842-
return builder.create<mlir::cir::GetGlobalOp>(getLoc(D->getSourceRange()),
843-
ptrTy, g.getSymName());
843+
return builder.create<mlir::cir::GetGlobalOp>(
844+
getLoc(D->getSourceRange()), ptrTy, g.getSymName(), tlsAccess);
844845
}
845846

846847
mlir::cir::GlobalViewAttr

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -1634,9 +1634,13 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
16341634
<< "' does not reference a valid cir.global or cir.func";
16351635

16361636
mlir::Type symTy;
1637-
if (auto g = dyn_cast<GlobalOp>(op))
1637+
if (auto g = dyn_cast<GlobalOp>(op)) {
16381638
symTy = g.getSymType();
1639-
else if (auto f = dyn_cast<FuncOp>(op))
1639+
// Verify that for thread local global access, the global needs to
1640+
// be marked with tls bits.
1641+
if (getTls() && !g.getTlsModel())
1642+
return emitOpError("access to global not marked thread local");
1643+
} else if (auto f = dyn_cast<FuncOp>(op))
16401644
symTy = f.getFunctionType();
16411645
else
16421646
llvm_unreachable("shall not get here");

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -1613,7 +1613,16 @@ class CIRGetGlobalOpLowering
16131613

16141614
auto type = getTypeConverter()->convertType(op.getType());
16151615
auto symbol = op.getName();
1616-
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(op, type, symbol);
1616+
mlir::Operation *newop =
1617+
rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), type, symbol);
1618+
1619+
if (op.getTls()) {
1620+
// Handle access to TLS via intrinsic.
1621+
newop = rewriter.create<mlir::LLVM::ThreadlocalAddressOp>(
1622+
op.getLoc(), type, newop->getResult(0));
1623+
}
1624+
1625+
rewriter.replaceOp(op, newop);
16171626
return mlir::success();
16181627
}
16191628
};

clang/test/CIR/CodeGen/tls.c

+11
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@
33
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
44
// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s
55

6+
extern __thread int b;
7+
int c(void) { return *&b; }
8+
// CIR: cir.global "private" external tls_dyn @b : !s32i
9+
// CIR: cir.func @c() -> !s32i
10+
// CIR: %[[TLS_ADDR:.*]] = cir.get_global thread_local @b : cir.ptr <!s32i>
11+
612
__thread int a;
713
// CIR: cir.global external tls_dyn @a = #cir.int<0> : !s32i
14+
15+
// LLVM: @b = external thread_local global i32
816
// LLVM: @a = thread_local global i32 0
17+
18+
// LLVM-LABEL: @c
19+
// LLVM: = call ptr @llvm.threadlocal.address.p0(ptr @b)

clang/test/CIR/IR/global.cir

+13-1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ module {
6363
cir.global external tls_local_dyn @model1 = #cir.int<0> : !s32i
6464
cir.global external tls_init_exec @model2 = #cir.int<0> : !s32i
6565
cir.global external tls_local_exec @model3 = #cir.int<0> : !s32i
66+
67+
cir.global "private" external tls_dyn @batata : !s32i
68+
cir.func @f35() {
69+
%0 = cir.get_global thread_local @batata : cir.ptr <!s32i>
70+
cir.return
71+
}
6672
}
6773

6874
// CHECK: cir.global external @a = #cir.int<3> : !s32i
@@ -91,4 +97,10 @@ module {
9197
// CHECK: cir.global external tls_dyn @model0 = #cir.int<0> : !s32i
9298
// CHECK: cir.global external tls_local_dyn @model1 = #cir.int<0> : !s32i
9399
// CHECK: cir.global external tls_init_exec @model2 = #cir.int<0> : !s32i
94-
// CHECK: cir.global external tls_local_exec @model3 = #cir.int<0> : !s32i
100+
// CHECK: cir.global external tls_local_exec @model3 = #cir.int<0> : !s32i
101+
102+
// CHECK: cir.global "private" external tls_dyn @batata : !s32i
103+
// CHECK: cir.func @f35() {
104+
// CHECK: %0 = cir.get_global thread_local @batata : cir.ptr <!s32i>
105+
// CHECK: cir.return
106+
// CHECK: }

clang/test/CIR/IR/invalid.cir

+13
Original file line numberDiff line numberDiff line change
@@ -1033,3 +1033,16 @@ cir.func @bad_fetch(%x: !cir.ptr<!cir.float>, %y: !cir.float) -> () {
10331033
%12 = cir.atomic.fetch(xor, %x : !cir.ptr<!cir.float>, %y : !cir.float, seq_cst) : !cir.float
10341034
cir.return
10351035
}
1036+
1037+
// -----
1038+
1039+
!s32i = !cir.int<s, 32>
1040+
1041+
module {
1042+
cir.global "private" external @batata : !s32i
1043+
cir.func @f35() {
1044+
// expected-error@+1 {{access to global not marked thread local}}
1045+
%0 = cir.get_global thread_local @batata : cir.ptr <!s32i>
1046+
cir.return
1047+
}
1048+
}

0 commit comments

Comments
 (0)