Skip to content

Commit 58e76e7

Browse files
ghehglanza
authored andcommitted
[CIR][LLVMLowering] Lower cir.objectsize (#545)
Lowers `cir.objectsize` to `llvm.objectsize`
1 parent 967c779 commit 58e76e7

File tree

10 files changed

+83
-82
lines changed

10 files changed

+83
-82
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+10-17
Original file line numberDiff line numberDiff line change
@@ -1775,31 +1775,24 @@ def GetGlobalOp : CIR_Op<"get_global",
17751775
[Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
17761776
let summary = "Get the address of a global variable";
17771777
let description = [{
1778-
The `cir.get_global` operation retrieves the address pointing to a
1779-
named global variable. If the global variable is marked constant, writing
1780-
to the resulting address (such as through a `cir.store` operation) is
1781-
undefined. Resulting type must always be a `!cir.ptr<...>` type.
1778+
The `cir.get_global` operation retrieves the address pointing to a
1779+
named global variable. If the global variable is marked constant, writing
1780+
to the resulting address (such as through a `cir.store` operation) is
1781+
undefined. Resulting type must always be a `!cir.ptr<...>` type.
17821782

1783-
Addresses of thread local globals can only be retrieved if this operation
1784-
is marked `thread_local`, which indicates the address isn't constant.
1783+
Example:
17851784

1786-
Example:
1787-
```mlir
1788-
%x = cir.get_global @foo : !cir.ptr<i32>
1789-
...
1790-
%y = cir.get_global thread_local @batata : !cir.ptr<i32>
1791-
```
1785+
```mlir
1786+
%x = cir.get_global @foo : !cir.ptr<i32>
1787+
```
17921788
}];
17931789

1794-
let arguments = (ins FlatSymbolRefAttr:$name, UnitAttr:$tls);
1790+
let arguments = (ins FlatSymbolRefAttr:$name);
17951791
let results = (outs Res<CIR_PointerType, "", []>:$addr);
17961792

17971793
// FIXME: we should not be printing `cir.ptr` below, that should come
17981794
// from the pointer type directly.
1799-
let assemblyFormat = [{
1800-
(`thread_local` $tls^)?
1801-
$name `:` `cir.ptr` type($addr) attr-dict
1802-
}];
1795+
let assemblyFormat = "$name `:` `cir.ptr` type($addr) attr-dict";
18031796

18041797
// `GetGlobalOp` is fully verified by its traits.
18051798
let hasVerifier = 0;

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+3-5
Original file line numberDiff line numberDiff line change
@@ -697,11 +697,9 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
697697
return create<mlir::cir::GlobalOp>(loc, uniqueName, type, isConst, linkage);
698698
}
699699

700-
mlir::Value createGetGlobal(mlir::cir::GlobalOp global,
701-
bool threadLocal = false) {
702-
return create<mlir::cir::GetGlobalOp>(global.getLoc(),
703-
getPointerTo(global.getSymType()),
704-
global.getName(), threadLocal);
700+
mlir::Value createGetGlobal(mlir::cir::GlobalOp global) {
701+
return create<mlir::cir::GetGlobalOp>(
702+
global.getLoc(), getPointerTo(global.getSymType()), global.getName());
705703
}
706704

707705
mlir::Value createGetBitfield(mlir::Location loc, mlir::Type resultType,

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -720,10 +720,11 @@ static LValue buildGlobalVarDeclLValue(CIRGenFunction &CGF, const Expr *E,
720720
if (CGF.getLangOpts().OpenMP)
721721
llvm_unreachable("not implemented");
722722

723-
// Traditional LLVM codegen handles thread local separately, CIR handles
724-
// as part of getAddrOfGlobalVar.
725723
auto V = CGF.CGM.getAddrOfGlobalVar(VD);
726724

725+
if (VD->getTLSKind() != VarDecl::TLS_None)
726+
llvm_unreachable("NYI");
727+
727728
auto RealVarTy = CGF.getTypes().convertTypeForMem(VD->getType());
728729
auto realPtrTy = CGF.getBuilder().getPointerTo(RealVarTy);
729730
if (realPtrTy != V.getType())

clang/lib/CIR/CodeGen/CIRGenModule.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -837,12 +837,11 @@ mlir::Value CIRGenModule::getAddrOfGlobalVar(const VarDecl *D, mlir::Type Ty,
837837
if (!Ty)
838838
Ty = getTypes().convertTypeForMem(ASTTy);
839839

840-
bool tlsAccess = D->getTLSKind() != VarDecl::TLS_None;
841840
auto g = buildGlobal(D, Ty, IsForDefinition);
842841
auto ptrTy =
843842
mlir::cir::PointerType::get(builder.getContext(), g.getSymType());
844-
return builder.create<mlir::cir::GetGlobalOp>(
845-
getLoc(D->getSourceRange()), ptrTy, g.getSymName(), tlsAccess);
843+
return builder.create<mlir::cir::GetGlobalOp>(getLoc(D->getSourceRange()),
844+
ptrTy, g.getSymName());
846845
}
847846

848847
mlir::cir::GlobalViewAttr

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -1634,13 +1634,9 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
16341634
<< "' does not reference a valid cir.global or cir.func";
16351635

16361636
mlir::Type symTy;
1637-
if (auto g = dyn_cast<GlobalOp>(op)) {
1637+
if (auto g = dyn_cast<GlobalOp>(op))
16381638
symTy = g.getSymType();
1639-
// Verify that for thread local global access, the global needs to
1640-
// be marked with tls bits.
1641-
if (getTls() && !g.getTlsModel())
1642-
return emitOpError("access to global not marked thread local");
1643-
} else if (auto f = dyn_cast<FuncOp>(op))
1639+
else if (auto f = dyn_cast<FuncOp>(op))
16441640
symTy = f.getFunctionType();
16451641
else
16461642
llvm_unreachable("shall not get here");

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+33-12
Original file line numberDiff line numberDiff line change
@@ -1614,16 +1614,7 @@ class CIRGetGlobalOpLowering
16141614

16151615
auto type = getTypeConverter()->convertType(op.getType());
16161616
auto symbol = op.getName();
1617-
mlir::Operation *newop =
1618-
rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), type, symbol);
1619-
1620-
if (op.getTls()) {
1621-
// Handle access to TLS via intrinsic.
1622-
newop = rewriter.create<mlir::LLVM::ThreadlocalAddressOp>(
1623-
op.getLoc(), type, newop->getResult(0));
1624-
}
1625-
1626-
rewriter.replaceOp(op, newop);
1617+
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(op, type, symbol);
16271618
return mlir::success();
16281619
}
16291620
};
@@ -2288,6 +2279,36 @@ class CIRBitClrsbOpLowering
22882279
}
22892280
};
22902281

2282+
class CIRObjSizeOpLowering
2283+
: public mlir::OpConversionPattern<mlir::cir::ObjSizeOp> {
2284+
public:
2285+
using OpConversionPattern<mlir::cir::ObjSizeOp>::OpConversionPattern;
2286+
2287+
mlir::LogicalResult
2288+
matchAndRewrite(mlir::cir::ObjSizeOp op, OpAdaptor adaptor,
2289+
mlir::ConversionPatternRewriter &rewriter) const override {
2290+
auto llvmResTy = getTypeConverter()->convertType(op.getType());
2291+
auto loc = op->getLoc();
2292+
2293+
auto llvmIntrinNameAttr =
2294+
mlir::StringAttr::get(rewriter.getContext(), "llvm.objectsize");
2295+
mlir::cir::SizeInfoType kindInfo = op.getKind();
2296+
auto falseValue = rewriter.create<mlir::LLVM::ConstantOp>(
2297+
loc, rewriter.getI1Type(), false);
2298+
auto trueValue = rewriter.create<mlir::LLVM::ConstantOp>(
2299+
loc, rewriter.getI1Type(), true);
2300+
2301+
rewriter.replaceOpWithNewOp<mlir::LLVM::CallIntrinsicOp>(
2302+
op, llvmResTy, llvmIntrinNameAttr,
2303+
mlir::ValueRange{adaptor.getPtr(),
2304+
kindInfo == mlir::cir::SizeInfoType::max ? falseValue
2305+
: trueValue,
2306+
trueValue, op.getDynamic() ? trueValue : falseValue});
2307+
2308+
return mlir::LogicalResult::success();
2309+
}
2310+
};
2311+
22912312
class CIRBitClzOpLowering
22922313
: public mlir::OpConversionPattern<mlir::cir::BitClzOp> {
22932314
public:
@@ -3035,8 +3056,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
30353056
CIRVectorShuffleVecLowering, CIRStackSaveLowering,
30363057
CIRStackRestoreLowering, CIRUnreachableLowering, CIRTrapLowering,
30373058
CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering,
3038-
CIRPrefetchLowering, CIRIsConstantOpLowering>(converter,
3039-
patterns.getContext());
3059+
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering>(
3060+
converter, patterns.getContext());
30403061
}
30413062

30423063
namespace {
+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
4+
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM
5+
6+
void b(void *__attribute__((pass_object_size(0))));
7+
void e(void *__attribute__((pass_object_size(2))));
8+
void c() {
9+
int a;
10+
int d[a];
11+
b(d);
12+
e(d);
13+
}
14+
15+
// CIR: cir.func no_proto @c()
16+
// CIR: [[TMP0:%.*]] = cir.alloca !s32i, cir.ptr <!s32i>, %{{[0-9]+}} : !u64i, ["vla"] {alignment = 16 : i64}
17+
// CIR: [[TMP1:%.*]] = cir.cast(bitcast, [[TMP0]] : !cir.ptr<!s32i>), !cir.ptr<!void>
18+
// CIR-NEXT: [[TMP2:%.*]] = cir.objsize([[TMP1]] : <!void>, max) -> !u64i
19+
// CIR-NEXT: cir.call @b([[TMP1]], [[TMP2]]) : (!cir.ptr<!void>, !u64i) -> ()
20+
// CIR: [[TMP3:%.*]] = cir.cast(bitcast, [[TMP0]] : !cir.ptr<!s32i>), !cir.ptr<!void>
21+
// CIR: [[TMP4:%.*]] = cir.objsize([[TMP3]] : <!void>, min) -> !u64i
22+
// CIR-NEXT: cir.call @e([[TMP3]], [[TMP4]]) : (!cir.ptr<!void>, !u64i) -> ()
23+
24+
// LLVM: define void @c()
25+
// LLVM: [[TMP0:%.*]] = alloca i32, i64 %{{[0-9]+}},
26+
// LLVM: [[TMP1:%.*]] = call i64 @llvm.objectsize.i64.p0(ptr [[TMP0]], i1 false, i1 true, i1 false),
27+
// LLVM-NEXT: call void @b(ptr [[TMP0]], i64 [[TMP1]])
28+
// LLVM: [[TMP2:%.*]] = call i64 @llvm.objectsize.i64.p0(ptr [[TMP0]], i1 true, i1 true, i1 false),
29+
// LLVM-NEXT: call void @e(ptr [[TMP0]], i64 [[TMP2]])

clang/test/CIR/CodeGen/tls.c

-11
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,6 @@
33
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
44
// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s
55

6-
extern __thread int b;
7-
int c(void) { return *&b; }
8-
// CIR: cir.global "private" external tls_dyn @b : !s32i
9-
// CIR: cir.func @c() -> !s32i
10-
// CIR: %[[TLS_ADDR:.*]] = cir.get_global thread_local @b : cir.ptr <!s32i>
11-
126
__thread int a;
137
// CIR: cir.global external tls_dyn @a = #cir.int<0> : !s32i
14-
15-
// LLVM: @b = external thread_local global i32
168
// LLVM: @a = thread_local global i32 0
17-
18-
// LLVM-LABEL: @c
19-
// LLVM: = call ptr @llvm.threadlocal.address.p0(ptr @b)

clang/test/CIR/IR/global.cir

+1-13
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,6 @@ module {
6363
cir.global external tls_local_dyn @model1 = #cir.int<0> : !s32i
6464
cir.global external tls_init_exec @model2 = #cir.int<0> : !s32i
6565
cir.global external tls_local_exec @model3 = #cir.int<0> : !s32i
66-
67-
cir.global "private" external tls_dyn @batata : !s32i
68-
cir.func @f35() {
69-
%0 = cir.get_global thread_local @batata : cir.ptr <!s32i>
70-
cir.return
71-
}
7266
}
7367

7468
// CHECK: cir.global external @a = #cir.int<3> : !s32i
@@ -97,10 +91,4 @@ module {
9791
// CHECK: cir.global external tls_dyn @model0 = #cir.int<0> : !s32i
9892
// CHECK: cir.global external tls_local_dyn @model1 = #cir.int<0> : !s32i
9993
// CHECK: cir.global external tls_init_exec @model2 = #cir.int<0> : !s32i
100-
// CHECK: cir.global external tls_local_exec @model3 = #cir.int<0> : !s32i
101-
102-
// CHECK: cir.global "private" external tls_dyn @batata : !s32i
103-
// CHECK: cir.func @f35() {
104-
// CHECK: %0 = cir.get_global thread_local @batata : cir.ptr <!s32i>
105-
// CHECK: cir.return
106-
// CHECK: }
94+
// CHECK: cir.global external tls_local_exec @model3 = #cir.int<0> : !s32i

clang/test/CIR/IR/invalid.cir

-13
Original file line numberDiff line numberDiff line change
@@ -1034,16 +1034,3 @@ cir.func @bad_fetch(%x: !cir.ptr<!cir.float>, %y: !cir.float) -> () {
10341034
%12 = cir.atomic.fetch(xor, %x : !cir.ptr<!cir.float>, %y : !cir.float, seq_cst) : !cir.float
10351035
cir.return
10361036
}
1037-
1038-
// -----
1039-
1040-
!s32i = !cir.int<s, 32>
1041-
1042-
module {
1043-
cir.global "private" external @batata : !s32i
1044-
cir.func @f35() {
1045-
// expected-error@+1 {{access to global not marked thread local}}
1046-
%0 = cir.get_global thread_local @batata : cir.ptr <!s32i>
1047-
cir.return
1048-
}
1049-
}

0 commit comments

Comments
 (0)