Skip to content

Commit 09df598

Browse files
committed
Conversion for sycl.call operator (#56)
This PR adds a conversion pattern named CallPattern to the SYCLToLLVM conversion pass. This new conversion pattern lowers a sycl.call operator to an LLVM call to the SYCL member function named by the operation. Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent b994ff4 commit 09df598

File tree

5 files changed

+131
-60
lines changed

5 files changed

+131
-60
lines changed

mlir-sycl/include/mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h

+20-21
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
namespace mlir {
2323
class LLVMTypeConverter;
24+
2425
namespace sycl {
2526
class SYCLFuncRegistry;
2627

@@ -79,39 +80,37 @@ class SYCLFuncDescriptor {
7980
Range1CopyCtor, // sycl::range<1>::range(sycl::range<1> const&)
8081
Range2CopyCtor, // sycl::range<2>::range(sycl::range<2> const&)
8182
Range3CopyCtor, // sycl::range<3>::range(sycl::range<3> const&)
82-
83-
Arr1CtorSizeT, // sycl::detail::array<1>::array<1>(std::enable_if<(1)==(1), unsigned long>::type)
8483
};
8584
// clang-format on
8685

87-
/// Enumerates the kind of FuncId.
88-
enum class FuncKind {
86+
/// Enumerates the descriptor kind.
87+
enum class Kind {
8988
Unknown,
9089
Accessor, // sycl::accessor class
9190
Id, // sycl::id<n> class
9291
Range, // sycl::range<n> class
9392
};
9493

95-
/// Each descriptor is uniquely identified by the pair {FuncId, FuncKind}.
94+
/// Each descriptor is uniquely identified by the pair {FuncId, Kind}.
9695
class Id {
9796
public:
9897
friend class SYCLFuncRegistry;
9998
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &, const Id &);
10099

101-
Id(FuncId id, FuncKind kind) : funcId(id), funcKind(kind) {
100+
Id(FuncId id, Kind kind) : funcId(id), kind(kind) {
102101
assert(funcId != FuncId::Unknown && "Illegal function id");
103-
assert(funcKind != FuncKind::Unknown && "Illegal function id kind");
102+
assert(kind != Kind::Unknown && "Illegal descriptor kind");
104103
}
105104

106-
/// Maps a FuncKind to a descriptive name.
107-
static std::map<SYCLFuncDescriptor::FuncKind, std::string> funcKindToName;
105+
/// Maps a Kind to a descriptive name.
106+
static std::map<SYCLFuncDescriptor::Kind, std::string> kindToName;
108107

109-
/// Maps a descriptive name to a FuncKind.
110-
static std::map<std::string, SYCLFuncDescriptor::FuncKind> nameToFuncKind;
108+
/// Maps a descriptive name to a Kind.
109+
static std::map<std::string, SYCLFuncDescriptor::Kind> nameToKind;
111110

112111
private:
113112
FuncId funcId = FuncId::Unknown;
114-
FuncKind funcKind = FuncKind::Unknown;
113+
Kind kind = Kind::Unknown;
115114
};
116115

117116
/// Returns true if the given \p funcId is valid.
@@ -123,7 +122,7 @@ class SYCLFuncDescriptor {
123122
Location loc);
124123

125124
protected:
126-
SYCLFuncDescriptor(FuncId funcId, FuncKind kind, StringRef name,
125+
SYCLFuncDescriptor(FuncId funcId, Kind kind, StringRef name,
127126
Type outputTy, ArrayRef<Type> argTys)
128127
: descId(funcId, kind), name(name), outputTy(outputTy),
129128
argTys(argTys.begin(), argTys.end()) {}
@@ -142,7 +141,7 @@ class SYCLFuncDescriptor {
142141
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
143142
const SYCLFuncDescriptor::Id &id) {
144143
os << "funcId=" << (int)id.funcId
145-
<< ", funcKind=" << SYCLFuncDescriptor::Id::funcKindToName[id.funcKind];
144+
<< ", kind=" << SYCLFuncDescriptor::Id::kindToName[id.kind];
146145
return os;
147146
}
148147

@@ -161,7 +160,7 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
161160
public: \
162161
friend class SYCLFuncRegistry; \
163162
using FuncId = SYCLFuncDescriptor::FuncId; \
164-
using FuncKind = SYCLFuncDescriptor::FuncKind; \
163+
using Kind = SYCLFuncDescriptor::Kind; \
165164
\
166165
private: \
167166
ClassName(FuncId funcId, StringRef name, Type outputTy, \
@@ -171,17 +170,17 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
171170
} \
172171
bool isValid(FuncId) const override; \
173172
};
174-
DEFINE_CLASS(SYCLAccessorFuncDescriptor, FuncKind::Accessor)
175-
DEFINE_CLASS(SYCLIdFuncDescriptor, FuncKind::Id)
176-
DEFINE_CLASS(SYCLRangeFuncDescriptor, FuncKind::Range)
173+
DEFINE_CLASS(SYCLAccessorFuncDescriptor, Kind::Accessor)
174+
DEFINE_CLASS(SYCLIdFuncDescriptor, Kind::Id)
175+
DEFINE_CLASS(SYCLRangeFuncDescriptor, Kind::Range)
177176
#undef DEFINE_CLASS
178177

179178
/// \class SYCLFuncRegistry
180179
/// Singleton class representing the set of SYCL functions callable from the
181180
/// compiler.
182181
class SYCLFuncRegistry {
183182
using FuncId = SYCLFuncDescriptor::FuncId;
184-
using FuncKind = SYCLFuncDescriptor::FuncKind;
183+
using Kind = SYCLFuncDescriptor::Kind;
185184
using Registry = std::map<FuncId, SYCLFuncDescriptor>;
186185

187186
public:
@@ -199,8 +198,8 @@ class SYCLFuncRegistry {
199198
}
200199

201200
/// Returns the SYCLFuncDescriptor::Id::FuncId corresponding to the function
202-
/// descriptor that matches the given \p funcKind and signature.
203-
FuncId getFuncId(FuncKind funcKind, Type retType, TypeRange argTypes) const;
201+
/// descriptor that matches the given \p kind and signature.
202+
FuncId getFuncId(Kind kind, Type retType, TypeRange argTypes) const;
204203

205204
private:
206205
SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder);

mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLFuncRegistry.cpp

+24-25
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@ using namespace mlir::sycl;
2828
// SYCLFuncDescriptor::Id
2929
//===----------------------------------------------------------------------===//
3030

31-
std::map<SYCLFuncDescriptor::FuncKind, std::string>
32-
SYCLFuncDescriptor::Id::funcKindToName = {
33-
{FuncKind::Accessor, "accessor"},
34-
{FuncKind::Id, "id"},
35-
{FuncKind::Range, "range"},
36-
{FuncKind::Unknown, "unknown"},
31+
std::map<SYCLFuncDescriptor::Kind, std::string>
32+
SYCLFuncDescriptor::Id::kindToName = {
33+
{Kind::Accessor, "accessor"},
34+
{Kind::Id, "id"},
35+
{Kind::Range, "range"},
36+
{Kind::Unknown, "unknown"},
3737
};
3838

39-
std::map<std::string, SYCLFuncDescriptor::FuncKind>
40-
SYCLFuncDescriptor::Id::nameToFuncKind = {
41-
{"accessor", FuncKind::Accessor},
42-
{"id", FuncKind::Id},
43-
{"range", FuncKind::Range},
44-
{"unknown", FuncKind::Unknown},
39+
std::map<std::string, SYCLFuncDescriptor::Kind>
40+
SYCLFuncDescriptor::Id::nameToKind = {
41+
{"accessor", Kind::Accessor},
42+
{"id", Kind::Id},
43+
{"range", Kind::Range},
44+
{"unknown", Kind::Unknown},
4545
};
4646

4747
//===----------------------------------------------------------------------===//
@@ -157,20 +157,19 @@ const SYCLFuncRegistry SYCLFuncRegistry::create(ModuleOp &module,
157157
}
158158

159159
SYCLFuncDescriptor::FuncId
160-
SYCLFuncRegistry::getFuncId(SYCLFuncDescriptor::FuncKind funcKind, Type retType,
160+
SYCLFuncRegistry::getFuncId(SYCLFuncDescriptor::Kind kind, Type retType,
161161
TypeRange argTypes) const {
162-
assert(funcKind != SYCLFuncDescriptor::FuncKind::Unknown &&
163-
"Invalid funcKind");
162+
assert(kind != Kind::Unknown && "Invalid descriptor kind");
164163
LLVM_DEBUG(llvm::dbgs() << "Looking up function of kind: "
165-
<< SYCLFuncDescriptor::Id::funcKindToName[funcKind]
164+
<< SYCLFuncDescriptor::Id::kindToName.at(kind)
166165
<< "\n";);
167166

168167
for (const auto &entry : registry) {
169168
const SYCLFuncDescriptor &desc = entry.second;
170169
LLVM_DEBUG(llvm::dbgs() << desc << "\n");
171170

172-
// Skip through entries that do not match the requested funcIdKind.
173-
if (desc.descId.funcKind != funcKind) {
171+
// Skip through entries that do not match the requested kind.
172+
if (desc.descId.kind != kind) {
174173
LLVM_DEBUG(llvm::dbgs() << "\tskip, kind does not match\n");
175174
continue;
176175
}
@@ -186,7 +185,12 @@ SYCLFuncRegistry::getFuncId(SYCLFuncDescriptor::FuncKind funcKind, Type retType,
186185
continue;
187186
}
188187
if (!std::equal(argTypes.begin(), argTypes.end(), desc.argTys.begin())) {
189-
LLVM_DEBUG(llvm::dbgs() << "\tskip, arguments types do not match\n");
188+
LLVM_DEBUG({
189+
auto pair = std::mismatch(argTypes.begin(), argTypes.end(),
190+
desc.argTys.begin());
191+
llvm::dbgs() << "\tskip, arguments " << *pair.first << " and "
192+
<< *pair.second << " do not match\n";
193+
});
190194
continue;
191195
}
192196

@@ -268,14 +272,8 @@ void SYCLFuncRegistry::declareIdFuncDescriptors(LLVMTypeConverter &converter,
268272
converter.convertType(MemRefType::get(-1, IDType::get(context, 2)));
269273
Type id3PtrTy =
270274
converter.convertType(MemRefType::get(-1, IDType::get(context, 3)));
271-
272275
auto voidTy = LLVM::LLVMVoidType::get(context);
273276
auto i64Ty = IntegerType::get(context, 64);
274-
auto indexTy = IndexType::get(context);
275-
276-
auto arrayMemref = mlir::MemRefType::get(1, indexTy);
277-
Type arr1PtrTy =
278-
converter.convertType(mlir::MemRefType::get(-1, arrayMemref));
279277

280278
// Construct the SYCL functions descriptors for the sycl::id<n> type.
281279
// Descriptor format: (enum, function name, signature).
@@ -329,6 +327,7 @@ void SYCLFuncRegistry::declareIdFuncDescriptors(LLVMTypeConverter &converter,
329327
SYCLIdFuncDescriptor(FuncId::Id3Ctor3SizeT,
330328
"_ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm",
331329
voidTy, {id3PtrTy, i64Ty, i64Ty, i64Ty}),
330+
332331
// sycl::id<1>::id(sycl::id<1> const&)
333332
SYCLIdFuncDescriptor(FuncId::Id1CopyCtor,
334333
"_ZN2cl4sycl2idILi1EEC1ERKS2_", voidTy, {id1PtrTy, id1PtrTy}),

mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp

+71-14
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,64 @@ static Optional<Type> convertRangeType(sycl::RangeType type,
206206
converter);
207207
}
208208

209+
//===----------------------------------------------------------------------===//
210+
// CallPattern - Converts `sycl.call` to LLVM.
211+
//===----------------------------------------------------------------------===//
212+
213+
class CallPattern final : public SYCLToLLVMConversion<sycl::SYCLCallOp> {
214+
public:
215+
using SYCLToLLVMConversion<sycl::SYCLCallOp>::SYCLToLLVMConversion;
216+
using FuncId = SYCLFuncDescriptor::FuncId;
217+
using Kind = SYCLFuncDescriptor::Kind;
218+
219+
LogicalResult
220+
matchAndRewrite(sycl::SYCLCallOp op, OpAdaptor opAdaptor,
221+
ConversionPatternRewriter &rewriter) const override {
222+
assert(op.Type().has_value() &&
223+
"Expecting op.Type() to have a valid value");
224+
StringRef typeName = op.Type().value();
225+
Kind kind = SYCLFuncDescriptor::Id::nameToKind.at(typeName.str());
226+
assert((kind != Kind::Unknown) && "unknown descriptor kind");
227+
return rewriteCall(kind, op, opAdaptor, rewriter);
228+
}
229+
230+
private:
231+
/// Rewrite sycl.call() {Function = *, Type = *} to a LLVM call to the
232+
/// appropriate member function.
233+
LogicalResult rewriteCall(Kind kind, SYCLCallOp op, OpAdaptor opAdaptor,
234+
ConversionPatternRewriter &rewriter) const {
235+
assert((kind != Kind::Unknown) && "Unexpected descriptor kind");
236+
LLVM_DEBUG(llvm::dbgs() << "CallPattern: Rewriting op: "; op.dump();
237+
llvm::dbgs() << "\n");
238+
239+
ModuleOp module = op.getOperation()->getParentOfType<ModuleOp>();
240+
const auto &registry = SYCLFuncRegistry::create(module, rewriter);
241+
242+
/// Lookup the FuncId corresponding to the member function to use.
243+
Type retType = op.getODSResults(0).empty()
244+
? LLVM::LLVMVoidType::get(module.getContext())
245+
: op.Result().getType();
246+
247+
FuncId funcId =
248+
registry.getFuncId(kind, retType, opAdaptor.Args().getTypes());
249+
SYCLFuncDescriptor::call(funcId, opAdaptor.getOperands(), registry,
250+
rewriter, op.getLoc());
251+
252+
LLVM_DEBUG({
253+
Operation *func = op->getParentOfType<LLVM::LLVMFuncOp>();
254+
if (!func)
255+
func = op->getParentOfType<func::FuncOp>();
256+
257+
assert(func && "Could not find parent function");
258+
llvm::dbgs() << "ConstructorPattern: Function after rewrite:\n"
259+
<< *func << "\n";
260+
});
261+
262+
rewriter.eraseOp(op);
263+
return success();
264+
}
265+
};
266+
209267
//===----------------------------------------------------------------------===//
210268
// ConstructorPattern - Converts `sycl.constructor` to LLVM.
211269
//===----------------------------------------------------------------------===//
@@ -214,36 +272,34 @@ class ConstructorPattern final
214272
: public SYCLToLLVMConversion<sycl::SYCLConstructorOp> {
215273
public:
216274
using SYCLToLLVMConversion<sycl::SYCLConstructorOp>::SYCLToLLVMConversion;
275+
using FuncId = SYCLFuncDescriptor::FuncId;
276+
using Kind = SYCLFuncDescriptor::Kind;
217277

218278
LogicalResult
219-
matchAndRewrite(sycl::SYCLConstructorOp op, OpAdaptor opAdaptor,
279+
matchAndRewrite(SYCLConstructorOp op, OpAdaptor opAdaptor,
220280
ConversionPatternRewriter &rewriter) const override {
221281
return rewriteConstructor(
222-
SYCLFuncDescriptor::Id::nameToFuncKind.at(op.Type().str()), op,
223-
opAdaptor, rewriter);
282+
SYCLFuncDescriptor::Id::nameToKind.at(op.Type().str()), op, opAdaptor,
283+
rewriter);
224284
}
225285

226286
private:
227287
/// Rewrite sycl.constructor() { type = * } to a LLVM call to the appropriate
228288
/// constructor function.
229-
LogicalResult rewriteConstructor(SYCLFuncDescriptor::FuncKind ctorKind,
230-
SYCLConstructorOp op, OpAdaptor opAdaptor,
289+
LogicalResult rewriteConstructor(Kind kind, SYCLConstructorOp op,
290+
OpAdaptor opAdaptor,
231291
ConversionPatternRewriter &rewriter) const {
232-
assert((ctorKind != SYCLFuncDescriptor::FuncKind::Unknown) &&
233-
"Unexpected ctorKind");
292+
assert((kind != Kind::Unknown) && "Unexpected descriptor kind");
234293
LLVM_DEBUG(llvm::dbgs() << "ConstructorPattern: Rewriting op: "; op.dump();
235294
llvm::dbgs() << "\n");
236295

237296
ModuleOp module = op.getOperation()->getParentOfType<ModuleOp>();
238297
const auto &registry = SYCLFuncRegistry::create(module, rewriter);
239298

240-
/// Lookup the FuncId corresponding to the ctor function to use, which is
241-
/// determined based on 'ctorKind) the kind of constructor to search for, and
242-
/// the LLVM types of the sycl.constructor arguments.
243-
SYCLFuncDescriptor::FuncId funcId = registry.getFuncId(
244-
ctorKind, LLVM::LLVMVoidType::get(module.getContext()),
245-
opAdaptor.Args().getTypes());
246-
299+
/// Lookup the FuncId corresponding to the ctor function to use.
300+
auto retType = LLVM::LLVMVoidType::get(module.getContext());
301+
FuncId funcId =
302+
registry.getFuncId(kind, retType, opAdaptor.Args().getTypes());
247303
SYCLFuncDescriptor::call(funcId, opAdaptor.getOperands(), registry,
248304
rewriter, op.getLoc());
249305

@@ -300,5 +356,6 @@ void mlir::sycl::populateSYCLToLLVMConversionPatterns(
300356
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
301357
populateSYCLToLLVMTypeConversion(typeConverter);
302358

359+
patterns.add<CallPattern>(patterns.getContext(), typeConverter);
303360
patterns.add<ConstructorPattern>(patterns.getContext(), typeConverter);
304361
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: sycl-mlir-opt -split-input-file -convert-sycl-to-llvm -verify-diagnostics %s | FileCheck %s
2+
3+
//===-------------------------------------------------------------------------------------------------===//
4+
// Member functions for sycl::accessor
5+
//===-------------------------------------------------------------------------------------------------===//
6+
7+
!sycl_accessor_1_i32_read_write_global_buffer = !sycl.accessor<[1, i32, read_write, global_buffer], (!sycl.accessor_impl_device<[1], (!sycl.id<1>, !sycl.range<1>, !sycl.range<1>)>, !llvm.struct<(ptr<i32, 1>)>)>
8+
9+
// CHECK: llvm.func @_ZN2cl4sycl8accessorIiLi1ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEE6__initEPU3AS1iNS0_5rangeILi1EEESE_NS0_2idILi1EEE([[ARG_TYPES:!llvm.struct<\(ptr<struct<"class.cl::sycl::accessor.1",.*]])
10+
func.func @accessorInit1(%arg0: memref<?x!sycl_accessor_1_i32_read_write_global_buffer>, %arg1: memref<?xi32>, %arg2: !sycl.range<1>, %arg3: !sycl.range<1>, %arg4: !sycl.id<1>) {
11+
// CHECK: llvm.call @_ZN2cl4sycl8accessorIiLi1ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEE6__initEPU3AS1iNS0_5rangeILi1EEESE_NS0_2idILi1EEE({{.*}}) : ([[ARG_TYPES]]) -> ()
12+
sycl.call(%arg0, %arg1, %arg2, %arg3, %arg4) {Function = @__init, Type = @accessor} : (memref<?x!sycl_accessor_1_i32_read_write_global_buffer>, memref<?xi32>, !sycl.range<1>, !sycl.range<1>, !sycl.id<1>) -> ()
13+
return
14+
}
15+
16+
// -----

0 commit comments

Comments
 (0)