Skip to content

Commit 041c8e4

Browse files
seven-milelanza
authored andcommitted
[CIR][Dialect] Emit OpenCL kernel metadata (#705)
This PR introduces a new attribute `OpenCLKernelMetadataAttr` to model the OpenCL kernel metadata structurally in CIR, with its corresponding implementations of CodeGen, Lowering and Translation. The `"TypeAttr":$vec_type_hint` part is tricky because of the absence of the signless feature of LLVM IR, while SPIR-V requires it. According to the spec, the final LLVM IR should encode signedness with an extra `i32` boolean value. In this PR, the droping logic from CIR's `TypeConverter` is still used to avoid code duplication when lowering to LLVM dialect. However, the signedness is then restored (still capsuled by a CIR attribute) and dropped again in the translation into LLVM IR.
1 parent 1caf737 commit 041c8e4

File tree

11 files changed

+406
-5
lines changed

11 files changed

+406
-5
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,4 +979,6 @@ def BitfieldInfoAttr : CIR_Attr<"BitfieldInfo", "bitfield_info"> {
979979
];
980980
}
981981

982+
include "clang/CIR/Dialect/IR/CIROpenCLAttrs.td"
983+
982984
#endif // MLIR_CIR_DIALECT_CIR_ATTRS
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
//===- CIROpenCLAttrs.td - CIR dialect attrs for OpenCL ----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the CIR dialect attributes for OpenCL.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CIR_DIALECT_CIR_OPENCL_ATTRS
14+
#define MLIR_CIR_DIALECT_CIR_OPENCL_ATTRS
15+
16+
//===----------------------------------------------------------------------===//
17+
// OpenCLKernelMetadataAttr
18+
//===----------------------------------------------------------------------===//
19+
20+
def OpenCLKernelMetadataAttr
21+
: CIR_Attr<"OpenCLKernelMetadata", "cl.kernel_metadata"> {
22+
23+
let summary = "OpenCL kernel metadata";
24+
let description = [{
25+
Provide the required information of an OpenCL kernel for the SPIR-V backend.
26+
27+
The `work_group_size_hint` and `reqd_work_group_size` parameter are integer
28+
arrays with 3 elements that provide hints for the work-group size and the
29+
required work-group size, respectively.
30+
31+
The `vec_type_hint` parameter is a type attribute that provides a hint for
32+
the vectorization. It can be a CIR or LLVM type, depending on the lowering
33+
stage.
34+
35+
The `vec_type_hint_signedness` parameter is a boolean that indicates the
36+
signedness of the vector type hint. It's useful when LLVM type is set in
37+
`vec_type_hint`, which is signless by design. It should be set if and only
38+
if the `vec_type_hint` is present.
39+
40+
The `intel_reqd_sub_group_size` parameter is an integer that restricts the
41+
sub-group size to the specified value.
42+
43+
Example:
44+
```
45+
#fn_attr = #cir<extra({cl.kernel_metadata = #cir.cl.kernel_metadata<
46+
work_group_size_hint = [8 : i32, 16 : i32, 32 : i32],
47+
reqd_work_group_size = [1 : i32, 2 : i32, 4 : i32],
48+
vec_type_hint = !s32i,
49+
vec_type_hint_signedness = 1,
50+
intel_reqd_sub_group_size = 8 : i32
51+
>})>
52+
53+
cir.func @kernel(%arg0: !s32i) extra(#fn_attr) {
54+
cir.return
55+
}
56+
```
57+
}];
58+
59+
let parameters = (ins
60+
OptionalParameter<"ArrayAttr">:$work_group_size_hint,
61+
OptionalParameter<"ArrayAttr">:$reqd_work_group_size,
62+
OptionalParameter<"TypeAttr">:$vec_type_hint,
63+
OptionalParameter<"std::optional<bool>">:$vec_type_hint_signedness,
64+
OptionalParameter<"IntegerAttr">:$intel_reqd_sub_group_size
65+
);
66+
67+
let assemblyFormat = "`<` struct(params) `>`";
68+
69+
let genVerifyDecl = 1;
70+
71+
let extraClassDeclaration = [{
72+
/// Extract the signedness from int or int vector types.
73+
static std::optional<bool> isSignedHint(mlir::Type vecTypeHint);
74+
}];
75+
76+
let extraClassDefinition = [{
77+
std::optional<bool> $cppClass::isSignedHint(mlir::Type hintQTy) {
78+
// Only types in CIR carry signedness
79+
if (!mlir::isa<mlir::cir::CIRDialect>(hintQTy.getDialect()))
80+
return std::nullopt;
81+
82+
// See also clang::CodeGen::CodeGenFunction::EmitKernelMetadata
83+
auto hintEltQTy = mlir::dyn_cast<mlir::cir::VectorType>(hintQTy);
84+
auto isCIRSignedIntType = [](mlir::Type t) {
85+
return mlir::isa<mlir::cir::IntType>(t) &&
86+
mlir::cast<mlir::cir::IntType>(t).isSigned();
87+
};
88+
return isCIRSignedIntType(hintQTy) ||
89+
(hintEltQTy && isCIRSignedIntType(hintEltQTy.getEltType()));
90+
}
91+
}];
92+
93+
}
94+
95+
#endif // MLIR_CIR_DIALECT_CIR_OPENCL_ATTRS

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ struct MissingFeatures {
142142
static bool getFPFeaturesInEffect() { return false; }
143143
static bool cxxABI() { return false; }
144144
static bool openCL() { return false; }
145+
static bool openCLGenKernelMetadata() { return false; }
145146
static bool CUDA() { return false; }
146147
static bool openMP() { return false; }
147148
static bool openMPRuntime() { return false; }

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -993,8 +993,7 @@ void CIRGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
993993
llvm_unreachable("NYI");
994994

995995
if (FD && getLangOpts().OpenCL) {
996-
// TODO(cir): Emit OpenCL kernel metadata
997-
assert(!MissingFeatures::openCL());
996+
buildKernelMetadata(FD, Fn);
998997
}
999998

1000999
// If we are checking function types, emit a function type signature as
@@ -1720,3 +1719,67 @@ CIRGenFunction::buildArrayLength(const clang::ArrayType *origArrayType,
17201719

17211720
return numElements;
17221721
}
1722+
1723+
void CIRGenFunction::buildKernelMetadata(const FunctionDecl *FD,
1724+
mlir::cir::FuncOp Fn) {
1725+
if (!FD->hasAttr<OpenCLKernelAttr>() && !FD->hasAttr<CUDAGlobalAttr>())
1726+
return;
1727+
1728+
// TODO(cir): CGM.genKernelArgMetadata(Fn, FD, this);
1729+
assert(!MissingFeatures::openCLGenKernelMetadata());
1730+
1731+
if (!getLangOpts().OpenCL)
1732+
return;
1733+
1734+
using mlir::cir::OpenCLKernelMetadataAttr;
1735+
1736+
mlir::ArrayAttr workGroupSizeHintAttr, reqdWorkGroupSizeAttr;
1737+
mlir::TypeAttr vecTypeHintAttr;
1738+
std::optional<bool> vecTypeHintSignedness;
1739+
mlir::IntegerAttr intelReqdSubGroupSizeAttr;
1740+
1741+
if (const VecTypeHintAttr *A = FD->getAttr<VecTypeHintAttr>()) {
1742+
mlir::Type typeHintValue = getTypes().ConvertType(A->getTypeHint());
1743+
vecTypeHintAttr = mlir::TypeAttr::get(typeHintValue);
1744+
vecTypeHintSignedness =
1745+
OpenCLKernelMetadataAttr::isSignedHint(typeHintValue);
1746+
}
1747+
1748+
if (const WorkGroupSizeHintAttr *A = FD->getAttr<WorkGroupSizeHintAttr>()) {
1749+
workGroupSizeHintAttr = builder.getI32ArrayAttr({
1750+
static_cast<int32_t>(A->getXDim()),
1751+
static_cast<int32_t>(A->getYDim()),
1752+
static_cast<int32_t>(A->getZDim()),
1753+
});
1754+
}
1755+
1756+
if (const ReqdWorkGroupSizeAttr *A = FD->getAttr<ReqdWorkGroupSizeAttr>()) {
1757+
reqdWorkGroupSizeAttr = builder.getI32ArrayAttr({
1758+
static_cast<int32_t>(A->getXDim()),
1759+
static_cast<int32_t>(A->getYDim()),
1760+
static_cast<int32_t>(A->getZDim()),
1761+
});
1762+
}
1763+
1764+
if (const OpenCLIntelReqdSubGroupSizeAttr *A =
1765+
FD->getAttr<OpenCLIntelReqdSubGroupSizeAttr>()) {
1766+
intelReqdSubGroupSizeAttr = builder.getI32IntegerAttr(A->getSubGroupSize());
1767+
}
1768+
1769+
// Skip the metadata attr if no hints are present.
1770+
if (!vecTypeHintAttr && !workGroupSizeHintAttr && !reqdWorkGroupSizeAttr &&
1771+
!intelReqdSubGroupSizeAttr)
1772+
return;
1773+
1774+
// Append the kernel metadata to the extra attributes dictionary.
1775+
mlir::NamedAttrList attrs;
1776+
attrs.append(Fn.getExtraAttrs().getElements());
1777+
1778+
auto kernelMetadataAttr = OpenCLKernelMetadataAttr::get(
1779+
builder.getContext(), workGroupSizeHintAttr, reqdWorkGroupSizeAttr,
1780+
vecTypeHintAttr, vecTypeHintSignedness, intelReqdSubGroupSizeAttr);
1781+
attrs.append(kernelMetadataAttr.getMnemonic(), kernelMetadataAttr);
1782+
1783+
Fn.setExtraAttrsAttr(mlir::cir::ExtraFuncAttributesAttr::get(
1784+
builder.getContext(), attrs.getDictionary(builder.getContext())));
1785+
}

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ class CIRGenFunction : public CIRGenTypeCache {
100100
// enter/leave scopes.
101101
llvm::DenseMap<const Expr *, mlir::Value> VLASizeMap;
102102

103+
/// Add OpenCL kernel arg metadata and the kernel attribute metadata to
104+
/// the function metadata.
105+
void buildKernelMetadata(const FunctionDecl *FD, mlir::cir::FuncOp Fn);
106+
103107
public:
104108
/// A non-RAII class containing all the information about a bound
105109
/// opaque value. OpaqueValueMapping, below, is a RAII wrapper for

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
1616
#include "clang/CIR/Dialect/IR/CIRTypes.h"
1717

18+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1819
#include "mlir/IR/Attributes.h"
1920
#include "mlir/IR/Builders.h"
2021
#include "mlir/IR/BuiltinAttributeInterfaces.h"
@@ -499,6 +500,60 @@ LogicalResult DynamicCastInfoAttr::verify(
499500
return success();
500501
}
501502

503+
//===----------------------------------------------------------------------===//
504+
// OpenCLKernelMetadataAttr definitions
505+
//===----------------------------------------------------------------------===//
506+
507+
LogicalResult OpenCLKernelMetadataAttr::verify(
508+
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
509+
ArrayAttr workGroupSizeHint, ArrayAttr reqdWorkGroupSize,
510+
TypeAttr vecTypeHint, std::optional<bool> vecTypeHintSignedness,
511+
IntegerAttr intelReqdSubGroupSize) {
512+
// If no field is present, the attribute is considered invalid.
513+
if (!workGroupSizeHint && !reqdWorkGroupSize && !vecTypeHint &&
514+
!vecTypeHintSignedness && !intelReqdSubGroupSize) {
515+
return emitError()
516+
<< "metadata attribute without any field present is invalid";
517+
}
518+
519+
// Check for 3-dim integer tuples
520+
auto is3dimIntTuple = [](ArrayAttr arr) {
521+
auto isInt = [](Attribute dim) { return mlir::isa<IntegerAttr>(dim); };
522+
return arr.size() == 3 && llvm::all_of(arr, isInt);
523+
};
524+
if (workGroupSizeHint && !is3dimIntTuple(workGroupSizeHint)) {
525+
return emitError()
526+
<< "work_group_size_hint must have exactly 3 integer elements";
527+
}
528+
if (reqdWorkGroupSize && !is3dimIntTuple(reqdWorkGroupSize)) {
529+
return emitError()
530+
<< "reqd_work_group_size must have exactly 3 integer elements";
531+
}
532+
533+
// Check for co-presence of vecTypeHintSignedness
534+
if (!!vecTypeHint != vecTypeHintSignedness.has_value()) {
535+
return emitError() << "vec_type_hint_signedness should be present if and "
536+
"only if vec_type_hint is set";
537+
}
538+
539+
if (vecTypeHint) {
540+
Type vecTypeHintValue = vecTypeHint.getValue();
541+
if (mlir::isa<cir::CIRDialect>(vecTypeHintValue.getDialect())) {
542+
// Check for signedness alignment in CIR
543+
if (isSignedHint(vecTypeHintValue) != vecTypeHintSignedness) {
544+
return emitError() << "vec_type_hint_signedness must match the "
545+
"signedness of the vec_type_hint type";
546+
}
547+
// Check for the dialect of type hint
548+
} else if (!LLVM::isCompatibleType(vecTypeHintValue)) {
549+
return emitError() << "vec_type_hint must be a type from the CIR or LLVM "
550+
"dialect";
551+
}
552+
}
553+
554+
return success();
555+
}
556+
502557
//===----------------------------------------------------------------------===//
503558
// CIR Dialect
504559
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,12 +1523,13 @@ class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {
15231523
/// to the name of the attribute in ODS.
15241524
static StringRef getLinkageAttrNameString() { return "linkage"; }
15251525

1526+
/// Convert the `cir.func` attributes to `llvm.func` attributes.
15261527
/// Only retain those attributes that are not constructed by
15271528
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out
15281529
/// argument attributes.
15291530
void
1530-
filterFuncAttributes(mlir::cir::FuncOp func, bool filterArgAndResAttrs,
1531-
SmallVectorImpl<mlir::NamedAttribute> &result) const {
1531+
lowerFuncAttributes(mlir::cir::FuncOp func, bool filterArgAndResAttrs,
1532+
SmallVectorImpl<mlir::NamedAttribute> &result) const {
15321533
for (auto attr : func->getAttrs()) {
15331534
if (attr.getName() == mlir::SymbolTable::getSymbolAttrName() ||
15341535
attr.getName() == func.getFunctionTypeAttrName() ||
@@ -1543,11 +1544,45 @@ class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {
15431544
if (attr.getName() == func.getExtraAttrsAttrName()) {
15441545
std::string cirName = "cir." + func.getExtraAttrsAttrName().str();
15451546
attr.setName(mlir::StringAttr::get(getContext(), cirName));
1547+
1548+
lowerFuncOpenCLKernelMetadata(attr);
15461549
}
15471550
result.push_back(attr);
15481551
}
15491552
}
15501553

1554+
/// When do module translation, we can only translate LLVM-compatible types.
1555+
/// Here we lower possible OpenCLKernelMetadataAttr to use the converted type.
1556+
void
1557+
lowerFuncOpenCLKernelMetadata(mlir::NamedAttribute &extraAttrsEntry) const {
1558+
const auto attrKey = mlir::cir::OpenCLKernelMetadataAttr::getMnemonic();
1559+
auto oldExtraAttrs =
1560+
cast<mlir::cir::ExtraFuncAttributesAttr>(extraAttrsEntry.getValue());
1561+
if (!oldExtraAttrs.getElements().contains(attrKey))
1562+
return;
1563+
1564+
mlir::NamedAttrList newExtraAttrs;
1565+
for (auto entry : oldExtraAttrs.getElements()) {
1566+
if (entry.getName() == attrKey) {
1567+
auto clKernelMetadata =
1568+
cast<mlir::cir::OpenCLKernelMetadataAttr>(entry.getValue());
1569+
if (auto vecTypeHint = clKernelMetadata.getVecTypeHint()) {
1570+
auto newType = typeConverter->convertType(vecTypeHint.getValue());
1571+
auto newTypeHint = mlir::TypeAttr::get(newType);
1572+
auto newCLKMAttr = mlir::cir::OpenCLKernelMetadataAttr::get(
1573+
getContext(), clKernelMetadata.getWorkGroupSizeHint(),
1574+
clKernelMetadata.getReqdWorkGroupSize(), newTypeHint,
1575+
clKernelMetadata.getVecTypeHintSignedness(),
1576+
clKernelMetadata.getIntelReqdSubGroupSize());
1577+
entry.setValue(newCLKMAttr);
1578+
}
1579+
}
1580+
newExtraAttrs.push_back(entry);
1581+
}
1582+
extraAttrsEntry.setValue(mlir::cir::ExtraFuncAttributesAttr::get(
1583+
getContext(), newExtraAttrs.getDictionary(getContext())));
1584+
}
1585+
15511586
mlir::LogicalResult
15521587
matchAndRewrite(mlir::cir::FuncOp op, OpAdaptor adaptor,
15531588
mlir::ConversionPatternRewriter &rewriter) const override {
@@ -1585,7 +1620,7 @@ class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {
15851620

15861621
auto linkage = convertLinkage(op.getLinkage());
15871622
SmallVector<mlir::NamedAttribute, 4> attributes;
1588-
filterFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes);
1623+
lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes);
15891624

15901625
auto fn = rewriter.create<mlir::LLVM::LLVMFuncOp>(
15911626
Loc, op.getName(), llvmFnTy, linkage, isDsoLocal, mlir::LLVM::CConv::C,

0 commit comments

Comments
 (0)