@@ -226,20 +226,22 @@ struct sub_group {
226
226
/* these can map to SIMD or block read/write hardware where available */
227
227
#ifdef __SYCL_DEVICE_ONLY__
228
228
// Method for decorated pointer
229
- template <typename T >
229
+ template <typename CVT, typename T = std:: remove_cv_t <CVT> >
230
230
detail::enable_if_t <
231
231
!std::is_same<typename detail::remove_AS<T>::type, T>::value, T>
232
- load (T *src) const {
232
+ load (CVT *cv_src) const {
233
+ T *src = const_cast <T *>(cv_src);
233
234
return load (sycl::multi_ptr<typename detail::remove_AS<T>::type,
234
235
sycl::detail::deduce_AS<T>::value>(
235
236
(typename detail::remove_AS<T>::type *)src));
236
237
}
237
238
238
239
// Method for raw pointer
239
- template <typename T >
240
+ template <typename CVT, typename T = std:: remove_cv_t <CVT> >
240
241
detail::enable_if_t <
241
242
std::is_same<typename detail::remove_AS<T>::type, T>::value, T>
242
- load (T *src) const {
243
+ load (CVT *cv_src) const {
244
+ T *src = const_cast <T *>(cv_src);
243
245
244
246
#ifdef __NVPTX__
245
247
return src[get_local_id ()[0 ]];
@@ -257,17 +259,20 @@ struct sub_group {
257
259
#endif // __NVPTX__
258
260
}
259
261
#else // __SYCL_DEVICE_ONLY__
260
- template <typename T> T load (T *src) const {
262
+ template <typename CVT, typename T = std::remove_cv_t <CVT>>
263
+ T load (CVT *src) const {
261
264
(void )src;
262
265
throw runtime_error (" Sub-groups are not supported on host device." ,
263
266
PI_INVALID_DEVICE);
264
267
}
265
268
#endif // __SYCL_DEVICE_ONLY__
266
269
267
- template <typename T, access::address_space Space>
270
+ template <typename CVT, access::address_space Space,
271
+ typename T = std::remove_cv_t <CVT>>
268
272
sycl::detail::enable_if_t <
269
273
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value, T>
270
- load (const multi_ptr<T, Space> src) const {
274
+ load (const multi_ptr<CVT, Space> cv_src) const {
275
+ multi_ptr<T, Space> src = const_cast <T *>(static_cast <CVT *>(cv_src));
271
276
#ifdef __SYCL_DEVICE_ONLY__
272
277
#ifdef __NVPTX__
273
278
return src.get ()[get_local_id ()[0 ]];
@@ -281,10 +286,12 @@ struct sub_group {
281
286
#endif
282
287
}
283
288
284
- template <typename T, access::address_space Space>
289
+ template <typename CVT, access::address_space Space,
290
+ typename T = std::remove_cv_t <CVT>>
285
291
sycl::detail::enable_if_t <
286
292
sycl::detail::sub_group::AcceptableForLocalLoadStore<T, Space>::value, T>
287
- load (const multi_ptr<T, Space> src) const {
293
+ load (const multi_ptr<CVT, Space> cv_src) const {
294
+ multi_ptr<T, Space> src = const_cast <T *>(static_cast <CVT *>(cv_src));
288
295
#ifdef __SYCL_DEVICE_ONLY__
289
296
return src.get ()[get_local_id ()[0 ]];
290
297
#else
@@ -295,75 +302,88 @@ struct sub_group {
295
302
}
296
303
#ifdef __SYCL_DEVICE_ONLY__
297
304
#ifdef __NVPTX__
298
- template <int N, typename T, access::address_space Space>
305
+ template <int N, typename CVT, access::address_space Space,
306
+ typename T = std::remove_cv_t <CVT>>
299
307
sycl::detail::enable_if_t <
300
308
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value,
301
309
vec<T, N>>
302
- load (const multi_ptr<T, Space> src) const {
310
+ load (const multi_ptr<CVT, Space> cv_src) const {
311
+ multi_ptr<T, Space> src = const_cast <T *>(static_cast <CVT *>(cv_src));
303
312
vec<T, N> res;
304
313
for (int i = 0 ; i < N; ++i) {
305
314
res[i] = *(src.get () + i * get_max_local_range ()[0 ] + get_local_id ()[0 ]);
306
315
}
307
316
return res;
308
317
}
309
318
#else // __NVPTX__
310
- template <int N, typename T, access::address_space Space>
319
+ template <int N, typename CVT, access::address_space Space,
320
+ typename T = std::remove_cv_t <CVT>>
311
321
sycl::detail::enable_if_t <
312
322
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
313
323
N != 1 && N != 3 && N != 16 ,
314
324
vec<T, N>>
315
- load (const multi_ptr<T, Space> src) const {
325
+ load (const multi_ptr<CVT, Space> cv_src) const {
326
+ multi_ptr<T, Space> src = const_cast <T *>(static_cast <CVT *>(cv_src));
316
327
return sycl::detail::sub_group::load<N, T>(src);
317
328
}
318
329
319
- template <int N, typename T, access::address_space Space>
330
+ template <int N, typename CVT, access::address_space Space,
331
+ typename T = std::remove_cv_t <CVT>>
320
332
sycl::detail::enable_if_t <
321
333
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
322
334
N == 16 ,
323
335
vec<T, 16 >>
324
- load (const multi_ptr<T, Space> src) const {
336
+ load (const multi_ptr<CVT, Space> cv_src) const {
337
+ multi_ptr<T, Space> src = const_cast <T *>(static_cast <CVT *>(cv_src));
325
338
return {sycl::detail::sub_group::load<8 , T>(src),
326
339
sycl::detail::sub_group::load<8 , T>(src +
327
340
8 * get_max_local_range ()[0 ])};
328
341
}
329
342
330
- template <int N, typename T, access::address_space Space>
343
+ template <int N, typename CVT, access::address_space Space,
344
+ typename T = std::remove_cv_t <CVT>>
331
345
sycl::detail::enable_if_t <
332
346
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
333
347
N == 3 ,
334
348
vec<T, 3 >>
335
- load (const multi_ptr<T, Space> src) const {
349
+ load (const multi_ptr<CVT, Space> cv_src) const {
350
+ multi_ptr<T, Space> src = const_cast <T *>(static_cast <CVT *>(cv_src));
336
351
return {
337
352
sycl::detail::sub_group::load<1 , T>(src),
338
353
sycl::detail::sub_group::load<2 , T>(src + get_max_local_range ()[0 ])};
339
354
}
340
355
341
- template <int N, typename T, access::address_space Space>
356
+ template <int N, typename CVT, access::address_space Space,
357
+ typename T = std::remove_cv_t <CVT>>
342
358
sycl::detail::enable_if_t <
343
359
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value &&
344
360
N == 1 ,
345
361
vec<T, 1 >>
346
- load (const multi_ptr<T, Space> src) const {
362
+ load (const multi_ptr<CVT, Space> cv_src) const {
363
+ multi_ptr<T, Space> src = const_cast <T *>(static_cast <CVT *>(cv_src));
347
364
return sycl::detail::sub_group::load (src);
348
365
}
349
366
#endif // ___NVPTX___
350
367
#else // __SYCL_DEVICE_ONLY__
351
- template <int N, typename T, access::address_space Space>
368
+ template <int N, typename CVT, access::address_space Space,
369
+ typename T = std::remove_cv_t <CVT>>
352
370
sycl::detail::enable_if_t <
353
371
sycl::detail::sub_group::AcceptableForGlobalLoadStore<T, Space>::value,
354
372
vec<T, N>>
355
- load (const multi_ptr<T , Space> src) const {
373
+ load (const multi_ptr<CVT , Space> src) const {
356
374
(void )src;
357
375
throw runtime_error (" Sub-groups are not supported on host device." ,
358
376
PI_INVALID_DEVICE);
359
377
}
360
378
#endif // __SYCL_DEVICE_ONLY__
361
379
362
- template <int N, typename T, access::address_space Space>
380
+ template <int N, typename CVT, access::address_space Space,
381
+ typename T = std::remove_cv_t <CVT>>
363
382
sycl::detail::enable_if_t <
364
383
sycl::detail::sub_group::AcceptableForLocalLoadStore<T, Space>::value,
365
384
vec<T, N>>
366
- load (const multi_ptr<T, Space> src) const {
385
+ load (const multi_ptr<CVT, Space> cv_src) const {
386
+ multi_ptr<T, Space> src = const_cast <T *>(static_cast <CVT *>(cv_src));
367
387
#ifdef __SYCL_DEVICE_ONLY__
368
388
vec<T, N> res;
369
389
for (int i = 0 ; i < N; ++i) {
0 commit comments