Skip to content

Commit 8261853

Browse files
whitneywhtsangetiotto
authored andcommitted
Add MangledName attribute to sycl.constructor and sycl.call (#58)
We noticed that not all templated function can be register in the registry, e.g., function with template field `Type`, where `Type` can be user defined type. Will create a separate PR for cleaning up `SYCLFuncRegistry`. Signed-off-by: Tsang, Whitney <[email protected]>
1 parent 09df598 commit 8261853

File tree

6 files changed

+60
-53
lines changed

6 files changed

+60
-53
lines changed

mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOps.td

+4
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def SYCLConstructorOp : SYCL_Op<"constructor", []> {
7474

7575
let arguments = (ins
7676
FlatSymbolRefAttr:$Type,
77+
FlatSymbolRefAttr:$MangledName,
7778
Variadic<ConstructorArgs>:$Args
7879
);
7980
let results = (outs);
@@ -122,6 +123,7 @@ def SYCLCallOp : SYCL_Op<"call", []> {
122123
let arguments = (ins
123124
OptionalAttr<FlatSymbolRefAttr>:$Type,
124125
FlatSymbolRefAttr:$Function,
126+
FlatSymbolRefAttr:$MangledName,
125127
Variadic<AnyType>:$Args
126128
);
127129
let results = (outs Optional<AnyType>:$Result);
@@ -132,12 +134,14 @@ def SYCLCallOp : SYCL_Op<"call", []> {
132134
"::llvm::Optional<::mlir::Type>":$Result,
133135
"::llvm::Optional<::llvm::StringRef>":$Type,
134136
"::llvm::StringRef":$Function,
137+
"::llvm::StringRef":$MangledName,
135138
"::mlir::ValueRange":$Args), [{
136139
odsState.addOperands(Args);
137140
if (Type.hasValue()) {
138141
odsState.addAttribute(TypeAttrName(odsState.name), ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), Type.getValue()));
139142
}
140143
odsState.addAttribute(FunctionAttrName(odsState.name), ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), Function));
144+
odsState.addAttribute(MangledNameAttrName(odsState.name), ::mlir::SymbolRefAttr::get(odsBuilder.getContext(), MangledName));
141145
if (Result.hasValue()) {
142146
odsState.addTypes(Result.getValue());
143147
}

mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp

+12-13
Original file line numberDiff line numberDiff line change
@@ -237,17 +237,17 @@ class CallPattern final : public SYCLToLLVMConversion<sycl::SYCLCallOp> {
237237
llvm::dbgs() << "\n");
238238

239239
ModuleOp module = op.getOperation()->getParentOfType<ModuleOp>();
240-
const auto &registry = SYCLFuncRegistry::create(module, rewriter);
241240

242241
/// Lookup the FuncId corresponding to the member function to use.
243242
Type retType = op.getODSResults(0).empty()
244243
? LLVM::LLVMVoidType::get(module.getContext())
245244
: op.Result().getType();
246245

247-
FuncId funcId =
248-
registry.getFuncId(kind, retType, opAdaptor.Args().getTypes());
249-
SYCLFuncDescriptor::call(funcId, opAdaptor.getOperands(), registry,
250-
rewriter, op.getLoc());
246+
LLVMBuilder builder(rewriter, op.getLoc());
247+
SmallVector<Type> operandTypes(opAdaptor.Args().getTypes());
248+
FlatSymbolRefAttr funcRef = builder.getOrInsertFuncDecl(
249+
opAdaptor.MangledName(), retType, operandTypes, module);
250+
builder.genCall(funcRef, {}, opAdaptor.getOperands());
251251

252252
LLVM_DEBUG({
253253
Operation *func = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -294,14 +294,13 @@ class ConstructorPattern final
294294
llvm::dbgs() << "\n");
295295

296296
ModuleOp module = op.getOperation()->getParentOfType<ModuleOp>();
297-
const auto &registry = SYCLFuncRegistry::create(module, rewriter);
298-
299-
/// Lookup the FuncId corresponding to the ctor function to use.
300-
auto retType = LLVM::LLVMVoidType::get(module.getContext());
301-
FuncId funcId =
302-
registry.getFuncId(kind, retType, opAdaptor.Args().getTypes());
303-
SYCLFuncDescriptor::call(funcId, opAdaptor.getOperands(), registry,
304-
rewriter, op.getLoc());
297+
298+
LLVMBuilder builder(rewriter, op.getLoc());
299+
SmallVector<Type> operandTypes(opAdaptor.Args().getTypes());
300+
FlatSymbolRefAttr funcRef = builder.getOrInsertFuncDecl(
301+
opAdaptor.MangledName(), LLVM::LLVMVoidType::get(module.getContext()),
302+
operandTypes, module);
303+
builder.genCall(funcRef, {}, opAdaptor.getOperands());
305304

306305
LLVM_DEBUG({
307306
Operation *func = op->getParentOfType<LLVM::LLVMFuncOp>();

mlir-sycl/test/Conversion/SYCLToLLVM/sycl-call-to-llvm.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
// CHECK: llvm.func @_ZN2cl4sycl8accessorIiLi1ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEE6__initEPU3AS1iNS0_5rangeILi1EEESE_NS0_2idILi1EEE([[ARG_TYPES:!llvm.struct<\(ptr<struct<"class.cl::sycl::accessor.1",.*]])
1010
func.func @accessorInit1(%arg0: memref<?x!sycl_accessor_1_i32_read_write_global_buffer>, %arg1: memref<?xi32>, %arg2: !sycl.range<1>, %arg3: !sycl.range<1>, %arg4: !sycl.id<1>) {
1111
// CHECK: llvm.call @_ZN2cl4sycl8accessorIiLi1ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEE6__initEPU3AS1iNS0_5rangeILi1EEESE_NS0_2idILi1EEE({{.*}}) : ([[ARG_TYPES]]) -> ()
12-
sycl.call(%arg0, %arg1, %arg2, %arg3, %arg4) {Function = @__init, Type = @accessor} : (memref<?x!sycl_accessor_1_i32_read_write_global_buffer>, memref<?xi32>, !sycl.range<1>, !sycl.range<1>, !sycl.id<1>) -> ()
12+
sycl.call(%arg0, %arg1, %arg2, %arg3, %arg4) {Function = @__init, MangledName = @_ZN2cl4sycl8accessorIiLi1ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEE6__initEPU3AS1iNS0_5rangeILi1EEESE_NS0_2idILi1EEE, Type = @accessor} : (memref<?x!sycl_accessor_1_i32_read_write_global_buffer>, memref<?xi32>, !sycl.range<1>, !sycl.range<1>, !sycl.id<1>) -> ()
1313
return
1414
}
1515

0 commit comments

Comments
 (0)