diff --git a/sycl-jit/jit-compiler/include/KernelFusion.h b/sycl-jit/jit-compiler/include/KernelFusion.h index e79124f016c68..7310d69a91952 100644 --- a/sycl-jit/jit-compiler/include/KernelFusion.h +++ b/sycl-jit/jit-compiler/include/KernelFusion.h @@ -104,6 +104,8 @@ KF_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile, View IncludeFiles, View UserArgs); +KF_EXPORT_SYMBOL void destroyBinary(BinaryAddress Address); + /// Clear all previously set options. KF_EXPORT_SYMBOL void resetJITConfiguration(); diff --git a/sycl-jit/jit-compiler/ld-version-script.txt b/sycl-jit/jit-compiler/ld-version-script.txt index c12256659ce30..2c6f307c88d03 100644 --- a/sycl-jit/jit-compiler/ld-version-script.txt +++ b/sycl-jit/jit-compiler/ld-version-script.txt @@ -4,6 +4,7 @@ fuseKernels; materializeSpecConstants; compileSYCL; + destroyBinary; resetJITConfiguration; addToJITConfiguration; diff --git a/sycl-jit/jit-compiler/lib/KernelFusion.cpp b/sycl-jit/jit-compiler/lib/KernelFusion.cpp index 0ac2b12738f5e..34c67c8fb22b6 100644 --- a/sycl-jit/jit-compiler/lib/KernelFusion.cpp +++ b/sycl-jit/jit-compiler/lib/KernelFusion.cpp @@ -317,6 +317,10 @@ compileSYCL(InMemoryFile SourceFile, View IncludeFiles, return RTCResult{std::move(BundleInfo), BuildLog.c_str()}; } +extern "C" KF_EXPORT_SYMBOL void destroyBinary(BinaryAddress Address) { + JITContext::getInstance().destroyKernelBinary(Address); +} + extern "C" KF_EXPORT_SYMBOL void resetJITConfiguration() { ConfigHelper::reset(); } diff --git a/sycl-jit/jit-compiler/lib/fusion/JITContext.h b/sycl-jit/jit-compiler/lib/fusion/JITContext.h index 30329e000d2d0..1497a1a060362 100644 --- a/sycl-jit/jit-compiler/lib/fusion/JITContext.h +++ b/sycl-jit/jit-compiler/lib/fusion/JITContext.h @@ -70,7 +70,15 @@ class JITContext { template KernelBinary &emplaceKernelBinary(Ts &&...Args) { WriteLockT WriteLock{BinariesMutex}; - return Binaries.emplace_back(std::forward(Args)...); + auto KBUPtr = std::make_unique(std::forward(Args)...); + KernelBinary &KB = *KBUPtr; + Binaries.emplace(KB.address(), std::move(KBUPtr)); + return KB; + } + + void destroyKernelBinary(BinaryAddress Addr) { + WriteLockT WriteLock{BinariesMutex}; + Binaries.erase(Addr); } std::optional getCacheEntry(CacheKeyT &Identifier) const; @@ -96,7 +104,7 @@ class JITContext { MutexT BinariesMutex; - std::vector Binaries; + std::unordered_map> Binaries; mutable MutexT CacheMutex; diff --git a/sycl/source/detail/jit_compiler.cpp b/sycl/source/detail/jit_compiler.cpp index e95b3ab2e58b8..6fc88bb812a20 100644 --- a/sycl/source/detail/jit_compiler.cpp +++ b/sycl/source/detail/jit_compiler.cpp @@ -98,6 +98,16 @@ jit_compiler::jit_compiler() "Cannot resolve JIT library function entry point"); return false; } + + this->DestroyBinaryHandle = reinterpret_cast( + sycl::detail::ur::getOsLibraryFuncAddress(LibraryPtr.get(), + "destroyBinary")); + if (!this->DestroyBinaryHandle) { + printPerformanceWarning( + "Cannot resolve JIT library function entry point"); + return false; + } + LibraryHandle = std::move(LibraryPtr); return true; }; @@ -1130,10 +1140,10 @@ sycl_device_binaries jit_compiler::createPIDeviceBinary( return JITDeviceBinaries.back().getPIDeviceStruct(); } -sycl_device_binaries jit_compiler::createDeviceBinaryImage( +sycl_device_binaries jit_compiler::createDeviceBinaries( const ::jit_compiler::RTCBundleInfo &BundleInfo, const std::string &OffloadEntryPrefix) { - DeviceBinariesCollection Collection; + auto Collection = std::make_unique(); for (const auto &DevImgInfo : BundleInfo) { DeviceBinaryContainer Binary; @@ -1164,17 +1174,28 @@ sycl_device_binaries jit_compiler::createDeviceBinaryImage( Binary.addProperty(std::move(PropSet)); } - Collection.addDeviceBinary(std::move(Binary), - DevImgInfo.BinaryInfo.BinaryStart, - DevImgInfo.BinaryInfo.BinarySize, - (DevImgInfo.BinaryInfo.AddressBits == 64) - ? __SYCL_DEVICE_BINARY_TARGET_SPIRV64 - : __SYCL_DEVICE_BINARY_TARGET_SPIRV32, - SYCL_DEVICE_BINARY_TYPE_SPIRV); + Collection->addDeviceBinary(std::move(Binary), + DevImgInfo.BinaryInfo.BinaryStart, + DevImgInfo.BinaryInfo.BinarySize, + (DevImgInfo.BinaryInfo.AddressBits == 64) + ? __SYCL_DEVICE_BINARY_TARGET_SPIRV64 + : __SYCL_DEVICE_BINARY_TARGET_SPIRV32, + SYCL_DEVICE_BINARY_TYPE_SPIRV); } - JITDeviceBinaries.push_back(std::move(Collection)); - return JITDeviceBinaries.back().getPIDeviceStruct(); + sycl_device_binaries Binaries = Collection->getPIDeviceStruct(); + + std::lock_guard Guard{RTCDeviceBinariesMutex}; + RTCDeviceBinaries.emplace(Binaries, std::move(Collection)); + return Binaries; +} + +void jit_compiler::destroyDeviceBinaries(sycl_device_binaries Binaries) { + std::lock_guard Guard{RTCDeviceBinariesMutex}; + for (uint16_t i = 0; i < Binaries->NumDeviceBinaries; ++i) { + DestroyBinaryHandle(Binaries->DeviceBinaries[i].BinaryStart); + } + RTCDeviceBinaries.erase(Binaries); } std::vector jit_compiler::encodeArgUsageMask( @@ -1270,8 +1291,8 @@ sycl_device_binaries jit_compiler::compileSYCL( throw sycl::exception(sycl::errc::build, Result.getBuildLog()); } - return createDeviceBinaryImage(Result.getBundleInfo(), - /*OffloadEntryPrefix=*/CompilationID + '$'); + return createDeviceBinaries(Result.getBundleInfo(), + /*OffloadEntryPrefix=*/CompilationID + '$'); } } // namespace detail diff --git a/sycl/source/detail/jit_compiler.hpp b/sycl/source/detail/jit_compiler.hpp index aa190a3133afa..cf404e7bb723e 100644 --- a/sycl/source/detail/jit_compiler.hpp +++ b/sycl/source/detail/jit_compiler.hpp @@ -17,6 +17,8 @@ #endif // SYCL_EXT_JIT_ENABLE #include +#include +#include #include namespace jit_compiler { @@ -53,6 +55,8 @@ class jit_compiler { const std::vector &UserArgs, std::string *LogPtr, const std::vector &RegisteredKernelNames); + void destroyDeviceBinaries(sycl_device_binaries Binaries); + bool isAvailable() { return Available; } static jit_compiler &get_instance() { @@ -73,8 +77,8 @@ class jit_compiler { ::jit_compiler::BinaryFormat Format); sycl_device_binaries - createDeviceBinaryImage(const ::jit_compiler::RTCBundleInfo &BundleInfo, - const std::string &OffloadEntryPrefix); + createDeviceBinaries(const ::jit_compiler::RTCBundleInfo &BundleInfo, + const std::string &OffloadEntryPrefix); std::vector encodeArgUsageMask(const ::jit_compiler::ArgUsageMask &Mask) const; @@ -88,17 +92,27 @@ class jit_compiler { // Manages the lifetime of the UR structs for device binaries. std::vector JITDeviceBinaries; + // Manages the lifetime of the UR structs for device binaries for SYCL-RTC. + std::unordered_map> + RTCDeviceBinaries; + + // Protects access to map above. + std::mutex RTCDeviceBinariesMutex; + #if SYCL_EXT_JIT_ENABLE // Handles to the entry points of the lazily loaded JIT library. using FuseKernelsFuncT = decltype(::jit_compiler::fuseKernels) *; using MaterializeSpecConstFuncT = decltype(::jit_compiler::materializeSpecConstants) *; using CompileSYCLFuncT = decltype(::jit_compiler::compileSYCL) *; + using DestroyBinaryFuncT = decltype(::jit_compiler::destroyBinary) *; using ResetConfigFuncT = decltype(::jit_compiler::resetJITConfiguration) *; using AddToConfigFuncT = decltype(::jit_compiler::addToJITConfiguration) *; FuseKernelsFuncT FuseKernelsHandle = nullptr; MaterializeSpecConstFuncT MaterializeSpecConstHandle = nullptr; CompileSYCLFuncT CompileSYCLHandle = nullptr; + DestroyBinaryFuncT DestroyBinaryHandle = nullptr; ResetConfigFuncT ResetConfigHandle = nullptr; AddToConfigFuncT AddToConfigHandle = nullptr; static std::function CustomDeleterForLibHandle; diff --git a/sycl/source/detail/kernel_bundle_impl.hpp b/sycl/source/detail/kernel_bundle_impl.hpp index 8e6c9fc9afd01..be22afa63712a 100644 --- a/sycl/source/detail/kernel_bundle_impl.hpp +++ b/sycl/source/detail/kernel_bundle_impl.hpp @@ -380,7 +380,8 @@ class kernel_bundle_impl { // program manager integration, only for sycl_jit language kernel_bundle_impl(context Ctx, std::vector Devs, const std::vector &KernelIDs, - std::vector KNames, std::string Pfx, + std::vector KNames, + sycl_device_binaries Binaries, std::string Pfx, syclex::source_language Lang) : kernel_bundle_impl(std::move(Ctx), std::move(Devs), KernelIDs, bundle_state::executable) { @@ -392,6 +393,7 @@ class kernel_bundle_impl { // from the (unprefixed) kernel name. MIsInterop = true; MKernelNames = std::move(KNames); + MDeviceBinaries = Binaries; MPrefix = std::move(Pfx); MLanguage = Lang; } @@ -520,8 +522,9 @@ class kernel_bundle_impl { } } - return std::make_shared( - MContext, MDevices, KernelIDs, KernelNames, Prefix, MLanguage); + return std::make_shared(MContext, MDevices, KernelIDs, + KernelNames, Binaries, Prefix, + MLanguage); } ur_program_handle_t UrProgram = nullptr; @@ -928,6 +931,17 @@ class kernel_bundle_impl { return true; } + ~kernel_bundle_impl() { + try { + if (MDeviceBinaries) { + ProgramManager::getInstance().removeImages(MDeviceBinaries); + syclex::detail::SYCL_JIT_destroy(MDeviceBinaries); + } + } catch (std::exception &e) { + __SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~kernel_bundle_impl", e); + } + } + private: void fillUniqueDeviceImages() { assert(MUniqueDeviceImages.empty()); @@ -959,6 +973,7 @@ class kernel_bundle_impl { const std::variant> MSource; // only kernel_bundles created from source have KernelNames member. std::vector MKernelNames; + sycl_device_binaries MDeviceBinaries = nullptr; std::string MPrefix; include_pairs_t MIncludePairs; }; diff --git a/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp b/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp index 9108572bb5b1d..56b79340b1309 100644 --- a/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp +++ b/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp @@ -323,6 +323,15 @@ std::pair SYCL_JIT_to_SPIRV( #endif } +void SYCL_JIT_destroy([[maybe_unused]] sycl_device_binaries Binaries) { +#if SYCL_EXT_JIT_ENABLE + sycl::detail::jit_compiler::get_instance().destroyDeviceBinaries(Binaries); +#else + throw sycl::exception(sycl::errc::invalid, + "kernel_compiler via sycl-jit is not available"); +#endif +} + } // namespace detail } // namespace ext::oneapi::experimental } // namespace _V1 diff --git a/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp b/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp index 8187c5373150a..1a1a2665ae313 100644 --- a/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp +++ b/sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp @@ -40,6 +40,8 @@ SYCL_JIT_to_SPIRV(const std::string &Source, include_pairs_t IncludePairs, const std::vector &UserArgs, std::string *LogPtr, const std::vector &RegisteredKernelNames); +void SYCL_JIT_destroy(sycl_device_binaries Binaries); + bool SYCL_JIT_Compilation_Available(); } // namespace detail diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index 124b4c3822140..21ee135074ef0 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -1998,6 +1998,89 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) { } } +void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) { + for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) { + sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]); + const sycl_offload_entry EntriesB = RawImg->EntriesBegin; + const sycl_offload_entry EntriesE = RawImg->EntriesEnd; + // Treat the image as empty one + if (EntriesB == EntriesE) + continue; + + // Retrieve RTDeviceBinaryImage by looking up the first offload entry + kernel_id FirstKernelID = getSYCLKernelID(RawImg->EntriesBegin->name); + auto RTDBImages = getRawDeviceImages({FirstKernelID}); + assert(RTDBImages.size() == 1); + + RTDeviceBinaryImage *Img = *RTDBImages.begin(); + + // Drop the kernel argument mask map + m_EliminatedKernelArgMasks.erase(Img); + + // Acquire lock to modify maps for kernel bundles + std::lock_guard KernelIDsGuard(m_KernelIDsMutex); + + // Unmap the unique kernel IDs for the offload entries + for (sycl_offload_entry EntriesIt = EntriesB; EntriesIt != EntriesE; + ++EntriesIt) { + + // Drop entry for service kernel + if (std::strstr(EntriesIt->name, "__sycl_service_kernel__")) { + m_ServiceKernels.erase(EntriesIt->name); + continue; + } + + // Exported device functions won't have a kernel ID + if (m_ExportedSymbolImages.find(EntriesIt->name) != + m_ExportedSymbolImages.end()) { + continue; + } + + auto It = m_KernelName2KernelIDs.find(EntriesIt->name); + assert(It != m_KernelName2KernelIDs.end()); + m_KernelName2KernelIDs.erase(It); + m_KernelIDs2BinImage.erase(It->second); + } + + // Drop reverse mapping + m_BinImg2KernelIDs.erase(Img); + + // Unregister exported symbols (needs to happen after the ID unmap loop) + for (const sycl_device_binary_property &ESProp : + Img->getExportedSymbols()) { + m_ExportedSymbolImages.erase(ESProp->Name); + } + + // TODO: Handle other runtime info that was set up by `addImages` + assert(Img->getVirtualFunctions().empty()); + assert(Img->getAssertUsed().empty()); + assert(!Img->getProperty("sanUsed")); + assert(Img->getImplicitLocalArg().empty()); + assert(Img->getDeviceGlobals().empty()); + assert(Img->getHostPipes().empty()); + + // Purge references to the image in native programs map + { + std::lock_guard NativeProgramsGuard(MNativeProgramsMutex); + + // The map does not keep references to program handles; we can erase the + // entry without calling UR release + for (auto It = NativePrograms.begin(); It != NativePrograms.end();) { + auto CurIt = It++; + if (CurIt->second == Img) + NativePrograms.erase(CurIt); + } + } + + // Finally, destroy the image by erasing the associated unique ptr + auto It = + std::find_if(m_DeviceImages.begin(), m_DeviceImages.end(), + [Img](const auto &UPtr) { return UPtr.get() == Img; }); + assert(It != m_DeviceImages.end()); + m_DeviceImages.erase(It); + } +} + void ProgramManager::debugPrintBinaryImages() const { for (const auto &ImgIt : m_BinImg2KernelIDs) { ImgIt.first->print(); diff --git a/sycl/source/detail/program_manager/program_manager.hpp b/sycl/source/detail/program_manager/program_manager.hpp index 78d3574c01427..ecb3359d93dcc 100644 --- a/sycl/source/detail/program_manager/program_manager.hpp +++ b/sycl/source/detail/program_manager/program_manager.hpp @@ -216,6 +216,7 @@ class ProgramManager { const ContextImplPtr Context); void addImages(sycl_device_binaries DeviceImages); + void removeImages(sycl_device_binaries DeviceImages); void debugPrintBinaryImages() const; static std::string getProgramBuildLog(const ur_program_handle_t &Program, const ContextImplPtr &Context); diff --git a/sycl/test-e2e/KernelCompiler/kernel_compiler_sycl_jit.cpp b/sycl/test-e2e/KernelCompiler/kernel_compiler_sycl_jit.cpp index 5aca89e5be8a4..dffd1eb79c1ad 100644 --- a/sycl/test-e2e/KernelCompiler/kernel_compiler_sycl_jit.cpp +++ b/sycl/test-e2e/KernelCompiler/kernel_compiler_sycl_jit.cpp @@ -61,6 +61,17 @@ void ff_templated(T *ptr, T *unused) { } )==="; +auto constexpr SYCLSource2 = R"""( +#include + +extern "C" SYCL_EXTERNAL +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((sycl::ext::oneapi::experimental::nd_range_kernel<1>)) +void vec_add(float* in1, float* in2, float* out){ + size_t id = sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_global_linear_id(); + out[id] = in1[id] + in2[id]; +} +)"""; + auto constexpr ESIMDSource = R"===( #include #include @@ -229,6 +240,56 @@ int test_build_and_run() { return 0; } +int test_lifetimes() { + namespace syclex = sycl::ext::oneapi::experimental; + using source_kb = sycl::kernel_bundle; + using exe_kb = sycl::kernel_bundle; + + sycl::queue q; + sycl::context ctx = q.get_context(); + + bool ok = + q.get_device().ext_oneapi_can_compile(syclex::source_language::sycl_jit); + if (!ok) { + std::cout << "Apparently this device does not support `sycl_jit` source " + "kernel bundle extension: " + << q.get_device().get_info() + << std::endl; + return -1; + } + + source_kb kbSrc = syclex::create_kernel_bundle_from_source( + ctx, syclex::source_language::sycl_jit, SYCLSource2); + + exe_kb kbExe1 = syclex::build(kbSrc); + assert(sycl::get_kernel_ids().size() == 1); + + { + exe_kb kbExe2 = syclex::build(kbSrc); + assert(sycl::get_kernel_ids().size() == 2); + // kbExe2 goes out of scope; its kernels are removed from program mananager. + } + assert(sycl::get_kernel_ids().size() == 1); + + { + std::unique_ptr kPtr; + { + exe_kb kbExe3 = syclex::build(kbSrc); + assert(sycl::get_kernel_ids().size() == 2); + + sycl::kernel k = kbExe3.ext_oneapi_get_kernel("vec_add"); + kPtr = std::make_unique(k); + // kbExe3 goes out of scope, but the kernel keeps the underlying + // impl-object alive + } + assert(sycl::get_kernel_ids().size() == 2); + // kPtr goes out of scope, freeing the kernel and its bundle + } + assert(sycl::get_kernel_ids().size() == 1); + + return 0; +} + int test_device_code_split() { namespace syclex = sycl::ext::oneapi::experimental; using source_kb = sycl::kernel_bundle; @@ -452,8 +513,9 @@ int test_warning() { int main(int argc, char **) { #ifdef SYCL_EXT_ONEAPI_KERNEL_COMPILER int optional_tests = (argc > 1) ? test_warning() : 0; - return test_build_and_run() || test_device_code_split() || test_esimd() || - test_unsupported_options() || test_error() || optional_tests; + return test_build_and_run() || test_lifetimes() || test_device_code_split() || + test_esimd() || test_unsupported_options() || test_error() || + optional_tests; #else static_assert(false, "Kernel Compiler feature test macro undefined"); #endif