@@ -206,6 +206,64 @@ static Optional<Type> convertRangeType(sycl::RangeType type,
206
206
converter);
207
207
}
208
208
209
+ // ===----------------------------------------------------------------------===//
210
+ // CallPattern - Converts `sycl.call` to LLVM.
211
+ // ===----------------------------------------------------------------------===//
212
+
213
+ class CallPattern final : public SYCLToLLVMConversion<sycl::SYCLCallOp> {
214
+ public:
215
+ using SYCLToLLVMConversion<sycl::SYCLCallOp>::SYCLToLLVMConversion;
216
+ using FuncId = SYCLFuncDescriptor::FuncId;
217
+ using Kind = SYCLFuncDescriptor::Kind;
218
+
219
+ LogicalResult
220
+ matchAndRewrite (sycl::SYCLCallOp op, OpAdaptor opAdaptor,
221
+ ConversionPatternRewriter &rewriter) const override {
222
+ assert (op.Type ().has_value () &&
223
+ " Expecting op.Type() to have a valid value" );
224
+ StringRef typeName = op.Type ().value ();
225
+ Kind kind = SYCLFuncDescriptor::Id::nameToKind.at (typeName.str ());
226
+ assert ((kind != Kind::Unknown) && " unknown descriptor kind" );
227
+ return rewriteCall (kind, op, opAdaptor, rewriter);
228
+ }
229
+
230
+ private:
231
+ // / Rewrite sycl.call() {Function = *, Type = *} to a LLVM call to the
232
+ // / appropriate member function.
233
+ LogicalResult rewriteCall (Kind kind, SYCLCallOp op, OpAdaptor opAdaptor,
234
+ ConversionPatternRewriter &rewriter) const {
235
+ assert ((kind != Kind::Unknown) && " Unexpected descriptor kind" );
236
+ LLVM_DEBUG (llvm::dbgs () << " CallPattern: Rewriting op: " ; op.dump ();
237
+ llvm::dbgs () << " \n " );
238
+
239
+ ModuleOp module = op.getOperation ()->getParentOfType <ModuleOp>();
240
+ const auto ®istry = SYCLFuncRegistry::create (module, rewriter);
241
+
242
+ // / Lookup the FuncId corresponding to the member function to use.
243
+ Type retType = op.getODSResults (0 ).empty ()
244
+ ? LLVM::LLVMVoidType::get (module.getContext ())
245
+ : op.Result ().getType ();
246
+
247
+ FuncId funcId =
248
+ registry.getFuncId (kind, retType, opAdaptor.Args ().getTypes ());
249
+ SYCLFuncDescriptor::call (funcId, opAdaptor.getOperands (), registry,
250
+ rewriter, op.getLoc ());
251
+
252
+ LLVM_DEBUG ({
253
+ Operation *func = op->getParentOfType <LLVM::LLVMFuncOp>();
254
+ if (!func)
255
+ func = op->getParentOfType <func::FuncOp>();
256
+
257
+ assert (func && " Could not find parent function" );
258
+ llvm::dbgs () << " ConstructorPattern: Function after rewrite:\n "
259
+ << *func << " \n " ;
260
+ });
261
+
262
+ rewriter.eraseOp (op);
263
+ return success ();
264
+ }
265
+ };
266
+
209
267
// ===----------------------------------------------------------------------===//
210
268
// ConstructorPattern - Converts `sycl.constructor` to LLVM.
211
269
// ===----------------------------------------------------------------------===//
@@ -214,36 +272,34 @@ class ConstructorPattern final
214
272
: public SYCLToLLVMConversion<sycl::SYCLConstructorOp> {
215
273
public:
216
274
using SYCLToLLVMConversion<sycl::SYCLConstructorOp>::SYCLToLLVMConversion;
275
+ using FuncId = SYCLFuncDescriptor::FuncId;
276
+ using Kind = SYCLFuncDescriptor::Kind;
217
277
218
278
LogicalResult
219
- matchAndRewrite (sycl:: SYCLConstructorOp op, OpAdaptor opAdaptor,
279
+ matchAndRewrite (SYCLConstructorOp op, OpAdaptor opAdaptor,
220
280
ConversionPatternRewriter &rewriter) const override {
221
281
return rewriteConstructor (
222
- SYCLFuncDescriptor::Id::nameToFuncKind .at (op.Type ().str ()), op,
223
- opAdaptor, rewriter);
282
+ SYCLFuncDescriptor::Id::nameToKind .at (op.Type ().str ()), op, opAdaptor ,
283
+ rewriter);
224
284
}
225
285
226
286
private:
227
287
// / Rewrite sycl.constructor() { type = * } to a LLVM call to the appropriate
228
288
// / constructor function.
229
- LogicalResult rewriteConstructor (SYCLFuncDescriptor::FuncKind ctorKind ,
230
- SYCLConstructorOp op, OpAdaptor opAdaptor,
289
+ LogicalResult rewriteConstructor (Kind kind, SYCLConstructorOp op ,
290
+ OpAdaptor opAdaptor,
231
291
ConversionPatternRewriter &rewriter) const {
232
- assert ((ctorKind != SYCLFuncDescriptor::FuncKind::Unknown) &&
233
- " Unexpected ctorKind" );
292
+ assert ((kind != Kind::Unknown) && " Unexpected descriptor kind" );
234
293
LLVM_DEBUG (llvm::dbgs () << " ConstructorPattern: Rewriting op: " ; op.dump ();
235
294
llvm::dbgs () << " \n " );
236
295
237
296
ModuleOp module = op.getOperation ()->getParentOfType <ModuleOp>();
238
297
const auto ®istry = SYCLFuncRegistry::create (module, rewriter);
239
298
240
- // / Lookup the FuncId corresponding to the ctor function to use, which is
241
- // / determined based on 'ctorKind) the kind of constructor to search for, and
242
- // / the LLVM types of the sycl.constructor arguments.
243
- SYCLFuncDescriptor::FuncId funcId = registry.getFuncId (
244
- ctorKind, LLVM::LLVMVoidType::get (module.getContext ()),
245
- opAdaptor.Args ().getTypes ());
246
-
299
+ // / Lookup the FuncId corresponding to the ctor function to use.
300
+ auto retType = LLVM::LLVMVoidType::get (module.getContext ());
301
+ FuncId funcId =
302
+ registry.getFuncId (kind, retType, opAdaptor.Args ().getTypes ());
247
303
SYCLFuncDescriptor::call (funcId, opAdaptor.getOperands (), registry,
248
304
rewriter, op.getLoc ());
249
305
@@ -300,5 +356,6 @@ void mlir::sycl::populateSYCLToLLVMConversionPatterns(
300
356
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
301
357
populateSYCLToLLVMTypeConversion (typeConverter);
302
358
359
+ patterns.add <CallPattern>(patterns.getContext (), typeConverter);
303
360
patterns.add <ConstructorPattern>(patterns.getContext (), typeConverter);
304
361
}
0 commit comments