Skip to content

Commit c6af296

Browse files
committed
[SYCL]Link Fallback Device Libraries On Demand
Signed-off-by: gejin <[email protected]>
1 parent 9e9faf9 commit c6af296

File tree

10 files changed

+219
-24
lines changed

10 files changed

+219
-24
lines changed

llvm/include/llvm/Support/PropertySetIO.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ class PropertySetRegistry {
123123
// Specific property category names used by tools.
124124
static constexpr char SYCL_SPECIALIZATION_CONSTANTS[] =
125125
"SYCL/specialization constants";
126+
static constexpr char SYCL_DEVICELIB_REQ_MASK[] = "SYCL/devicelib req mask";
126127

127128
// Function for bulk addition of an entire property set under given category
128129
// (property set name).
@@ -160,4 +161,4 @@ class PropertySetRegistry {
160161
} // namespace util
161162
} // namespace llvm
162163

163-
#endif // #define LLVM_SUPPORT_PROPERTYSETIO_H
164+
#endif // #define LLVM_SUPPORT_PROPERTYSETIO_H

llvm/lib/Support/PropertySetIO.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,6 @@ template <> PropertyValue::Type PropertyValue::getTypeTag<uint32_t>() {
112112
}
113113

114114
constexpr char PropertySetRegistry::SYCL_SPECIALIZATION_CONSTANTS[];
115+
constexpr char PropertySetRegistry::SYCL_DEVICELIB_REQ_MASK[];
115116
} // namespace util
116117
} // namespace llvm

llvm/lib/Support/SimpleTable.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ int SimpleTable::getColumnId(StringRef ColName) const {
6464

6565
Error SimpleTable::addColumnName(StringRef ColName) {
6666
if (ColumnName2Num.find(ColName) != ColumnName2Num.end())
67-
return makeError("column already exists" + ColName);
67+
return makeError("column already exists " + ColName);
6868
ColumnNames.emplace_back(ColName.str());
6969
ColumnName2Num[ColumnNames.back()] = static_cast<int>(ColumnNames.size()) - 1;
7070
ColumnNum2Name.push_back(std::prev(ColumnNames.end()));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//===- DeviceLibFunctions.h - record the functions in each device library--===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef __DEVICELIB_FUNCS_LIST__
9+
#define __DEVICELIB_FUNCS_LIST__
10+
#include <string>
11+
// All __devicelib_* functions must be sorted in alphabetic order as we
12+
// will use binary search to find some entry in them.
13+
static std::string CmathDeviceLibFunctions[] = {
14+
"__devicelib_acosf", "__devicelib_acoshf", "__devicelib_asinf",
15+
"__devicelib_asinhf", "__devicelib_atan2f", "__devicelib_atanf",
16+
"__devicelib_atanhf", "__devicelib_cbrtf", "__devicelib_cosf",
17+
"__devicelib_coshf", "__devicelib_erfcf", "__devicelib_erff",
18+
"__devicelib_exp2f", "__devicelib_expf", "__devicelib_expm1f",
19+
"__devicelib_fdimf", "__devicelib_fmaf", "__devicelib_fmodf",
20+
"__devicelib_frexpf", "__devicelib_hypotf", "__devicelib_ilogbf",
21+
"__devicelib_ldexpf", "__devicelib_lgammaf", "__devicelib_log10f",
22+
"__devicelib_log1pf", "__devicelib_log2f", "__devicelib_logbf",
23+
"__devicelib_logf", "__devicelib_modff", "__devicelib_nextafterf",
24+
"__devicelib_powf", "__devicelib_remainderf", "__devicelib_remquof",
25+
"__devicelib_sinf", "__devicelib_sinhf", "__devicelib_sqrtf",
26+
"__devicelib_tanf", "__devicelib_tanhf", "__devicelib_tgammaf"};
27+
28+
static std::string Cmath64DeviceLibFunctions[] = {
29+
"__devicelib_acos", "__devicelib_acosh", "__devicelib_asin",
30+
"__devicelib_asinh", "__devicelib_atan", "__devicelib_atan2",
31+
"__devicelib_atanh", "__devicelib_cbrt", "__devicelib_cos",
32+
"__devicelib_cosh", "__devicelib_erf", "__devicelib_erfc",
33+
"__devicelib_exp", "__devicelib_exp2", "__devicelib_expm1",
34+
"__devicelib_fdim", "__devicelib_fma", "__devicelib_fmod",
35+
"__devicelib_frexp", "__devicelib_hypot", "__devicelib_ilogb",
36+
"__devicelib_ldexp", "__devicelib_lgamma", "__devicelib_log",
37+
"__devicelib_log10", "__devicelib_log1p", "__devicelib_log2",
38+
"__devicelib_logb", "__devicelib_modf", "__devicelib_nextafter",
39+
"__devicelib_pow", "__devicelib_remainder", "__devicelib_remquo",
40+
"__devicelib_sin", "__devicelib_sinh", "__devicelib_sqrt",
41+
"__devicelib_tan", "__devicelib_tanh", "__devicelib_tgamma"};
42+
43+
static std::string ComplexDeviceLibFunctions[] = {
44+
"__devicelib___divsc3", "__devicelib___mulsc3", "__devicelib_cabsf",
45+
"__devicelib_cacosf", "__devicelib_cacoshf", "__devicelib_cargf",
46+
"__devicelib_casinf", "__devicelib_casinhf", "__devicelib_catanf",
47+
"__devicelib_catanhf", "__devicelib_ccosf", "__devicelib_ccoshf",
48+
"__devicelib_cexpf", "__devicelib_cimagf", "__devicelib_clogf",
49+
"__devicelib_cpolarf", "__devicelib_cpowf", "__devicelib_cprojf",
50+
"__devicelib_crealf", "__devicelib_csinf", "__devicelib_csinhf",
51+
"__devicelib_csqrtf", "__devicelib_ctanf", "__devicelib_ctanhf"};
52+
53+
static std::string Complex64DeviceLibFunctions[] = {
54+
"__devicelib___divdc3", "__devicelib___muldc3", "__devicelib_cabs",
55+
"__devicelib_cacos", "__devicelib_cacosh", "__devicelib_carg",
56+
"__devicelib_casin", "__devicelib_casinh", "__devicelib_catan",
57+
"__devicelib_catanh", "__devicelib_ccos", "__devicelib_ccosh",
58+
"__devicelib_cexp", "__devicelib_cimag", "__devicelib_clog",
59+
"__devicelib_cpolar", "__devicelib_cpow", "__devicelib_cproj",
60+
"__devicelib_creal", "__devicelib_csin", "__devicelib_csinh",
61+
"__devicelib_csqrt", "__devicelib_ctan", "__devicelib_ctanh"};
62+
#endif

llvm/tools/sycl-post-link/sycl-post-link.cpp

+102-13
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// - specialization constant intrinsic transformation
1414
//===----------------------------------------------------------------------===//
1515

16+
#include "DeviceLibFunctions.h"
1617
#include "SpecConstants.h"
1718
#include "llvm/ADT/SetVector.h"
1819
#include "llvm/Bitcode/BitcodeWriterPass.h"
@@ -295,15 +296,95 @@ saveResultModules(std::vector<std::unique_ptr<Module>> &ResModules) {
295296
return Res;
296297
}
297298

299+
// Each fallback device library corresponds to one bit in "require mask" which
300+
// is an unsigned int32. getDeviceLibBit checks which fallback device library
301+
// is required for FuncName and returns the corresponding bit. The corresponding
302+
// mask for each fallback device library is:
303+
// fallback-cassert: 0x1
304+
// fallback-cmath: 0x2
305+
// fallback-cmath-fp64: 0x4
306+
// fallback-complex: 0x8
307+
// fallback-complex-fp64: 0x10
308+
static uint32_t getDeviceLibBits(const std::string &FuncName) {
309+
310+
static constexpr uint32_t DeviceLibAssert = 0x1;
311+
static constexpr uint32_t DeviceLibCmath = 0x2;
312+
static constexpr uint32_t DeviceLibCmath64 = 0x4;
313+
static constexpr uint32_t DeviceLibComplex = 0x8;
314+
static constexpr uint32_t DeviceLibComplex64 = 0x10;
315+
if (FuncName == "__devicelib_assert_fail") {
316+
return DeviceLibAssert;
317+
}
318+
size_t Len =
319+
sizeof(CmathDeviceLibFunctions) / sizeof(CmathDeviceLibFunctions[0]);
320+
if (std::binary_search(CmathDeviceLibFunctions, CmathDeviceLibFunctions + Len,
321+
FuncName)) {
322+
return DeviceLibCmath;
323+
}
324+
Len =
325+
sizeof(Cmath64DeviceLibFunctions) / sizeof(Cmath64DeviceLibFunctions[0]);
326+
if (std::binary_search(Cmath64DeviceLibFunctions,
327+
Cmath64DeviceLibFunctions + Len, FuncName)) {
328+
return DeviceLibCmath64;
329+
}
330+
Len =
331+
sizeof(ComplexDeviceLibFunctions) / sizeof(ComplexDeviceLibFunctions[0]);
332+
if (std::binary_search(ComplexDeviceLibFunctions,
333+
ComplexDeviceLibFunctions + Len, FuncName)) {
334+
return DeviceLibComplex;
335+
}
336+
Len = sizeof(Complex64DeviceLibFunctions) /
337+
sizeof(Complex64DeviceLibFunctions[0]);
338+
if (std::binary_search(Complex64DeviceLibFunctions,
339+
Complex64DeviceLibFunctions + Len, FuncName)) {
340+
return DeviceLibComplex64;
341+
}
342+
return 0;
343+
}
344+
345+
// For each device image module, we go through all functions which meets
346+
// 1. The function name has prefix "__devicelib_"
347+
// 2. The function has SPIR_FUNC calling convention
348+
// 3. The function is declaration which means it doesn't have function body
349+
static uint32_t getModuleReqMask(const std::unique_ptr<Module> &MPtr) {
350+
uint32_t ReqMask = 0;
351+
uint32_t DeviceLibBits = 0;
352+
for (const Function &SF : *MPtr) {
353+
if (SF.getName().startswith("__devicelib_") &&
354+
(SF.getCallingConv() == CallingConv::SPIR_FUNC) && SF.isDeclaration()) {
355+
DeviceLibBits = getDeviceLibBits(SF.getName().str());
356+
ReqMask |= DeviceLibBits;
357+
}
358+
}
359+
return ReqMask;
360+
}
361+
362+
static void
363+
getDeviceLibReqMasks(const std::vector<std::unique_ptr<Module>> &ResModules,
364+
std::vector<uint32_t> &DeviceLibReqMaskVec) {
365+
for (auto &MPtr : ResModules) {
366+
uint32_t ModuleReqMask = getModuleReqMask(MPtr);
367+
DeviceLibReqMaskVec.push_back(ModuleReqMask);
368+
}
369+
}
370+
298371
static string_vector
299-
saveSpecConstantIDMaps(const std::vector<SpecIDMapTy> &Maps) {
372+
saveDeviceImageProperty(const std::vector<uint32_t> ReqMaskVec,
373+
const std::vector<SpecIDMapTy> &Maps) {
300374
string_vector Res;
301-
302-
for (size_t I = 0; I < Maps.size(); ++I) {
375+
bool saveSpecIDMaps =
376+
(Maps.size() != 0) && (ReqMaskVec.size() == Maps.size());
377+
for (size_t I = 0; I < ReqMaskVec.size(); ++I) {
303378
std::string SCFile = makeResultFileName(".prop", I);
304379
llvm::util::PropertySetRegistry PropSet;
305-
PropSet.add(llvm::util::PropertySetRegistry::SYCL_SPECIALIZATION_CONSTANTS,
306-
Maps[I]);
380+
std::map<StringRef, uint32_t> reqMaskEntry;
381+
reqMaskEntry["devicelib_req_mask"] = ReqMaskVec[I];
382+
PropSet.add(llvm::util::PropertySetRegistry::SYCL_DEVICELIB_REQ_MASK,
383+
reqMaskEntry);
384+
if (saveSpecIDMaps)
385+
PropSet.add(
386+
llvm::util::PropertySetRegistry::SYCL_SPECIALIZATION_CONSTANTS,
387+
Maps[I]);
307388
std::error_code EC;
308389
raw_fd_ostream SCOut(SCFile, EC);
309390
PropSet.write(SCOut);
@@ -456,15 +537,23 @@ int main(int argc, char **argv) {
456537
Error Err = Table.addColumn(COL_CODE, Files);
457538
CHECK_AND_EXIT(Err);
458539
}
459-
if (DoSpecConst && SetSpecConstAtRT) {
460-
// extract spec constant maps per each module
461-
for (auto &MUptr : ResultModules) {
462-
ResultSpecIDMaps.emplace_back(SpecIDMapTy());
463-
if (SpecConstsMet)
464-
SpecConstantsPass::collectSpecConstantMetadata(*MUptr.get(),
465-
ResultSpecIDMaps.back());
540+
{
541+
// Device library req mask is collected and stored in device image property
542+
// as default and each device image module will have one req mask.
543+
std::vector<uint32_t> DeviceLibReqMaskVec;
544+
getDeviceLibReqMasks(ResultModules, DeviceLibReqMaskVec);
545+
if (DoSpecConst && SetSpecConstAtRT) {
546+
// extract spec constant maps per each module
547+
for (auto &MUptr : ResultModules) {
548+
ResultSpecIDMaps.emplace_back(SpecIDMapTy());
549+
if (SpecConstsMet)
550+
SpecConstantsPass::collectSpecConstantMetadata(
551+
*MUptr.get(), ResultSpecIDMaps.back());
552+
}
553+
assert(DeviceLibReqMaskVec.size() == ResultSpecIDMaps.size());
466554
}
467-
string_vector Files = saveSpecConstantIDMaps(ResultSpecIDMaps);
555+
string_vector Files =
556+
saveDeviceImageProperty(DeviceLibReqMaskVec, ResultSpecIDMaps);
468557
Error Err = Table.addColumn(COL_PROPS, Files);
469558
CHECK_AND_EXIT(Err);
470559
}

sycl/include/CL/sycl/detail/pi.h

+1
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ static const uint8_t PI_DEVICE_BINARY_OFFLOAD_KIND_SYCL = 4;
650650
/// PropertySetRegistry::SYCL_SPECIALIZATION_CONSTANTS defined in
651651
/// PropertySetIO.h
652652
#define PI_PROPERTY_SET_SPEC_CONST_MAP "SYCL/specialization constants"
653+
#define PI_PROPERTY_SET_DEVICELIB_REQ_MASK "SYCL/devicelib req mask"
653654

654655
/// This struct is a record of the device binary information. If the Kind field
655656
/// denotes a portable binary type (SPIR-V or LLVM IR), the DeviceTargetSpec

sycl/include/CL/sycl/detail/pi.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class DeviceBinaryImage {
293293
/// name of the property is the specializaion constant symbolic ID and the
294294
/// value is 32-bit unsigned integer ID.
295295
const PropertyRange &getSpecConstants() const { return SpecConstIDMap; }
296+
const PropertyRange &getDeviceLibReqMask() const { return DeviceLibReqMask; }
296297
virtual ~DeviceBinaryImage() {}
297298

298299
protected:
@@ -302,6 +303,7 @@ class DeviceBinaryImage {
302303
pi_device_binary Bin;
303304
pi::PiDeviceBinaryType Format = PI_DEVICE_BINARY_TYPE_NONE;
304305
DeviceBinaryImage::PropertyRange SpecConstIDMap;
306+
DeviceBinaryImage::PropertyRange DeviceLibReqMask;
305307
};
306308

307309
/// Tries to determine the device binary image foramat. Returns

sycl/source/detail/pi.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ void DeviceBinaryImage::init(pi_device_binary Bin) {
510510
Format = getBinaryImageFormat(Bin->BinaryStart, getSize());
511511

512512
SpecConstIDMap.init(Bin, PI_PROPERTY_SET_SPEC_CONST_MAP);
513+
DeviceLibReqMask.init(Bin, PI_PROPERTY_SET_DEVICELIB_REQ_MASK);
513514
}
514515

515516
} // namespace pi

sycl/source/detail/program_manager/program_manager.cpp

+45-8
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,11 @@ RT::PiProgram ProgramManager::getBuiltPIProgram(OSModuleHandle M,
377377
// Link a fallback implementation of device libraries if they are not
378378
// supported by a device compiler.
379379
// Pre-compiled programs are supposed to be already linked.
380-
const bool LinkDeviceLibs = Img.getFormat() == PI_DEVICE_BINARY_TYPE_SPIRV;
380+
// If device image is not SPIRV, DeviceLibReqMask will be 0 which means
381+
// no fallback device library will be linked.
382+
uint32_t DeviceLibReqMask = 0;
383+
if (Img.getFormat() == PI_DEVICE_BINARY_TYPE_SPIRV)
384+
DeviceLibReqMask = getDeviceLibReqMask(Img);
381385

382386
const std::vector<device> &Devices = ContextImpl->getDevices();
383387
std::vector<RT::PiDevice> PiDevices(Devices.size());
@@ -388,7 +392,7 @@ RT::PiProgram ProgramManager::getBuiltPIProgram(OSModuleHandle M,
388392
ProgramPtr BuiltProgram =
389393
build(std::move(ProgramManaged), ContextImpl, Img.getCompileOptions(),
390394
Img.getLinkOptions(), PiDevices,
391-
ContextImpl->getCachedLibPrograms(), LinkDeviceLibs);
395+
ContextImpl->getCachedLibPrograms(), DeviceLibReqMask);
392396

393397
return BuiltProgram.release();
394398
};
@@ -659,15 +663,37 @@ RTDeviceBinaryImage &ProgramManager::getDeviceImage(OSModuleHandle M,
659663
return *Img;
660664
}
661665

666+
static bool isDeviceLibRequired(DeviceLibExt Ext, uint32_t DeviceLibReqMask) {
667+
static constexpr uint32_t DeviceLibAssert = 0x1;
668+
static constexpr uint32_t DeviceLibCmath = 0x2;
669+
static constexpr uint32_t DeviceLibCmath64 = 0x4;
670+
static constexpr uint32_t DeviceLibComplex = 0x8;
671+
static constexpr uint32_t DeviceLibComplex64 = 0x10;
672+
switch (Ext) {
673+
case cl_intel_devicelib_assert:
674+
return (DeviceLibReqMask & DeviceLibAssert) == DeviceLibAssert;
675+
case cl_intel_devicelib_math:
676+
return (DeviceLibReqMask & DeviceLibCmath) == DeviceLibCmath;
677+
case cl_intel_devicelib_math_fp64:
678+
return (DeviceLibReqMask & DeviceLibCmath64) == DeviceLibCmath64;
679+
case cl_intel_devicelib_complex:
680+
return (DeviceLibReqMask & DeviceLibComplex) == DeviceLibComplex;
681+
case cl_intel_devicelib_complex_fp64:
682+
return (DeviceLibReqMask & DeviceLibComplex64) == DeviceLibComplex64;
683+
default:
684+
break;
685+
}
686+
687+
return false;
688+
}
689+
662690
static std::vector<RT::PiProgram>
663691
getDeviceLibPrograms(const ContextImplPtr Context,
664692
const std::vector<RT::PiDevice> &Devices,
665-
std::map<DeviceLibExt, RT::PiProgram> &CachedLibPrograms) {
693+
std::map<DeviceLibExt, RT::PiProgram> &CachedLibPrograms,
694+
uint32_t DeviceLibReqMask) {
666695
std::vector<RT::PiProgram> Programs;
667696

668-
// TODO: SYCL compiler should generate a list of required extensions for a
669-
// particular program in order to allow us do a more fine-grained check here.
670-
// Require *all* possible devicelib extensions for now.
671697
std::pair<DeviceLibExt, bool> RequiredDeviceLibExt[] = {
672698
{cl_intel_devicelib_assert, /* is fallback loaded? */ false},
673699
{cl_intel_devicelib_math, false},
@@ -701,6 +727,9 @@ getDeviceLibPrograms(const ContextImplPtr Context,
701727
continue;
702728
}
703729

730+
if (!isDeviceLibRequired(Ext, DeviceLibReqMask)) {
731+
continue;
732+
}
704733
if ((Ext == cl_intel_devicelib_math_fp64 ||
705734
Ext == cl_intel_devicelib_complex_fp64) && !fp64Support) {
706735
continue;
@@ -731,14 +760,15 @@ ProgramManager::build(ProgramPtr Program, const ContextImplPtr Context,
731760
const string_class &LinkOptions,
732761
const std::vector<RT::PiDevice> &Devices,
733762
std::map<DeviceLibExt, RT::PiProgram> &CachedLibPrograms,
734-
bool LinkDeviceLibs) {
763+
uint32_t DeviceLibReqMask) {
735764

736765
if (DbgProgMgr > 0) {
737766
std::cerr << ">>> ProgramManager::build(" << Program.get() << ", "
738767
<< CompileOptions << ", " << LinkOptions << ", ... "
739768
<< Devices.size() << ")\n";
740769
}
741770

771+
bool LinkDeviceLibs = (DeviceLibReqMask != 0);
742772
const char *CompileOpts = std::getenv("SYCL_PROGRAM_COMPILE_OPTIONS");
743773
if (!CompileOpts) {
744774
CompileOpts = CompileOptions.c_str();
@@ -750,7 +780,8 @@ ProgramManager::build(ProgramPtr Program, const ContextImplPtr Context,
750780

751781
std::vector<RT::PiProgram> LinkPrograms;
752782
if (LinkDeviceLibs) {
753-
LinkPrograms = getDeviceLibPrograms(Context, Devices, CachedLibPrograms);
783+
LinkPrograms = getDeviceLibPrograms(Context, Devices, CachedLibPrograms,
784+
DeviceLibReqMask);
754785
}
755786

756787
const detail::plugin &Plugin = Context->getPlugin();
@@ -951,6 +982,12 @@ void ProgramManager::flushSpecConstants(const program_impl &Prg,
951982
Prg.flush_spec_constants(*Img, NativePrg);
952983
}
953984

985+
uint32_t ProgramManager::getDeviceLibReqMask(const RTDeviceBinaryImage &Img) {
986+
const pi::DeviceBinaryImage::PropertyRange &DLMRange =
987+
Img.getDeviceLibReqMask();
988+
return pi::DeviceBinaryProperty(*(DLMRange.begin())).asUint32();
989+
}
990+
954991
} // namespace detail
955992
} // namespace sycl
956993
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/source/detail/program_manager/program_manager.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class ProgramManager {
102102
void flushSpecConstants(const program_impl &Prg,
103103
pi::PiProgram NativePrg = nullptr,
104104
const RTDeviceBinaryImage *Img = nullptr);
105+
uint32_t getDeviceLibReqMask(const RTDeviceBinaryImage &Img);
105106

106107
private:
107108
ProgramManager();
@@ -118,7 +119,7 @@ class ProgramManager {
118119
const string_class &LinkOptions,
119120
const std::vector<RT::PiDevice> &Devices,
120121
std::map<DeviceLibExt, RT::PiProgram> &CachedLibPrograms,
121-
bool LinkDeviceLibs = false);
122+
uint32_t DeviceLibReqMask);
122123
/// Provides a new kernel set id for grouping kernel names together
123124
KernelSetId getNextKernelSetId() const;
124125
/// Returns the kernel set associated with the kernel, handles some special

0 commit comments

Comments
 (0)