Skip to content

Commit 50edee4

Browse files
authored
[SYCL] Allow pointers to const and volatile in sub-group load (#3497)
In every `load()` overload, use `T` with `const` and `volatile` removed.
1 parent e47dbad commit 50edee4

File tree

1 file changed

+43
-23
lines changed

1 file changed

+43
-23
lines changed

sycl/include/CL/sycl/ONEAPI/sub_group.hpp

+43-23
Original file line numberDiff line numberDiff line change
@@ -226,20 +226,22 @@ struct sub_group {
226226
/* these can map to SIMD or block read/write hardware where available */
227227
#ifdef __SYCL_DEVICE_ONLY__
228228
// Method for decorated pointer
229-
template <typename T>
229+
template <typename CVT, typename T = std::remove_cv_t<CVT>>
230230
detail::enable_if_t<
231231
!std::is_same<typename detail::remove_AS<T>::type, T>::value, T>
232-
load(T *src) const {
232+
load(CVT *cv_src) const {
233+
T *src = const_cast<T *>(cv_src);
233234
return load(sycl::multi_ptr<typename detail::remove_AS<T>::type,
234235
sycl::detail::deduce_AS<T>::value>(
235236
(typename detail::remove_AS<T>::type *)src));
236237
}
237238

238239
// Method for raw pointer
239-
template <typename T>
240+
template <typename CVT, typename T = std::remove_cv_t<CVT>>
240241
detail::enable_if_t<
241242
std::is_same<typename detail::remove_AS<T>::type, T>::value, T>
242-
load(T *src) const {
243+
load(CVT *cv_src) const {
244+
T *src = const_cast<T *>(cv_src);
243245

244246
#ifdef __NVPTX__
245247
return src[get_local_id()[0]];
@@ -257,17 +259,20 @@ struct sub_group {
257259
#endif // __NVPTX__
258260
}
259261
#else //__SYCL_DEVICE_ONLY__
260-
template <typename T> T load(T *src) const {
262+
template <typename CVT, typename T = std::remove_cv_t<CVT>>
263+
T load(CVT *src) const {
261264
(void)src;
262265
throw runtime_error("Sub-groups are not supported on host device.",
263266
PI_INVALID_DEVICE);
264267
}
265268
#endif //__SYCL_DEVICE_ONLY__
266269

267-
template <typename T, access::address_space Space>
270+
template <typename CVT, access::address_space Space,
271+
typename T = std::remove_cv_t<CVT>>
268272
sycl::detail::enable_if_t<
269273
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value, T>
270-
load(const multi_ptr<T, Space> src) const {
274+
load(const multi_ptr<CVT, Space> cv_src) const {
275+
multi_ptr<T, Space> src = const_cast<T *>(static_cast<CVT *>(cv_src));
271276
#ifdef __SYCL_DEVICE_ONLY__
272277
#ifdef __NVPTX__
273278
return src.get()[get_local_id()[0]];
@@ -281,10 +286,12 @@ struct sub_group {
281286
#endif
282287
}
283288

284-
template <typename T, access::address_space Space>
289+
template <typename CVT, access::address_space Space,
290+
typename T = std::remove_cv_t<CVT>>
285291
sycl::detail::enable_if_t<
286292
sycl::detail::sub_group::AcceptableForLocalLoadStore<T, Space>::value, T>
287-
load(const multi_ptr<T, Space> src) const {
293+
load(const multi_ptr<CVT, Space> cv_src) const {
294+
multi_ptr<T, Space> src = const_cast<T *>(static_cast<CVT *>(cv_src));
288295
#ifdef __SYCL_DEVICE_ONLY__
289296
return src.get()[get_local_id()[0]];
290297
#else
@@ -295,75 +302,88 @@ struct sub_group {
295302
}
296303
#ifdef __SYCL_DEVICE_ONLY__
297304
#ifdef __NVPTX__
298-
template <int N, typename T, access::address_space Space>
305+
template <int N, typename CVT, access::address_space Space,
306+
typename T = std::remove_cv_t<CVT>>
299307
sycl::detail::enable_if_t<
300308
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value,
301309
vec<T, N>>
302-
load(const multi_ptr<T, Space> src) const {
310+
load(const multi_ptr<CVT, Space> cv_src) const {
311+
multi_ptr<T, Space> src = const_cast<T *>(static_cast<CVT *>(cv_src));
303312
vec<T, N> res;
304313
for (int i = 0; i < N; ++i) {
305314
res[i] = *(src.get() + i * get_max_local_range()[0] + get_local_id()[0]);
306315
}
307316
return res;
308317
}
309318
#else // __NVPTX__
310-
template <int N, typename T, access::address_space Space>
319+
template <int N, typename CVT, access::address_space Space,
320+
typename T = std::remove_cv_t<CVT>>
311321
sycl::detail::enable_if_t<
312322
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
313323
N != 1 && N != 3 && N != 16,
314324
vec<T, N>>
315-
load(const multi_ptr<T, Space> src) const {
325+
load(const multi_ptr<CVT, Space> cv_src) const {
326+
multi_ptr<T, Space> src = const_cast<T *>(static_cast<CVT *>(cv_src));
316327
return sycl::detail::sub_group::load<N, T>(src);
317328
}
318329

319-
template <int N, typename T, access::address_space Space>
330+
template <int N, typename CVT, access::address_space Space,
331+
typename T = std::remove_cv_t<CVT>>
320332
sycl::detail::enable_if_t<
321333
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
322334
N == 16,
323335
vec<T, 16>>
324-
load(const multi_ptr<T, Space> src) const {
336+
load(const multi_ptr<CVT, Space> cv_src) const {
337+
multi_ptr<T, Space> src = const_cast<T *>(static_cast<CVT *>(cv_src));
325338
return {sycl::detail::sub_group::load<8, T>(src),
326339
sycl::detail::sub_group::load<8, T>(src +
327340
8 * get_max_local_range()[0])};
328341
}
329342

330-
template <int N, typename T, access::address_space Space>
343+
template <int N, typename CVT, access::address_space Space,
344+
typename T = std::remove_cv_t<CVT>>
331345
sycl::detail::enable_if_t<
332346
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
333347
N == 3,
334348
vec<T, 3>>
335-
load(const multi_ptr<T, Space> src) const {
349+
load(const multi_ptr<CVT, Space> cv_src) const {
350+
multi_ptr<T, Space> src = const_cast<T *>(static_cast<CVT *>(cv_src));
336351
return {
337352
sycl::detail::sub_group::load<1, T>(src),
338353
sycl::detail::sub_group::load<2, T>(src + get_max_local_range()[0])};
339354
}
340355

341-
template <int N, typename T, access::address_space Space>
356+
template <int N, typename CVT, access::address_space Space,
357+
typename T = std::remove_cv_t<CVT>>
342358
sycl::detail::enable_if_t<
343359
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
344360
N == 1,
345361
vec<T, 1>>
346-
load(const multi_ptr<T, Space> src) const {
362+
load(const multi_ptr<CVT, Space> cv_src) const {
363+
multi_ptr<T, Space> src = const_cast<T *>(static_cast<CVT *>(cv_src));
347364
return sycl::detail::sub_group::load(src);
348365
}
349366
#endif // ___NVPTX___
350367
#else // __SYCL_DEVICE_ONLY__
351-
template <int N, typename T, access::address_space Space>
368+
template <int N, typename CVT, access::address_space Space,
369+
typename T = std::remove_cv_t<CVT>>
352370
sycl::detail::enable_if_t<
353371
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value,
354372
vec<T, N>>
355-
load(const multi_ptr<T, Space> src) const {
373+
load(const multi_ptr<CVT, Space> src) const {
356374
(void)src;
357375
throw runtime_error("Sub-groups are not supported on host device.",
358376
PI_INVALID_DEVICE);
359377
}
360378
#endif // __SYCL_DEVICE_ONLY__
361379

362-
template <int N, typename T, access::address_space Space>
380+
template <int N, typename CVT, access::address_space Space,
381+
typename T = std::remove_cv_t<CVT>>
363382
sycl::detail::enable_if_t<
364383
sycl::detail::sub_group::AcceptableForLocalLoadStore<T, Space>::value,
365384
vec<T, N>>
366-
load(const multi_ptr<T, Space> src) const {
385+
load(const multi_ptr<CVT, Space> cv_src) const {
386+
multi_ptr<T, Space> src = const_cast<T *>(static_cast<CVT *>(cv_src));
367387
#ifdef __SYCL_DEVICE_ONLY__
368388
vec<T, N> res;
369389
for (int i = 0; i < N; ++i) {

0 commit comments

Comments
 (0)