Skip to content

Commit c99ff03

Browse files
seven-milelanza
authored andcommitted
[CIR][Dialect] Emit OpenCL kernel argument metadata (#767)
Similar to #705, this PR implements the remaining `genKernelArgMetadata()` logic. The attribute `cir.cl.kernel_arg_metadata` is also intentionally placed in the `cir.func`'s `extra_attrs` rather than `cir.func`'s standard `arg_attrs` list. Also, the metadata is stored by `Array` with proper verification on it. See the tablegen doc string for details. This is in order to * keep it side-by-side with `cl.kernel_metadata`. * still emit metadata when kernel has an *empty* arg list (see the test `kernel-arg-meatadata.cl`). * avoid horrors of repeating the long name `cir.cl.kernel_arg_metadata` for `numArgs` times. Because clangir doesn't support OpenCL built-in types and the `half` floating point type yet, their changes and test cases are not included. Corresponding missing feature flag is added.
1 parent b0cf6ab commit c99ff03

File tree

11 files changed

+519
-3
lines changed

11 files changed

+519
-3
lines changed

clang/include/clang/CIR/Dialect/IR/CIROpenCLAttrs.td

+55
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,59 @@ def OpenCLKernelMetadataAttr
9292

9393
}
9494

95+
//===----------------------------------------------------------------------===//
96+
// OpenCLKernelArgMetadataAttr
97+
//===----------------------------------------------------------------------===//
98+
99+
def OpenCLKernelArgMetadataAttr
100+
: CIR_Attr<"OpenCLKernelArgMetadata", "cl.kernel_arg_metadata"> {
101+
102+
let summary = "OpenCL kernel argument metadata";
103+
let description = [{
104+
Provide the required information of an OpenCL kernel argument for the SPIR-V
105+
backend.
106+
107+
All parameters are arrays, containing the information of the argument in
108+
the same order as they appear in the source code.
109+
110+
The `addr_space` parameter is an array of I32 that provides the address
111+
space of the argument. It's useful for special types like `image`, which
112+
have implicit global address space.
113+
114+
Other parameters are arrays of strings that pass through the information
115+
from the source code correspondingly.
116+
117+
All the fields are mandatory except for `name`, which is optional.
118+
119+
Example:
120+
```
121+
#fn_attr = #cir<extra({cl.kernel_arg_metadata = #cir.cl.kernel_arg_metadata<
122+
addr_space = [1 : i32],
123+
access_qual = ["none"],
124+
type = ["char*"],
125+
base_type = ["char*"],
126+
type_qual = [""],
127+
name = ["in"]
128+
>})>
129+
130+
cir.func @kernel(%arg0: !s32i) extra(#fn_attr) {
131+
cir.return
132+
}
133+
```
134+
}];
135+
136+
let parameters = (ins
137+
"ArrayAttr":$addr_space,
138+
"ArrayAttr":$access_qual,
139+
"ArrayAttr":$type,
140+
"ArrayAttr":$base_type,
141+
"ArrayAttr":$type_qual,
142+
OptionalParameter<"ArrayAttr">:$name
143+
);
144+
145+
let assemblyFormat = "`<` struct(params) `>`";
146+
147+
let genVerifyDecl = 1;
148+
}
149+
95150
#endif // MLIR_CIR_DIALECT_CIR_OPENCL_ATTRS

clang/include/clang/CIR/MissingFeatures.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ struct MissingFeatures {
142142
static bool getFPFeaturesInEffect() { return false; }
143143
static bool cxxABI() { return false; }
144144
static bool openCL() { return false; }
145-
static bool openCLGenKernelMetadata() { return false; }
145+
static bool openCLBuiltinTypes() { return false; }
146146
static bool CUDA() { return false; }
147147
static bool openMP() { return false; }
148148
static bool openMPRuntime() { return false; }

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -1725,8 +1725,7 @@ void CIRGenFunction::buildKernelMetadata(const FunctionDecl *FD,
17251725
if (!FD->hasAttr<OpenCLKernelAttr>() && !FD->hasAttr<CUDAGlobalAttr>())
17261726
return;
17271727

1728-
// TODO(cir): CGM.genKernelArgMetadata(Fn, FD, this);
1729-
assert(!MissingFeatures::openCLGenKernelMetadata());
1728+
CGM.genKernelArgMetadata(Fn, FD, this);
17301729

17311730
if (!getLangOpts().OpenCL)
17321731
return;

clang/lib/CIR/CodeGen/CIRGenModule.cpp

+171
Original file line numberDiff line numberDiff line change
@@ -3061,3 +3061,174 @@ mlir::cir::SourceLanguage CIRGenModule::getCIRSourceLanguage() {
30613061
// TODO(cir): support remaining source languages.
30623062
llvm_unreachable("CIR does not yet support the given source language");
30633063
}
3064+
3065+
// Returns the address space id that should be produced to the
3066+
// kernel_arg_addr_space metadata. This is always fixed to the ids
3067+
// as specified in the SPIR 2.0 specification in order to differentiate
3068+
// for example in clGetKernelArgInfo() implementation between the address
3069+
// spaces with targets without unique mapping to the OpenCL address spaces
3070+
// (basically all single AS CPUs).
3071+
static unsigned ArgInfoAddressSpace(LangAS AS) {
3072+
switch (AS) {
3073+
case LangAS::opencl_global:
3074+
return 1;
3075+
case LangAS::opencl_constant:
3076+
return 2;
3077+
case LangAS::opencl_local:
3078+
return 3;
3079+
case LangAS::opencl_generic:
3080+
return 4; // Not in SPIR 2.0 specs.
3081+
case LangAS::opencl_global_device:
3082+
return 5;
3083+
case LangAS::opencl_global_host:
3084+
return 6;
3085+
default:
3086+
return 0; // Assume private.
3087+
}
3088+
}
3089+
3090+
void CIRGenModule::genKernelArgMetadata(mlir::cir::FuncOp Fn,
3091+
const FunctionDecl *FD,
3092+
CIRGenFunction *CGF) {
3093+
assert(((FD && CGF) || (!FD && !CGF)) &&
3094+
"Incorrect use - FD and CGF should either be both null or not!");
3095+
// Create MDNodes that represent the kernel arg metadata.
3096+
// Each MDNode is a list in the form of "key", N number of values which is
3097+
// the same number of values as their are kernel arguments.
3098+
3099+
const PrintingPolicy &Policy = getASTContext().getPrintingPolicy();
3100+
3101+
// Integer values for the kernel argument address space qualifiers.
3102+
SmallVector<int32_t, 8> addressQuals;
3103+
3104+
// Attrs for the kernel argument access qualifiers (images only).
3105+
SmallVector<mlir::Attribute, 8> accessQuals;
3106+
3107+
// Attrs for the kernel argument type names.
3108+
SmallVector<mlir::Attribute, 8> argTypeNames;
3109+
3110+
// Attrs for the kernel argument base type names.
3111+
SmallVector<mlir::Attribute, 8> argBaseTypeNames;
3112+
3113+
// Attrs for the kernel argument type qualifiers.
3114+
SmallVector<mlir::Attribute, 8> argTypeQuals;
3115+
3116+
// Attrs for the kernel argument names.
3117+
SmallVector<mlir::Attribute, 8> argNames;
3118+
3119+
// OpenCL image and pipe types require special treatments for some metadata
3120+
assert(!MissingFeatures::openCLBuiltinTypes());
3121+
3122+
if (FD && CGF)
3123+
for (unsigned i = 0, e = FD->getNumParams(); i != e; ++i) {
3124+
const ParmVarDecl *parm = FD->getParamDecl(i);
3125+
// Get argument name.
3126+
argNames.push_back(builder.getStringAttr(parm->getName()));
3127+
3128+
if (!getLangOpts().OpenCL)
3129+
continue;
3130+
QualType ty = parm->getType();
3131+
std::string typeQuals;
3132+
3133+
// Get image and pipe access qualifier:
3134+
if (ty->isImageType() || ty->isPipeType()) {
3135+
llvm_unreachable("NYI");
3136+
} else
3137+
accessQuals.push_back(builder.getStringAttr("none"));
3138+
3139+
auto getTypeSpelling = [&](QualType Ty) {
3140+
auto typeName = Ty.getUnqualifiedType().getAsString(Policy);
3141+
3142+
if (Ty.isCanonical()) {
3143+
StringRef typeNameRef = typeName;
3144+
// Turn "unsigned type" to "utype"
3145+
if (typeNameRef.consume_front("unsigned "))
3146+
return std::string("u") + typeNameRef.str();
3147+
if (typeNameRef.consume_front("signed "))
3148+
return typeNameRef.str();
3149+
}
3150+
3151+
return typeName;
3152+
};
3153+
3154+
if (ty->isPointerType()) {
3155+
QualType pointeeTy = ty->getPointeeType();
3156+
3157+
// Get address qualifier.
3158+
addressQuals.push_back(
3159+
ArgInfoAddressSpace(pointeeTy.getAddressSpace()));
3160+
3161+
// Get argument type name.
3162+
std::string typeName = getTypeSpelling(pointeeTy) + "*";
3163+
std::string baseTypeName =
3164+
getTypeSpelling(pointeeTy.getCanonicalType()) + "*";
3165+
argTypeNames.push_back(builder.getStringAttr(typeName));
3166+
argBaseTypeNames.push_back(builder.getStringAttr(baseTypeName));
3167+
3168+
// Get argument type qualifiers:
3169+
if (ty.isRestrictQualified())
3170+
typeQuals = "restrict";
3171+
if (pointeeTy.isConstQualified() ||
3172+
(pointeeTy.getAddressSpace() == LangAS::opencl_constant))
3173+
typeQuals += typeQuals.empty() ? "const" : " const";
3174+
if (pointeeTy.isVolatileQualified())
3175+
typeQuals += typeQuals.empty() ? "volatile" : " volatile";
3176+
} else {
3177+
uint32_t AddrSpc = 0;
3178+
bool isPipe = ty->isPipeType();
3179+
if (ty->isImageType() || isPipe)
3180+
llvm_unreachable("NYI");
3181+
3182+
addressQuals.push_back(AddrSpc);
3183+
3184+
// Get argument type name.
3185+
ty = isPipe ? ty->castAs<PipeType>()->getElementType() : ty;
3186+
std::string typeName = getTypeSpelling(ty);
3187+
std::string baseTypeName = getTypeSpelling(ty.getCanonicalType());
3188+
3189+
// Remove access qualifiers on images
3190+
// (as they are inseparable from type in clang implementation,
3191+
// but OpenCL spec provides a special query to get access qualifier
3192+
// via clGetKernelArgInfo with CL_KERNEL_ARG_ACCESS_QUALIFIER):
3193+
if (ty->isImageType()) {
3194+
llvm_unreachable("NYI");
3195+
}
3196+
3197+
argTypeNames.push_back(builder.getStringAttr(typeName));
3198+
argBaseTypeNames.push_back(builder.getStringAttr(baseTypeName));
3199+
3200+
if (isPipe)
3201+
llvm_unreachable("NYI");
3202+
}
3203+
argTypeQuals.push_back(builder.getStringAttr(typeQuals));
3204+
}
3205+
3206+
bool shouldEmitArgName = getCodeGenOpts().EmitOpenCLArgMetadata ||
3207+
getCodeGenOpts().HIPSaveKernelArgName;
3208+
3209+
if (getLangOpts().OpenCL) {
3210+
// The kernel arg name is emitted only when `-cl-kernel-arg-info` is on,
3211+
// since it is only used to support `clGetKernelArgInfo` which requires
3212+
// `-cl-kernel-arg-info` to work. The other metadata are mandatory because
3213+
// they are necessary for OpenCL runtime to set kernel argument.
3214+
mlir::ArrayAttr resArgNames = {};
3215+
if (shouldEmitArgName)
3216+
resArgNames = builder.getArrayAttr(argNames);
3217+
3218+
// Update the function's extra attributes with the kernel argument metadata.
3219+
auto value = mlir::cir::OpenCLKernelArgMetadataAttr::get(
3220+
Fn.getContext(), builder.getI32ArrayAttr(addressQuals),
3221+
builder.getArrayAttr(accessQuals), builder.getArrayAttr(argTypeNames),
3222+
builder.getArrayAttr(argBaseTypeNames),
3223+
builder.getArrayAttr(argTypeQuals), resArgNames);
3224+
mlir::NamedAttrList items{Fn.getExtraAttrs().getElements().getValue()};
3225+
auto oldValue = items.set(value.getMnemonic(), value);
3226+
if (oldValue != value) {
3227+
Fn.setExtraAttrsAttr(mlir::cir::ExtraFuncAttributesAttr::get(
3228+
builder.getContext(), builder.getDictionaryAttr(items)));
3229+
}
3230+
} else {
3231+
if (shouldEmitArgName)
3232+
llvm_unreachable("NYI HIPSaveKernelArgName");
3233+
}
3234+
}

clang/lib/CIR/CodeGen/CIRGenModule.h

+14
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,20 @@ class CIRGenModule : public CIRGenTypeCache {
688688
return *openMPRuntime;
689689
}
690690

691+
/// OpenCL v1.2 s5.6.4.6 allows the compiler to store kernel argument
692+
/// information in the program executable. The argument information stored
693+
/// includes the argument name, its type, the address and access qualifiers
694+
/// used. This helper can be used to generate metadata for source code kernel
695+
/// function as well as generated implicitly kernels. If a kernel is generated
696+
/// implicitly null value has to be passed to the last two parameters,
697+
/// otherwise all parameters must have valid non-null values.
698+
/// \param FN is a pointer to IR function being generated.
699+
/// \param FD is a pointer to function declaration if any.
700+
/// \param CGF is a pointer to CIRGenFunction that generates this function.
701+
void genKernelArgMetadata(mlir::cir::FuncOp FN,
702+
const FunctionDecl *FD = nullptr,
703+
CIRGenFunction *CGF = nullptr);
704+
691705
private:
692706
// An ordered map of canonical GlobalDecls to their mangled names.
693707
llvm::MapVector<clang::GlobalDecl, llvm::StringRef> MangledDeclNames;

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

+37
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,43 @@ LogicalResult OpenCLKernelMetadataAttr::verify(
554554
return success();
555555
}
556556

557+
//===----------------------------------------------------------------------===//
558+
// OpenCLKernelArgMetadataAttr definitions
559+
//===----------------------------------------------------------------------===//
560+
561+
LogicalResult OpenCLKernelArgMetadataAttr::verify(
562+
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
563+
ArrayAttr addrSpaces, ArrayAttr accessQuals, ArrayAttr types,
564+
ArrayAttr baseTypes, ArrayAttr typeQuals, ArrayAttr argNames) {
565+
auto isIntArray = [](ArrayAttr elt) {
566+
return llvm::all_of(
567+
elt, [](Attribute elt) { return mlir::isa<IntegerAttr>(elt); });
568+
};
569+
auto isStrArray = [](ArrayAttr elt) {
570+
return llvm::all_of(
571+
elt, [](Attribute elt) { return mlir::isa<StringAttr>(elt); });
572+
};
573+
574+
if (!isIntArray(addrSpaces))
575+
return emitError() << "addr_space must be integer arrays";
576+
if (!llvm::all_of<ArrayRef<ArrayAttr>>(
577+
{accessQuals, types, baseTypes, typeQuals}, isStrArray))
578+
return emitError()
579+
<< "access_qual, type, base_type, type_qual must be string arrays";
580+
if (argNames && !isStrArray(argNames)) {
581+
return emitError() << "name must be a string array";
582+
}
583+
584+
if (!llvm::all_of<ArrayRef<ArrayAttr>>(
585+
{addrSpaces, accessQuals, types, baseTypes, typeQuals, argNames},
586+
[&](ArrayAttr arr) {
587+
return !arr || arr.size() == addrSpaces.size();
588+
})) {
589+
return emitError() << "all arrays must have the same number of elements";
590+
}
591+
return success();
592+
}
593+
557594
//===----------------------------------------------------------------------===//
558595
// AddressSpaceAttr definitions
559596
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)