Skip to content

Commit 133be13

Browse files
committed
Refactor the SYCLFuncDescriptor class (#53)
This PR refactors SYCLFuncDescriptor by creating subclasses for the sycl::id<n> and sycl::range<n> constructors.
1 parent 5d6c920 commit 133be13

File tree

3 files changed

+189
-142
lines changed

3 files changed

+189
-142
lines changed

mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h

+68-36
Original file line numberDiff line numberDiff line change
@@ -75,59 +75,91 @@ class SYCLFuncDescriptor {
7575
// clang-format on
7676

7777
/// Enumerates the kind of FuncId.
78-
enum class FuncIdKind {
78+
enum class FuncKind {
7979
Unknown,
8080
IdCtor, // any sycl::id<n> constructors.
8181
RangeCtor // any sycl::range<n> constructors.
8282
};
8383

84-
/// Returns the funcIdKind given a \p funcId.
85-
static FuncIdKind getFuncIdKind(FuncId funcId);
84+
/// Each descriptor is uniquely identified by the pair {FuncId, FuncKind}.
85+
class Id {
86+
public:
87+
friend class SYCLFuncRegistry;
88+
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &, const Id &);
8689

87-
/// Retuns a descriptive name for the given \p funcIdKind.
88-
static std::string funcIdKindToName(FuncIdKind funcIdKind);
90+
Id(FuncId id, FuncKind kind) : funcId(id), funcKind(kind) {
91+
assert(funcId != FuncId::Unknown && "Illegal function id");
92+
assert(funcKind != FuncKind::Unknown && "Illegal function id kind");
93+
}
8994

90-
/// Retuns the FuncIdKind given a descriptive \p name.
91-
static FuncIdKind nameToFuncIdKind(Twine name);
95+
/// Maps a FuncKind to a descriptive name.
96+
static std::map<SYCLFuncDescriptor::FuncKind, std::string> funcKindToName;
9297

93-
// Call the SYCL constructor identified by \p id with the given \p args.
94-
static Value call(FuncId id, ValueRange args,
98+
/// Maps a descriptive name to a FuncKind.
99+
static std::map<std::string, SYCLFuncDescriptor::FuncKind> nameToFuncKind;
100+
101+
private:
102+
FuncId funcId = FuncId::Unknown;
103+
FuncKind funcKind = FuncKind::Unknown;
104+
};
105+
106+
/// Returns true if the given \p funcId is valid.
107+
virtual bool isValid(FuncId funcId) const { return false; };
108+
109+
/// Call the SYCL constructor identified by \p funcId with the given \p args.
110+
static Value call(FuncId funcId, ValueRange args,
95111
const SYCLFuncRegistry &registry, OpBuilder &b,
96112
Location loc);
97113

98-
private:
99-
/// Private constructor: only available to 'SYCLFuncRegistry'.
100-
SYCLFuncDescriptor(FuncId id, StringRef name, Type outputTy,
101-
ArrayRef<Type> argTys)
102-
: funcId(id), funcIdKind(getFuncIdKind(id)), name(name),
103-
outputTy(outputTy), argTys(argTys.begin(), argTys.end()) {
104-
assert(funcId != FuncId::Unknown && "Illegal function id");
105-
assert(funcIdKind != FuncIdKind::Unknown && "Illegal function id kind");
106-
}
114+
protected:
115+
SYCLFuncDescriptor(FuncId funcId, FuncKind kind, StringRef name,
116+
Type outputTy, ArrayRef<Type> argTys)
117+
: descId(funcId, kind), name(name), outputTy(outputTy),
118+
argTys(argTys.begin(), argTys.end()) {}
107119

120+
private:
108121
/// Inject the declaration for this function into the module.
109122
void declareFunction(ModuleOp &module, OpBuilder &b);
110123

111-
/// Returns true if the given \p funcId is for a sycl::id<n> constructor.
112-
static bool isIdCtor(FuncId funcId);
113-
114-
private:
115-
FuncId funcId = FuncId::Unknown; // SYCL function identifier
116-
FuncIdKind funcIdKind = FuncIdKind::Unknown; // SYCL function kind
124+
Id descId; // unique identifier
117125
StringRef name; // SYCL function name
118126
Type outputTy; // SYCL function output type
119127
SmallVector<Type, 4> argTys; // SYCL function arguments types
120128
FlatSymbolRefAttr funcRef; // Reference to the SYCL function
121129
};
122130

131+
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
132+
const SYCLFuncDescriptor::Id &id) {
133+
os << "funcId=" << (int)id.funcId
134+
<< ", funcKind=" << SYCLFuncDescriptor::Id::funcKindToName[id.funcKind];
135+
return os;
136+
}
137+
123138
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
124139
const SYCLFuncDescriptor &desc) {
125-
os << "funcId=" << (int)desc.funcId
126-
<< ", funcIdKind=" << SYCLFuncDescriptor::funcIdKindToName(desc.funcIdKind)
127-
<< ", name='" << desc.name.str() << "')";
140+
os << "(" << desc.descId << ", name='" << desc.name.str() << "')";
128141
return os;
129142
}
130143

144+
#define DEFINE_CTOR_CLASS(ClassName, ClassKind) \
145+
class ClassName : public SYCLFuncDescriptor { \
146+
public: \
147+
friend class SYCLFuncRegistry; \
148+
using FuncId = SYCLFuncDescriptor::FuncId; \
149+
using FuncKind = SYCLFuncDescriptor::FuncKind; \
150+
\
151+
private: \
152+
ClassName(FuncId funcId, StringRef name, Type outputTy, \
153+
ArrayRef<Type> argTys) \
154+
: SYCLFuncDescriptor(funcId, ClassKind, name, outputTy, argTys) { \
155+
assert(isValid(funcId) && "Invalid function id"); \
156+
} \
157+
bool isValid(FuncId) const override; \
158+
};
159+
DEFINE_CTOR_CLASS(SYCLIdCtorDescriptor, FuncKind::IdCtor)
160+
DEFINE_CTOR_CLASS(SYCLRangeCtorDescriptor, FuncKind::RangeCtor)
161+
#undef DEFINE_CTOR_CLASS
162+
131163
/// \class SYCLFuncRegistry
132164
/// Singleton class representing the set of SYCL functions callable from the
133165
/// compiler.
@@ -137,18 +169,18 @@ class SYCLFuncRegistry {
137169

138170
static const SYCLFuncRegistry create(ModuleOp &module, OpBuilder &builder);
139171

140-
const SYCLFuncDescriptor &getFuncDesc(SYCLFuncDescriptor::FuncId id) const {
141-
assert((registry.find(id) != registry.end()) &&
142-
"function identified by 'id' not found in the SYCL function "
172+
const SYCLFuncDescriptor &
173+
getFuncDesc(SYCLFuncDescriptor::FuncId funcId) const {
174+
assert((registry.find(funcId) != registry.end()) &&
175+
"function identified by 'funcId' not found in the SYCL function "
143176
"registry");
144-
return registry.at(id);
177+
return registry.at(funcId);
145178
}
146179

147-
// Returns the SYCLFuncDescriptor::FuncId corresponding to the function
148-
// descriptor that matches the given signature and funcIdKind.
149-
SYCLFuncDescriptor::FuncId
150-
getFuncId(SYCLFuncDescriptor::FuncIdKind funcIdKind, Type retType,
151-
TypeRange argTypes) const;
180+
// Returns the SYCLFuncDescriptor::Id::FuncId corresponding to the function
181+
// descriptor that matches the given \p funcKind and signature.
182+
SYCLFuncDescriptor::FuncId getFuncId(SYCLFuncDescriptor::FuncKind funcKind,
183+
Type retType, TypeRange argTypes) const;
152184

153185
private:
154186
SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder);

0 commit comments

Comments
 (0)