Skip to content

Commit ad3951c

Browse files
committed
[SYCL][CUDA] CUDA Device selection changes
* Removes NVIDIA OpenCL from the available list of platforms * CUDA backend is available only if SYCL_BE=PI_CUDA is set Signed-off-by: Ruyman Reyes <[email protected]>
1 parent 66810be commit ad3951c

File tree

2 files changed

+15
-33
lines changed

2 files changed

+15
-33
lines changed

sycl/source/detail/platform_impl.cpp

+14-14
Original file line numberDiff line numberDiff line change
@@ -202,29 +202,29 @@ static bool isDeviceInvalidForBe(const device &Device) {
202202
if (Device.is_host())
203203
return false;
204204

205-
// Taking the version information from the platform gives us more useful
206-
// information than the driver_version of the device.
205+
// Retrieve Platform version to identify CUDA OpenCL platform
206+
// String: OpenCL 1.2 CUDA <version>
207207
const platform platform = Device.get_info<info::device::platform>();
208208
const std::string platformVersion =
209209
platform.get_info<info::platform::version>();
210+
const bool HasOpenCL = (platformVersion.find("OpenCL") != std::string::npos);
211+
const bool HasCUDA = (platformVersion.find("CUDA") != std::string::npos);
210212

211-
backend *BackendPref = detail::SYCLConfig<detail::SYCL_BE>::get();
212-
auto BackendType = detail::getSyclObjImpl(Device)->getPlugin().getBackend();
213-
static_assert(std::is_same<backend, decltype(BackendType)>(),
214-
"Type is not the same");
213+
backend *PrefBackend = detail::SYCLConfig<detail::SYCL_BE>::get();
214+
auto DeviceBackend = detail::getSyclObjImpl(Device)->getPlugin().getBackend();
215215

216-
// If no preference, assume OpenCL and reject CUDA backend
217-
if (BackendType == backend::cuda && !BackendPref) {
216+
// Reject the NVIDIA OpenCL implementation
217+
if (DeviceBackend == backend::opencl && HasCUDA && HasOpenCL)
218218
return true;
219-
} else if (!BackendPref)
220-
return false;
221219

222-
// If using PI_CUDA, don't accept a non-CUDA device
223-
if (BackendType == backend::opencl && *BackendPref == backend::cuda)
220+
// If no preference, assume OpenCL and reject CUDA
221+
if (DeviceBackend == backend::cuda && !PrefBackend) {
224222
return true;
223+
} else if (!PrefBackend)
224+
return false;
225225

226-
// If using PI_OPENCL, don't accept a non-OpenCL device
227-
if (BackendType == backend::cuda && *BackendPref == backend::opencl)
226+
// If using PI_OPENCL, reject the CUDA backend
227+
if (DeviceBackend == backend::cuda && *PrefBackend == backend::opencl)
228228
return true;
229229

230230
return false;

sycl/source/device_selector.cpp

+1-19
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,6 @@ device device_selector::select_device() const {
3434
int score = -1;
3535
const device *res = nullptr;
3636
for (const auto &dev : devices) {
37-
38-
// Reject the NVIDIA OpenCL platform
39-
if (!dev.is_host()) {
40-
string_class PlatformName = dev.get_info<info::device::platform>()
41-
.get_info<info::platform::name>();
42-
const bool IsCUDAPlatform =
43-
PlatformName.find("CUDA") != std::string::npos;
44-
45-
if (detail::getSyclObjImpl(dev)->getPlugin().getBackend() ==
46-
backend::opencl &&
47-
IsCUDAPlatform) {
48-
continue;
49-
}
50-
}
51-
5237
int dev_score = (*this)(dev);
5338
if (detail::pi::trace(detail::pi::TraceLevel::PI_TRACE_ALL)) {
5439
string_class PlatformVersion = dev.get_info<info::device::platform>()
@@ -95,9 +80,7 @@ device device_selector::select_device() const {
9580
}
9681

9782
int default_selector::operator()(const device &dev) const {
98-
9983
int Score = -1;
100-
10184
// Give preference to device of SYCL BE.
10285
if (isDeviceOfPreferredSyclBe(dev))
10386
Score = 50;
@@ -120,12 +103,11 @@ int default_selector::operator()(const device &dev) const {
120103

121104
int gpu_selector::operator()(const device &dev) const {
122105
int Score = -1;
123-
124106
if (dev.is_gpu()) {
125107
Score = 1000;
126108
// Give preference to device of SYCL BE.
127109
if (isDeviceOfPreferredSyclBe(dev))
128-
Score = 50;
110+
Score += 50;
129111
}
130112
return Score;
131113
}

0 commit comments

Comments
 (0)