Skip to content

Commit ba6cc2c

Browse files
authored
[SYCL][RTC] Define custom destructor for kernel_bundle_impl (#16702)
Adds a custom destructor to `kernel_bundle_impl` to properly clean-up runtime information and device binaries for bundles that are runtime-compiled from source: - **Removal of kernels from program manager** The new `ProgramManager::removeImages` method takes `addImages` as a blueprint and reverses the effects of registering the given device binaries. Currently, only a subset of the data structures in the program manager is handled, and `assert`s are in place for the remaining members. Device globals will be addressed after #16565 lands. AFAICT I could clean-up most of the other members mechanically as well, but decided against that because I can't test these features (such as virtual functions) right now due to lack of SYCL-RTC support. - **Free device binaries in JIT library** Store the UR-specific data structure that backs a device binary in a map of unique ptrs instead of a vector to make it possible to free it again without invalidating other compilation results. Also introduce a mutex to protect concurrent access to this new map. - **Free raw SPIR-V binaries in JIT context** Again replace a vector with a map of unique ptr to make it possible to free `KernelBinary` objects in the `JITContext`. `KernelBinary` owns the actual SPIR-V binaries; all other data structures mentioned earlier only store the pointer return by `KernelBinary::address()`. A new function `destroyBinary` is added to the `sycl-jit` interface. My understanding is that `~device_image_impl` and `~kernel_impl` already take care of releasing the underlying UR resources, hence no changes required there. --------- Signed-off-by: Julian Oppermann <[email protected]>
1 parent bae7012 commit ba6cc2c

File tree

12 files changed

+244
-22
lines changed

12 files changed

+244
-22
lines changed

sycl-jit/jit-compiler/include/KernelFusion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ KF_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
104104
View<InMemoryFile> IncludeFiles,
105105
View<const char *> UserArgs);
106106

107+
KF_EXPORT_SYMBOL void destroyBinary(BinaryAddress Address);
108+
107109
/// Clear all previously set options.
108110
KF_EXPORT_SYMBOL void resetJITConfiguration();
109111

sycl-jit/jit-compiler/ld-version-script.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
fuseKernels;
55
materializeSpecConstants;
66
compileSYCL;
7+
destroyBinary;
78
resetJITConfiguration;
89
addToJITConfiguration;
910

sycl-jit/jit-compiler/lib/KernelFusion.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,10 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
317317
return RTCResult{std::move(BundleInfo), BuildLog.c_str()};
318318
}
319319

320+
extern "C" KF_EXPORT_SYMBOL void destroyBinary(BinaryAddress Address) {
321+
JITContext::getInstance().destroyKernelBinary(Address);
322+
}
323+
320324
extern "C" KF_EXPORT_SYMBOL void resetJITConfiguration() {
321325
ConfigHelper::reset();
322326
}

sycl-jit/jit-compiler/lib/fusion/JITContext.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,15 @@ class JITContext {
7070

7171
template <typename... Ts> KernelBinary &emplaceKernelBinary(Ts &&...Args) {
7272
WriteLockT WriteLock{BinariesMutex};
73-
return Binaries.emplace_back(std::forward<Ts>(Args)...);
73+
auto KBUPtr = std::make_unique<KernelBinary>(std::forward<Ts>(Args)...);
74+
KernelBinary &KB = *KBUPtr;
75+
Binaries.emplace(KB.address(), std::move(KBUPtr));
76+
return KB;
77+
}
78+
79+
void destroyKernelBinary(BinaryAddress Addr) {
80+
WriteLockT WriteLock{BinariesMutex};
81+
Binaries.erase(Addr);
7482
}
7583

7684
std::optional<SYCLKernelInfo> getCacheEntry(CacheKeyT &Identifier) const;
@@ -96,7 +104,7 @@ class JITContext {
96104

97105
MutexT BinariesMutex;
98106

99-
std::vector<KernelBinary> Binaries;
107+
std::unordered_map<BinaryAddress, std::unique_ptr<KernelBinary>> Binaries;
100108

101109
mutable MutexT CacheMutex;
102110

sycl/source/detail/jit_compiler.cpp

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,16 @@ jit_compiler::jit_compiler()
9898
"Cannot resolve JIT library function entry point");
9999
return false;
100100
}
101+
102+
this->DestroyBinaryHandle = reinterpret_cast<DestroyBinaryFuncT>(
103+
sycl::detail::ur::getOsLibraryFuncAddress(LibraryPtr.get(),
104+
"destroyBinary"));
105+
if (!this->DestroyBinaryHandle) {
106+
printPerformanceWarning(
107+
"Cannot resolve JIT library function entry point");
108+
return false;
109+
}
110+
101111
LibraryHandle = std::move(LibraryPtr);
102112
return true;
103113
};
@@ -1130,10 +1140,10 @@ sycl_device_binaries jit_compiler::createPIDeviceBinary(
11301140
return JITDeviceBinaries.back().getPIDeviceStruct();
11311141
}
11321142

1133-
sycl_device_binaries jit_compiler::createDeviceBinaryImage(
1143+
sycl_device_binaries jit_compiler::createDeviceBinaries(
11341144
const ::jit_compiler::RTCBundleInfo &BundleInfo,
11351145
const std::string &OffloadEntryPrefix) {
1136-
DeviceBinariesCollection Collection;
1146+
auto Collection = std::make_unique<DeviceBinariesCollection>();
11371147

11381148
for (const auto &DevImgInfo : BundleInfo) {
11391149
DeviceBinaryContainer Binary;
@@ -1164,17 +1174,28 @@ sycl_device_binaries jit_compiler::createDeviceBinaryImage(
11641174
Binary.addProperty(std::move(PropSet));
11651175
}
11661176

1167-
Collection.addDeviceBinary(std::move(Binary),
1168-
DevImgInfo.BinaryInfo.BinaryStart,
1169-
DevImgInfo.BinaryInfo.BinarySize,
1170-
(DevImgInfo.BinaryInfo.AddressBits == 64)
1171-
? __SYCL_DEVICE_BINARY_TARGET_SPIRV64
1172-
: __SYCL_DEVICE_BINARY_TARGET_SPIRV32,
1173-
SYCL_DEVICE_BINARY_TYPE_SPIRV);
1177+
Collection->addDeviceBinary(std::move(Binary),
1178+
DevImgInfo.BinaryInfo.BinaryStart,
1179+
DevImgInfo.BinaryInfo.BinarySize,
1180+
(DevImgInfo.BinaryInfo.AddressBits == 64)
1181+
? __SYCL_DEVICE_BINARY_TARGET_SPIRV64
1182+
: __SYCL_DEVICE_BINARY_TARGET_SPIRV32,
1183+
SYCL_DEVICE_BINARY_TYPE_SPIRV);
11741184
}
11751185

1176-
JITDeviceBinaries.push_back(std::move(Collection));
1177-
return JITDeviceBinaries.back().getPIDeviceStruct();
1186+
sycl_device_binaries Binaries = Collection->getPIDeviceStruct();
1187+
1188+
std::lock_guard<std::mutex> Guard{RTCDeviceBinariesMutex};
1189+
RTCDeviceBinaries.emplace(Binaries, std::move(Collection));
1190+
return Binaries;
1191+
}
1192+
1193+
void jit_compiler::destroyDeviceBinaries(sycl_device_binaries Binaries) {
1194+
std::lock_guard<std::mutex> Guard{RTCDeviceBinariesMutex};
1195+
for (uint16_t i = 0; i < Binaries->NumDeviceBinaries; ++i) {
1196+
DestroyBinaryHandle(Binaries->DeviceBinaries[i].BinaryStart);
1197+
}
1198+
RTCDeviceBinaries.erase(Binaries);
11781199
}
11791200

11801201
std::vector<uint8_t> jit_compiler::encodeArgUsageMask(
@@ -1270,8 +1291,8 @@ sycl_device_binaries jit_compiler::compileSYCL(
12701291
throw sycl::exception(sycl::errc::build, Result.getBuildLog());
12711292
}
12721293

1273-
return createDeviceBinaryImage(Result.getBundleInfo(),
1274-
/*OffloadEntryPrefix=*/CompilationID + '$');
1294+
return createDeviceBinaries(Result.getBundleInfo(),
1295+
/*OffloadEntryPrefix=*/CompilationID + '$');
12751296
}
12761297

12771298
} // namespace detail

sycl/source/detail/jit_compiler.hpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#endif // SYCL_EXT_JIT_ENABLE
1818

1919
#include <functional>
20+
#include <memory>
21+
#include <mutex>
2022
#include <unordered_map>
2123

2224
namespace jit_compiler {
@@ -53,6 +55,8 @@ class jit_compiler {
5355
const std::vector<std::string> &UserArgs, std::string *LogPtr,
5456
const std::vector<std::string> &RegisteredKernelNames);
5557

58+
void destroyDeviceBinaries(sycl_device_binaries Binaries);
59+
5660
bool isAvailable() { return Available; }
5761

5862
static jit_compiler &get_instance() {
@@ -73,8 +77,8 @@ class jit_compiler {
7377
::jit_compiler::BinaryFormat Format);
7478

7579
sycl_device_binaries
76-
createDeviceBinaryImage(const ::jit_compiler::RTCBundleInfo &BundleInfo,
77-
const std::string &OffloadEntryPrefix);
80+
createDeviceBinaries(const ::jit_compiler::RTCBundleInfo &BundleInfo,
81+
const std::string &OffloadEntryPrefix);
7882

7983
std::vector<uint8_t>
8084
encodeArgUsageMask(const ::jit_compiler::ArgUsageMask &Mask) const;
@@ -88,17 +92,27 @@ class jit_compiler {
8892
// Manages the lifetime of the UR structs for device binaries.
8993
std::vector<DeviceBinariesCollection> JITDeviceBinaries;
9094

95+
// Manages the lifetime of the UR structs for device binaries for SYCL-RTC.
96+
std::unordered_map<sycl_device_binaries,
97+
std::unique_ptr<DeviceBinariesCollection>>
98+
RTCDeviceBinaries;
99+
100+
// Protects access to map above.
101+
std::mutex RTCDeviceBinariesMutex;
102+
91103
#if SYCL_EXT_JIT_ENABLE
92104
// Handles to the entry points of the lazily loaded JIT library.
93105
using FuseKernelsFuncT = decltype(::jit_compiler::fuseKernels) *;
94106
using MaterializeSpecConstFuncT =
95107
decltype(::jit_compiler::materializeSpecConstants) *;
96108
using CompileSYCLFuncT = decltype(::jit_compiler::compileSYCL) *;
109+
using DestroyBinaryFuncT = decltype(::jit_compiler::destroyBinary) *;
97110
using ResetConfigFuncT = decltype(::jit_compiler::resetJITConfiguration) *;
98111
using AddToConfigFuncT = decltype(::jit_compiler::addToJITConfiguration) *;
99112
FuseKernelsFuncT FuseKernelsHandle = nullptr;
100113
MaterializeSpecConstFuncT MaterializeSpecConstHandle = nullptr;
101114
CompileSYCLFuncT CompileSYCLHandle = nullptr;
115+
DestroyBinaryFuncT DestroyBinaryHandle = nullptr;
102116
ResetConfigFuncT ResetConfigHandle = nullptr;
103117
AddToConfigFuncT AddToConfigHandle = nullptr;
104118
static std::function<void(void *)> CustomDeleterForLibHandle;

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,8 @@ class kernel_bundle_impl {
380380
// program manager integration, only for sycl_jit language
381381
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
382382
const std::vector<kernel_id> &KernelIDs,
383-
std::vector<std::string> KNames, std::string Pfx,
383+
std::vector<std::string> KNames,
384+
sycl_device_binaries Binaries, std::string Pfx,
384385
syclex::source_language Lang)
385386
: kernel_bundle_impl(std::move(Ctx), std::move(Devs), KernelIDs,
386387
bundle_state::executable) {
@@ -392,6 +393,7 @@ class kernel_bundle_impl {
392393
// from the (unprefixed) kernel name.
393394
MIsInterop = true;
394395
MKernelNames = std::move(KNames);
396+
MDeviceBinaries = Binaries;
395397
MPrefix = std::move(Pfx);
396398
MLanguage = Lang;
397399
}
@@ -520,8 +522,9 @@ class kernel_bundle_impl {
520522
}
521523
}
522524

523-
return std::make_shared<kernel_bundle_impl>(
524-
MContext, MDevices, KernelIDs, KernelNames, Prefix, MLanguage);
525+
return std::make_shared<kernel_bundle_impl>(MContext, MDevices, KernelIDs,
526+
KernelNames, Binaries, Prefix,
527+
MLanguage);
525528
}
526529

527530
ur_program_handle_t UrProgram = nullptr;
@@ -928,6 +931,17 @@ class kernel_bundle_impl {
928931
return true;
929932
}
930933

934+
~kernel_bundle_impl() {
935+
try {
936+
if (MDeviceBinaries) {
937+
ProgramManager::getInstance().removeImages(MDeviceBinaries);
938+
syclex::detail::SYCL_JIT_destroy(MDeviceBinaries);
939+
}
940+
} catch (std::exception &e) {
941+
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~kernel_bundle_impl", e);
942+
}
943+
}
944+
931945
private:
932946
void fillUniqueDeviceImages() {
933947
assert(MUniqueDeviceImages.empty());
@@ -959,6 +973,7 @@ class kernel_bundle_impl {
959973
const std::variant<std::string, std::vector<std::byte>> MSource;
960974
// only kernel_bundles created from source have KernelNames member.
961975
std::vector<std::string> MKernelNames;
976+
sycl_device_binaries MDeviceBinaries = nullptr;
962977
std::string MPrefix;
963978
include_pairs_t MIncludePairs;
964979
};

sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,15 @@ std::pair<sycl_device_binaries, std::string> SYCL_JIT_to_SPIRV(
323323
#endif
324324
}
325325

326+
void SYCL_JIT_destroy([[maybe_unused]] sycl_device_binaries Binaries) {
327+
#if SYCL_EXT_JIT_ENABLE
328+
sycl::detail::jit_compiler::get_instance().destroyDeviceBinaries(Binaries);
329+
#else
330+
throw sycl::exception(sycl::errc::invalid,
331+
"kernel_compiler via sycl-jit is not available");
332+
#endif
333+
}
334+
326335
} // namespace detail
327336
} // namespace ext::oneapi::experimental
328337
} // namespace _V1

sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ SYCL_JIT_to_SPIRV(const std::string &Source, include_pairs_t IncludePairs,
4040
const std::vector<std::string> &UserArgs, std::string *LogPtr,
4141
const std::vector<std::string> &RegisteredKernelNames);
4242

43+
void SYCL_JIT_destroy(sycl_device_binaries Binaries);
44+
4345
bool SYCL_JIT_Compilation_Available();
4446

4547
} // namespace detail

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1998,6 +1998,89 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
19981998
}
19991999
}
20002000

2001+
void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
2002+
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
2003+
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
2004+
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
2005+
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
2006+
// Treat the image as empty one
2007+
if (EntriesB == EntriesE)
2008+
continue;
2009+
2010+
// Retrieve RTDeviceBinaryImage by looking up the first offload entry
2011+
kernel_id FirstKernelID = getSYCLKernelID(RawImg->EntriesBegin->name);
2012+
auto RTDBImages = getRawDeviceImages({FirstKernelID});
2013+
assert(RTDBImages.size() == 1);
2014+
2015+
RTDeviceBinaryImage *Img = *RTDBImages.begin();
2016+
2017+
// Drop the kernel argument mask map
2018+
m_EliminatedKernelArgMasks.erase(Img);
2019+
2020+
// Acquire lock to modify maps for kernel bundles
2021+
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2022+
2023+
// Unmap the unique kernel IDs for the offload entries
2024+
for (sycl_offload_entry EntriesIt = EntriesB; EntriesIt != EntriesE;
2025+
++EntriesIt) {
2026+
2027+
// Drop entry for service kernel
2028+
if (std::strstr(EntriesIt->name, "__sycl_service_kernel__")) {
2029+
m_ServiceKernels.erase(EntriesIt->name);
2030+
continue;
2031+
}
2032+
2033+
// Exported device functions won't have a kernel ID
2034+
if (m_ExportedSymbolImages.find(EntriesIt->name) !=
2035+
m_ExportedSymbolImages.end()) {
2036+
continue;
2037+
}
2038+
2039+
auto It = m_KernelName2KernelIDs.find(EntriesIt->name);
2040+
assert(It != m_KernelName2KernelIDs.end());
2041+
m_KernelName2KernelIDs.erase(It);
2042+
m_KernelIDs2BinImage.erase(It->second);
2043+
}
2044+
2045+
// Drop reverse mapping
2046+
m_BinImg2KernelIDs.erase(Img);
2047+
2048+
// Unregister exported symbols (needs to happen after the ID unmap loop)
2049+
for (const sycl_device_binary_property &ESProp :
2050+
Img->getExportedSymbols()) {
2051+
m_ExportedSymbolImages.erase(ESProp->Name);
2052+
}
2053+
2054+
// TODO: Handle other runtime info that was set up by `addImages`
2055+
assert(Img->getVirtualFunctions().empty());
2056+
assert(Img->getAssertUsed().empty());
2057+
assert(!Img->getProperty("sanUsed"));
2058+
assert(Img->getImplicitLocalArg().empty());
2059+
assert(Img->getDeviceGlobals().empty());
2060+
assert(Img->getHostPipes().empty());
2061+
2062+
// Purge references to the image in native programs map
2063+
{
2064+
std::lock_guard<std::mutex> NativeProgramsGuard(MNativeProgramsMutex);
2065+
2066+
// The map does not keep references to program handles; we can erase the
2067+
// entry without calling UR release
2068+
for (auto It = NativePrograms.begin(); It != NativePrograms.end();) {
2069+
auto CurIt = It++;
2070+
if (CurIt->second == Img)
2071+
NativePrograms.erase(CurIt);
2072+
}
2073+
}
2074+
2075+
// Finally, destroy the image by erasing the associated unique ptr
2076+
auto It =
2077+
std::find_if(m_DeviceImages.begin(), m_DeviceImages.end(),
2078+
[Img](const auto &UPtr) { return UPtr.get() == Img; });
2079+
assert(It != m_DeviceImages.end());
2080+
m_DeviceImages.erase(It);
2081+
}
2082+
}
2083+
20012084
void ProgramManager::debugPrintBinaryImages() const {
20022085
for (const auto &ImgIt : m_BinImg2KernelIDs) {
20032086
ImgIt.first->print();

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ class ProgramManager {
216216
const ContextImplPtr Context);
217217

218218
void addImages(sycl_device_binaries DeviceImages);
219+
void removeImages(sycl_device_binaries DeviceImages);
219220
void debugPrintBinaryImages() const;
220221
static std::string getProgramBuildLog(const ur_program_handle_t &Program,
221222
const ContextImplPtr &Context);

0 commit comments

Comments
 (0)