Skip to content

[CIR][CodeGen] Set address space for OpenCL globals #788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -684,9 +684,10 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {

mlir::Value createGetGlobal(mlir::cir::GlobalOp global,
bool threadLocal = false) {
return create<mlir::cir::GetGlobalOp>(global.getLoc(),
getPointerTo(global.getSymType()),
global.getName(), threadLocal);
return create<mlir::cir::GetGlobalOp>(
global.getLoc(),
getPointerTo(global.getSymType(), global.getAddrSpaceAttr()),
global.getName(), threadLocal);
}

mlir::Value createGetBitfield(mlir::Location loc, mlir::Type resultType,
Expand Down
8 changes: 5 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,9 @@ static LValue buildGlobalVarDeclLValue(CIRGenFunction &CGF, const Expr *E,
auto V = CGF.CGM.getAddrOfGlobalVar(VD);

auto RealVarTy = CGF.getTypes().convertTypeForMem(VD->getType());
auto realPtrTy = CGF.getBuilder().getPointerTo(RealVarTy);
mlir::cir::PointerType realPtrTy = CGF.getBuilder().getPointerTo(
RealVarTy, cast_if_present<mlir::cir::AddressSpaceAttr>(
cast<mlir::cir::PointerType>(V.getType()).getAddrSpace()));
if (realPtrTy != V.getType())
V = CGF.getBuilder().createBitcast(V.getLoc(), V, realPtrTy);

Expand Down Expand Up @@ -1999,8 +2001,8 @@ LValue CIRGenFunction::buildCastLValue(const CastExpr *E) {
case CK_AddressSpaceConversion: {
LValue LV = buildLValue(E->getSubExpr());
QualType DestTy = getContext().getPointerType(E->getType());
auto SrcAS = builder.getAddrSpaceAttr(
E->getSubExpr()->getType().getAddressSpace());
auto SrcAS =
builder.getAddrSpaceAttr(E->getSubExpr()->getType().getAddressSpace());
auto DestAS = builder.getAddrSpaceAttr(E->getType().getAddressSpace());
mlir::Value V = getTargetHooks().performAddrSpaceCast(
*this, LV.getPointer(), SrcAS, DestAS, ConvertType(DestTy));
Expand Down
62 changes: 47 additions & 15 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,11 +636,11 @@ mlir::Value CIRGenModule::getGlobalValue(const Decl *D) {
return CurCGF->symbolTable.lookup(D);
}

mlir::cir::GlobalOp CIRGenModule::createGlobalOp(CIRGenModule &CGM,
mlir::Location loc,
StringRef name, mlir::Type t,
bool isCst,
mlir::Operation *insertPoint) {
mlir::cir::GlobalOp
CIRGenModule::createGlobalOp(CIRGenModule &CGM, mlir::Location loc,
StringRef name, mlir::Type t, bool isCst,
mlir::cir::AddressSpaceAttr addrSpace,
mlir::Operation *insertPoint) {
mlir::cir::GlobalOp g;
auto &builder = CGM.getBuilder();
{
Expand All @@ -654,7 +654,8 @@ mlir::cir::GlobalOp CIRGenModule::createGlobalOp(CIRGenModule &CGM,
if (curCGF)
builder.setInsertionPoint(curCGF->CurFn);

g = builder.create<mlir::cir::GlobalOp>(loc, name, t, isCst);
g = builder.create<mlir::cir::GlobalOp>(
loc, name, t, isCst, GlobalLinkageKind::ExternalLinkage, addrSpace);
if (!curCGF) {
if (insertPoint)
CGM.getModule().insert(insertPoint, g);
Expand Down Expand Up @@ -741,6 +742,12 @@ void CIRGenModule::replaceGlobal(mlir::cir::GlobalOp Old,
// If the types does not match, update all references to Old to the new type.
auto OldTy = Old.getSymType();
auto NewTy = New.getSymType();
mlir::cir::AddressSpaceAttr oldAS = Old.getAddrSpaceAttr();
mlir::cir::AddressSpaceAttr newAS = New.getAddrSpaceAttr();
// TODO(cir): If the AS differs, we should also update all references.
if (oldAS != newAS) {
llvm_unreachable("NYI");
}
if (OldTy != NewTy) {
auto OldSymUses = Old.getSymbolUses(theModule.getOperation());
if (OldSymUses.has_value()) {
Expand Down Expand Up @@ -808,7 +815,7 @@ void CIRGenModule::setTLSMode(mlir::Operation *Op, const VarDecl &D) const {
/// mangled name but some other type.
mlir::cir::GlobalOp
CIRGenModule::getOrCreateCIRGlobal(StringRef MangledName, mlir::Type Ty,
LangAS AddrSpace, const VarDecl *D,
LangAS langAS, const VarDecl *D,
ForDefinition_t IsForDefinition) {
// Lookup the entry, lazily creating it if necessary.
mlir::cir::GlobalOp Entry;
Expand All @@ -817,8 +824,9 @@ CIRGenModule::getOrCreateCIRGlobal(StringRef MangledName, mlir::Type Ty,
Entry = dyn_cast_or_null<mlir::cir::GlobalOp>(V);
}

// unsigned TargetAS = astCtx.getTargetAddressSpace(AddrSpace);
mlir::cir::AddressSpaceAttr cirAS = builder.getAddrSpaceAttr(langAS);
if (Entry) {
auto entryCIRAS = Entry.getAddrSpaceAttr();
if (WeakRefReferences.erase(Entry)) {
if (D && !D->hasAttr<WeakAttr>()) {
auto LT = mlir::cir::GlobalLinkageKind::ExternalLinkage;
Expand All @@ -836,8 +844,7 @@ CIRGenModule::getOrCreateCIRGlobal(StringRef MangledName, mlir::Type Ty,
if (langOpts.OpenMP && !langOpts.OpenMPSimd && D)
getOpenMPRuntime().registerTargetGlobalVariable(D, Entry);

// TODO(cir): check TargetAS matches Entry address space
if (Entry.getSymType() == Ty && !MissingFeatures::addressSpaceInGlobalVar())
if (Entry.getSymType() == Ty && entryCIRAS == cirAS)
return Entry;

// If there are two attempts to define the same mangled name, issue an
Expand Down Expand Up @@ -866,14 +873,16 @@ CIRGenModule::getOrCreateCIRGlobal(StringRef MangledName, mlir::Type Ty,

// TODO(cir): LLVM codegen makes sure the result is of the correct type
// by issuing a address space cast.
if (entryCIRAS != cirAS)
llvm_unreachable("NYI");

// (If global is requested for a definition, we always need to create a new
// global, not just return a bitcast.)
if (!IsForDefinition)
return Entry;
}

// TODO(cir): auto DAddrSpace = GetGlobalVarAddressSpace(D);
auto declCIRAS = builder.getAddrSpaceAttr(getGlobalVarAddressSpace(D));
// TODO(cir): do we need to strip pointer casts for Entry?

auto loc = getLoc(D->getSourceRange());
Expand All @@ -882,6 +891,7 @@ CIRGenModule::getOrCreateCIRGlobal(StringRef MangledName, mlir::Type Ty,
// mark it as such.
auto GV = CIRGenModule::createGlobalOp(*this, loc, MangledName, Ty,
/*isConstant=*/false,
/*addrSpace=*/declCIRAS,
/*insertPoint=*/Entry.getOperation());

// If we already created a global with the same mangled name (but different
Expand Down Expand Up @@ -991,8 +1001,7 @@ mlir::Value CIRGenModule::getAddrOfGlobalVar(const VarDecl *D, mlir::Type Ty,

bool tlsAccess = D->getTLSKind() != VarDecl::TLS_None;
auto g = buildGlobal(D, Ty, IsForDefinition);
auto ptrTy =
mlir::cir::PointerType::get(builder.getContext(), g.getSymType());
auto ptrTy = builder.getPointerTo(g.getSymType(), g.getAddrSpaceAttr());
return builder.create<mlir::cir::GetGlobalOp>(
getLoc(D->getSourceRange()), ptrTy, g.getSymName(), tlsAccess);
}
Expand Down Expand Up @@ -1075,7 +1084,8 @@ void CIRGenModule::buildGlobalVarDefinition(const clang::VarDecl *D,
// If this is OpenMP device, check if it is legal to emit this global
// normally.
QualType ASTTy = D->getType();
if (getLangOpts().OpenCL || getLangOpts().OpenMPIsTargetDevice)
if ((getLangOpts().OpenCL && ASTTy->isSamplerT()) ||
getLangOpts().OpenMPIsTargetDevice)
llvm_unreachable("not implemented");

// TODO(cir): LLVM's codegen uses a llvm::TrackingVH here. Is that
Expand Down Expand Up @@ -1408,7 +1418,7 @@ LangAS CIRGenModule::getLangTempAllocaAddressSpace() const {
if (getLangOpts().OpenCL)
return LangAS::opencl_private;
if (getLangOpts().SYCLIsDevice || getLangOpts().CUDAIsDevice ||
(getLangOpts().OpenMP && getLangOpts().OpenMPIsTargetDevice))
(getLangOpts().OpenMP && getLangOpts().OpenMPIsTargetDevice))
llvm_unreachable("NYI");
return LangAS::Default;
}
Expand Down Expand Up @@ -3099,3 +3109,25 @@ mlir::cir::SourceLanguage CIRGenModule::getCIRSourceLanguage() {
// TODO(cir): support remaining source languages.
llvm_unreachable("CIR does not yet support the given source language");
}

LangAS CIRGenModule::getGlobalVarAddressSpace(const VarDecl *D) {
if (langOpts.OpenCL) {
LangAS AS = D ? D->getType().getAddressSpace() : LangAS::opencl_global;
assert(AS == LangAS::opencl_global || AS == LangAS::opencl_global_device ||
AS == LangAS::opencl_global_host || AS == LangAS::opencl_constant ||
AS == LangAS::opencl_local || AS >= LangAS::FirstTargetAddressSpace);
return AS;
}

if (langOpts.SYCLIsDevice &&
(!D || D->getType().getAddressSpace() == LangAS::Default))
llvm_unreachable("NYI");

if (langOpts.CUDA && langOpts.CUDAIsDevice)
llvm_unreachable("NYI");

if (langOpts.OpenMP)
llvm_unreachable("NYI");

return getTargetCIRGenInfo().getGlobalVarAddressSpace(*this, D);
}
11 changes: 11 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ class CIRGenModule : public CIRGenTypeCache {
static mlir::cir::GlobalOp
createGlobalOp(CIRGenModule &CGM, mlir::Location loc, StringRef name,
mlir::Type t, bool isCst = false,
mlir::cir::AddressSpaceAttr addrSpace = {},
mlir::Operation *insertPoint = nullptr);

// FIXME: Hardcoding priority here is gross.
Expand Down Expand Up @@ -328,6 +329,16 @@ class CIRGenModule : public CIRGenTypeCache {
return (Twine(".compoundLiteral.") + Twine(CompoundLitaralCnt++)).str();
}

/// Return the AST address space of the underlying global variable for D, as
/// determined by its declaration. Normally this is the same as the address
/// space of D's type, but in CUDA, address spaces are associated with
/// declarations, not types. If D is nullptr, return the default address
/// space for global variable.
///
/// For languages without explicit address spaces, if D has default address
/// space, target-specific global or constant address space may be returned.
LangAS getGlobalVarAddressSpace(const VarDecl *D);

/// Return the AST address space of constant literal, which is used to emit
/// the constant literal as global variable in LLVM IR.
/// Note: This is not necessarily the address space of the constant literal
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/CIR/CodeGen/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,15 @@ ABIArgInfo X86_64ABIInfo::classifyReturnType(QualType RetTy) const {
return ABIArgInfo::getDirect(ResType);
}

clang::LangAS
TargetCIRGenInfo::getGlobalVarAddressSpace(cir::CIRGenModule &CGM,
const clang::VarDecl *D) const {
assert(!CGM.getLangOpts().OpenCL &&
!(CGM.getLangOpts().CUDA && CGM.getLangOpts().CUDAIsDevice) &&
"Address space agnostic languages only");
return D ? D->getType().getAddressSpace() : LangAS::Default;
}

mlir::Value TargetCIRGenInfo::performAddrSpaceCast(
CIRGenFunction &CGF, mlir::Value Src, mlir::cir::AddressSpaceAttr SrcAddr,
mlir::cir::AddressSpaceAttr DestAddr, mlir::Type DestTy,
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/CIR/CodeGen/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ class TargetCIRGenInfo {
std::vector<LValue> &ResultRegDests,
std::string &AsmString, unsigned NumOutputs) const {}

/// Get target favored AST address space of a global variable for languages
/// other than OpenCL and CUDA.
/// If \p D is nullptr, returns the default target favored address space
/// for global variable.
virtual clang::LangAS getGlobalVarAddressSpace(CIRGenModule &CGM,
const clang::VarDecl *D) const;

/// Get the CIR address space for alloca.
virtual mlir::cir::AddressSpaceAttr getCIRAllocaAddressSpace() const {
// Return the null attribute, which means the target does not care about the
Expand Down
23 changes: 23 additions & 0 deletions clang/test/CIR/CodeGen/OpenCL/global.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: %clang_cc1 -cl-std=CL3.0 -O0 -fclangir -emit-cir -triple spirv64-unknown-unknown %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
// RUN: %clang_cc1 -cl-std=CL3.0 -O0 -fclangir -emit-llvm -triple spirv64-unknown-unknown %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=LLVM

global int a = 13;
// CIR-DAG: cir.global external addrspace(offload_global) @a = #cir.int<13> : !s32i
// LLVM-DAG: @a = addrspace(1) global i32 13

global int b = 15;
// CIR-DAG: cir.global external addrspace(offload_global) @b = #cir.int<15> : !s32i
// LLVM-DAG: @b = addrspace(1) global i32 15

kernel void test_get_global() {
a = b;
// CIR: %[[#ADDRB:]] = cir.get_global @b : !cir.ptr<!s32i, addrspace(offload_global)>
// CIR-NEXT: %[[#LOADB:]] = cir.load %[[#ADDRB]] : !cir.ptr<!s32i, addrspace(offload_global)>, !s32i
// CIR-NEXT: %[[#ADDRA:]] = cir.get_global @a : !cir.ptr<!s32i, addrspace(offload_global)>
// CIR-NEXT: cir.store %[[#LOADB]], %[[#ADDRA]] : !s32i, !cir.ptr<!s32i, addrspace(offload_global)>

// LLVM: %[[#LOADB:]] = load i32, ptr addrspace(1) @b, align 4
// LLVM-NEXT: store i32 %[[#LOADB]], ptr addrspace(1) @a, align 4
}