diff --git a/sycl/include/CL/sycl/detail/pi.h b/sycl/include/CL/sycl/detail/pi.h index 03bf0e8f5b2c8..d8da917b94cbc 100644 --- a/sycl/include/CL/sycl/detail/pi.h +++ b/sycl/include/CL/sycl/detail/pi.h @@ -60,6 +60,10 @@ using pi_uint64 = uint64_t; using pi_bool = pi_uint32; using pi_bitfield = pi_uint64; +// For selection of SYCL RT back-end, now manually through the "SYCL_BE" +// environment variable. +enum Backend { SYCL_BE_PI_OPENCL, SYCL_BE_PI_CUDA, SYCL_BE_PI_OTHER }; + // // NOTE: prefer to map 1:1 to OpenCL so that no translation is needed // for PI <-> OpenCL ways. The PI <-> to other BE translation is almost @@ -1346,6 +1350,8 @@ struct _pi_plugin { const char PiVersion[4] = _PI_H_VERSION_STRING; // Plugin edits this. char PluginVersion[4] = _PI_H_VERSION_STRING; + // Plugin type + Backend backend; char *Targets; struct FunctionPointers { #define _PI_API(api) decltype(::api) *api; diff --git a/sycl/include/CL/sycl/detail/pi.hpp b/sycl/include/CL/sycl/detail/pi.hpp index 73ba98a4e4530..4acf5fc3022ac 100644 --- a/sycl/include/CL/sycl/detail/pi.hpp +++ b/sycl/include/CL/sycl/detail/pi.hpp @@ -103,10 +103,6 @@ void *loadOsLibrary(const std::string &Library); // library, implementation is OS dependent. void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName); -// For selection of SYCL RT back-end, now manually through the "SYCL_BE" -// environment variable. -enum Backend { SYCL_BE_PI_OPENCL, SYCL_BE_PI_CUDA, SYCL_BE_PI_OTHER }; - // Check for manually selected BE at run-time. bool useBackend(Backend Backend); diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index 936bedad90912..50026eb3f25de 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -3564,6 +3564,9 @@ pi_result piPluginInit(pi_plugin *PluginInit) { // PI interface supports higher version or the same version. strncpy(PluginInit->PluginVersion, SupportedVersion, 4); + // Set plugin Type + PluginInit->backend = SYCL_BE_PI_CUDA; + // Set whole function table to zero to make it easier to detect if // functions are not set up below. std::memset(&(PluginInit->PiFunctionTable), 0, diff --git a/sycl/plugins/opencl/pi_opencl.cpp b/sycl/plugins/opencl/pi_opencl.cpp index 5f3f832c35e38..8e3655bb7fca9 100644 --- a/sycl/plugins/opencl/pi_opencl.cpp +++ b/sycl/plugins/opencl/pi_opencl.cpp @@ -1054,6 +1054,9 @@ pi_result piPluginInit(pi_plugin *PluginInit) { // PI interface supports higher version or the same version. strncpy(PluginInit->PluginVersion, SupportedVersion, 4); + // Set plugin Type + PluginInit->backend = SYCL_BE_PI_OPENCL; + #define _PI_CL(pi_api, ocl_api) \ (PluginInit->PiFunctionTable).pi_api = (decltype(&::pi_api))(&ocl_api); diff --git a/sycl/source/detail/plugin.hpp b/sycl/source/detail/plugin.hpp index 14ddf8f9560e2..c92862db35281 100644 --- a/sycl/source/detail/plugin.hpp +++ b/sycl/source/detail/plugin.hpp @@ -31,6 +31,10 @@ class plugin { const RT::PiPlugin &getPiPlugin() const { return MPlugin; } + bool isBackendType(Backend backend) const { + return MPlugin.backend == backend; + } + /// Checks return value from PI calls. /// /// \throw Exception if pi_result is not a PI_SUCCESS. diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index 20f9fa5a91004..69d4248919239 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -269,8 +269,10 @@ static bool isDeviceBinaryTypeSupported(const context &C, return false; } + ContextImplPtr ContextImpl = getSyclObjImpl(C); + // OpenCL 2.1 and greater require clCreateProgramWithIL - if (pi::useBackend(pi::SYCL_BE_PI_OPENCL) && + if (ContextImpl->getPlugin().isBackendType(Backend::SYCL_BE_PI_OPENCL) && C.get_platform().get_info() >= "2.1") return true; diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index b0e00e505010d..ef6fe42c9b629 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -1661,7 +1661,7 @@ cl_int ExecCGCommand::enqueueImp() { Requirement *Req = (Requirement *)(Arg.MPtr); AllocaCommandBase *AllocaCmd = getAllocaForReq(Req); RT::PiMem MemArg = (RT::PiMem)AllocaCmd->getMemAllocation(); - if (RT::useBackend(pi::Backend::SYCL_BE_PI_OPENCL)) { + if (Plugin.isBackendType(Backend::SYCL_BE_PI_OPENCL)) { Plugin.call(Kernel, Arg.MIndex, sizeof(RT::PiMem), &MemArg); } else {