@@ -75,59 +75,91 @@ class SYCLFuncDescriptor {
75
75
// clang-format on
76
76
77
77
// / Enumerates the kind of FuncId.
78
- enum class FuncIdKind {
78
+ enum class FuncKind {
79
79
Unknown,
80
80
IdCtor, // any sycl::id<n> constructors.
81
81
RangeCtor // any sycl::range<n> constructors.
82
82
};
83
83
84
- // / Returns the funcIdKind given a \p funcId.
85
- static FuncIdKind getFuncIdKind (FuncId funcId);
84
+ // / Each descriptor is uniquely identified by the pair {FuncId, FuncKind}.
85
+ class Id {
86
+ public:
87
+ friend class SYCLFuncRegistry ;
88
+ friend llvm::raw_ostream &operator <<(llvm::raw_ostream &, const Id &);
86
89
87
- // / Retuns a descriptive name for the given \p funcIdKind.
88
- static std::string funcIdKindToName (FuncIdKind funcIdKind);
90
+ Id (FuncId id, FuncKind kind) : funcId(id), funcKind(kind) {
91
+ assert (funcId != FuncId::Unknown && " Illegal function id" );
92
+ assert (funcKind != FuncKind::Unknown && " Illegal function id kind" );
93
+ }
89
94
90
- // / Retuns the FuncIdKind given a descriptive \p name.
91
- static FuncIdKind nameToFuncIdKind (Twine name) ;
95
+ // / Maps a FuncKind to a descriptive name.
96
+ static std::map<SYCLFuncDescriptor::FuncKind, std::string> funcKindToName ;
92
97
93
- // Call the SYCL constructor identified by \p id with the given \p args.
94
- static Value call (FuncId id, ValueRange args,
98
+ // / Maps a descriptive name to a FuncKind.
99
+ static std::map<std::string, SYCLFuncDescriptor::FuncKind> nameToFuncKind;
100
+
101
+ private:
102
+ FuncId funcId = FuncId::Unknown;
103
+ FuncKind funcKind = FuncKind::Unknown;
104
+ };
105
+
106
+ // / Returns true if the given \p funcId is valid.
107
+ virtual bool isValid (FuncId funcId) const { return false ; };
108
+
109
+ // / Call the SYCL constructor identified by \p funcId with the given \p args.
110
+ static Value call (FuncId funcId, ValueRange args,
95
111
const SYCLFuncRegistry ®istry, OpBuilder &b,
96
112
Location loc);
97
113
98
- private:
99
- // / Private constructor: only available to 'SYCLFuncRegistry'.
100
- SYCLFuncDescriptor (FuncId id, StringRef name, Type outputTy,
101
- ArrayRef<Type> argTys)
102
- : funcId(id), funcIdKind(getFuncIdKind(id)), name(name),
103
- outputTy (outputTy), argTys(argTys.begin(), argTys.end()) {
104
- assert (funcId != FuncId::Unknown && " Illegal function id" );
105
- assert (funcIdKind != FuncIdKind::Unknown && " Illegal function id kind" );
106
- }
114
+ protected:
115
+ SYCLFuncDescriptor (FuncId funcId, FuncKind kind, StringRef name,
116
+ Type outputTy, ArrayRef<Type> argTys)
117
+ : descId(funcId, kind), name(name), outputTy(outputTy),
118
+ argTys (argTys.begin(), argTys.end()) {}
107
119
120
+ private:
108
121
// / Inject the declaration for this function into the module.
109
122
void declareFunction (ModuleOp &module, OpBuilder &b);
110
123
111
- // / Returns true if the given \p funcId is for a sycl::id<n> constructor.
112
- static bool isIdCtor (FuncId funcId);
113
-
114
- private:
115
- FuncId funcId = FuncId::Unknown; // SYCL function identifier
116
- FuncIdKind funcIdKind = FuncIdKind::Unknown; // SYCL function kind
124
+ Id descId; // unique identifier
117
125
StringRef name; // SYCL function name
118
126
Type outputTy; // SYCL function output type
119
127
SmallVector<Type, 4 > argTys; // SYCL function arguments types
120
128
FlatSymbolRefAttr funcRef; // Reference to the SYCL function
121
129
};
122
130
131
+ inline llvm::raw_ostream &operator <<(llvm::raw_ostream &os,
132
+ const SYCLFuncDescriptor::Id &id) {
133
+ os << " funcId=" << (int )id.funcId
134
+ << " , funcKind=" << SYCLFuncDescriptor::Id::funcKindToName[id.funcKind ];
135
+ return os;
136
+ }
137
+
123
138
inline llvm::raw_ostream &operator <<(llvm::raw_ostream &os,
124
139
const SYCLFuncDescriptor &desc) {
125
- os << " funcId=" << (int )desc.funcId
126
- << " , funcIdKind=" << SYCLFuncDescriptor::funcIdKindToName (desc.funcIdKind )
127
- << " , name='" << desc.name .str () << " ')" ;
140
+ os << " (" << desc.descId << " , name='" << desc.name .str () << " ')" ;
128
141
return os;
129
142
}
130
143
144
+ #define DEFINE_CTOR_CLASS (ClassName, ClassKind ) \
145
+ class ClassName : public SYCLFuncDescriptor { \
146
+ public: \
147
+ friend class SYCLFuncRegistry ; \
148
+ using FuncId = SYCLFuncDescriptor::FuncId; \
149
+ using FuncKind = SYCLFuncDescriptor::FuncKind; \
150
+ \
151
+ private: \
152
+ ClassName (FuncId funcId, StringRef name, Type outputTy, \
153
+ ArrayRef<Type> argTys) \
154
+ : SYCLFuncDescriptor(funcId, ClassKind, name, outputTy, argTys) { \
155
+ assert (isValid (funcId) && " Invalid function id" ); \
156
+ } \
157
+ bool isValid (FuncId) const override ; \
158
+ };
159
+ DEFINE_CTOR_CLASS (SYCLIdCtorDescriptor, FuncKind::IdCtor)
160
+ DEFINE_CTOR_CLASS (SYCLRangeCtorDescriptor, FuncKind::RangeCtor)
161
+ #undef DEFINE_CTOR_CLASS
162
+
131
163
// / \class SYCLFuncRegistry
132
164
// / Singleton class representing the set of SYCL functions callable from the
133
165
// / compiler.
@@ -137,18 +169,18 @@ class SYCLFuncRegistry {
137
169
138
170
static const SYCLFuncRegistry create (ModuleOp &module, OpBuilder &builder);
139
171
140
- const SYCLFuncDescriptor &getFuncDesc (SYCLFuncDescriptor::FuncId id) const {
141
- assert ((registry.find (id) != registry.end ()) &&
142
- " function identified by 'id' not found in the SYCL function "
172
+ const SYCLFuncDescriptor &
173
+ getFuncDesc (SYCLFuncDescriptor::FuncId funcId) const {
174
+ assert ((registry.find (funcId) != registry.end ()) &&
175
+ " function identified by 'funcId' not found in the SYCL function "
143
176
" registry" );
144
- return registry.at (id );
177
+ return registry.at (funcId );
145
178
}
146
179
147
- // Returns the SYCLFuncDescriptor::FuncId corresponding to the function
148
- // descriptor that matches the given signature and funcIdKind.
149
- SYCLFuncDescriptor::FuncId
150
- getFuncId (SYCLFuncDescriptor::FuncIdKind funcIdKind, Type retType,
151
- TypeRange argTypes) const ;
180
+ // Returns the SYCLFuncDescriptor::Id::FuncId corresponding to the function
181
+ // descriptor that matches the given \p funcKind and signature.
182
+ SYCLFuncDescriptor::FuncId getFuncId (SYCLFuncDescriptor::FuncKind funcKind,
183
+ Type retType, TypeRange argTypes) const ;
152
184
153
185
private:
154
186
SYCLFuncRegistry (ModuleOp &module, OpBuilder &builder);
0 commit comments