@@ -274,58 +274,99 @@ template <int Dimensions = 1> class group {
274
274
__spirv_MemoryBarrier (__spv::Scope::Workgroup, flags);
275
275
}
276
276
277
+ // / Asynchronously copies a number of elements specified by \p numElements
278
+ // / from the source pointed by \p src to destination pointed by \p dest
279
+ // / with a source stride specified by \p srcStride, and returns a SYCL
280
+ // / device_event which can be used to wait on the completion of the copy.
281
+ // / Permitted types for dataT are all scalar and vector types, except boolean.
277
282
template <typename dataT>
278
- device_event async_work_group_copy (local_ptr< dataT> dest,
279
- global_ptr<dataT> src,
280
- size_t numElements ) const {
283
+ detail:: enable_if_t <!detail::is_bool< dataT>::value, device_event>
284
+ async_work_group_copy (local_ptr<dataT> dest, global_ptr<dataT> src,
285
+ size_t numElements, size_t srcStride ) const {
281
286
using DestT = detail::ConvertToOpenCLType_t<decltype (dest)>;
282
287
using SrcT = detail::ConvertToOpenCLType_t<decltype (src)>;
283
288
284
- __ocl_event_t e = OpGroupAsyncCopyGlobalToLocal (
289
+ __ocl_event_t E = OpGroupAsyncCopyGlobalToLocal (
285
290
__spv::Scope::Workgroup, DestT (dest.get ()), SrcT (src.get ()),
286
- numElements, 1 , 0 );
287
- return device_event (&e );
291
+ numElements, srcStride , 0 );
292
+ return device_event (&E );
288
293
}
289
294
295
+ // / Asynchronously copies a number of elements specified by \p numElements
296
+ // / from the source pointed by \p src to destination pointed by \p dest with
297
+ // / the destination stride specified by \p destStride, and returns a SYCL
298
+ // / device_event which can be used to wait on the completion of the copy.
299
+ // / Permitted types for dataT are all scalar and vector types, except boolean.
290
300
template <typename dataT>
291
- device_event async_work_group_copy (global_ptr< dataT> dest,
292
- local_ptr<dataT> src,
293
- size_t numElements ) const {
301
+ detail:: enable_if_t <!detail::is_bool< dataT>::value, device_event>
302
+ async_work_group_copy (global_ptr<dataT> dest, local_ptr<dataT> src,
303
+ size_t numElements, size_t destStride ) const {
294
304
using DestT = detail::ConvertToOpenCLType_t<decltype (dest)>;
295
305
using SrcT = detail::ConvertToOpenCLType_t<decltype (src)>;
296
306
297
- __ocl_event_t e = OpGroupAsyncCopyLocalToGlobal (
307
+ __ocl_event_t E = OpGroupAsyncCopyLocalToGlobal (
298
308
__spv::Scope::Workgroup, DestT (dest.get ()), SrcT (src.get ()),
299
- numElements, 1 , 0 );
300
- return device_event (&e);
309
+ numElements, destStride, 0 );
310
+ return device_event (&E);
311
+ }
312
+
313
+ // / Specialization for scalar bool type.
314
+ // / Asynchronously copies a number of elements specified by \p NumElements
315
+ // / from the source pointed by \p Src to destination pointed by \p Dest
316
+ // / with a stride specified by \p Stride, and returns a SYCL device_event
317
+ // / which can be used to wait on the completion of the copy.
318
+ template <typename T, access::address_space DestS, access::address_space SrcS>
319
+ detail::enable_if_t <detail::is_scalar_bool<T>::value, device_event>
320
+ async_work_group_copy (multi_ptr<T, DestS> Dest, multi_ptr<T, SrcS> Src,
321
+ size_t NumElements, size_t Stride) const {
322
+ static_assert (sizeof (bool ) == sizeof (uint8_t ),
323
+ " Async copy to/from bool memory is not supported." );
324
+ auto DestP =
325
+ multi_ptr<uint8_t , DestS>(reinterpret_cast <uint8_t *>(Dest.get ()));
326
+ auto SrcP =
327
+ multi_ptr<uint8_t , SrcS>(reinterpret_cast <uint8_t *>(Src.get ()));
328
+ return async_work_group_copy (DestP, SrcP, NumElements, Stride);
329
+ }
330
+
331
+ // / Specialization for vector bool type.
332
+ // / Asynchronously copies a number of elements specified by \p NumElements
333
+ // / from the source pointed by \p Src to destination pointed by \p Dest
334
+ // / with a stride specified by \p Stride, and returns a SYCL device_event
335
+ // / which can be used to wait on the completion of the copy.
336
+ template <typename T, access::address_space DestS, access::address_space SrcS>
337
+ detail::enable_if_t <detail::is_vector_bool<T>::value, device_event>
338
+ async_work_group_copy (multi_ptr<T, DestS> Dest, multi_ptr<T, SrcS> Src,
339
+ size_t NumElements, size_t Stride) const {
340
+ static_assert (sizeof (bool ) == sizeof (uint8_t ),
341
+ " Async copy to/from bool memory is not supported." );
342
+ using VecT = detail::change_base_type_t <T, uint8_t >;
343
+ auto DestP = multi_ptr<VecT, DestS>(reinterpret_cast <VecT *>(Dest.get ()));
344
+ auto SrcP = multi_ptr<VecT, SrcS>(reinterpret_cast <VecT *>(Src.get ()));
345
+ return async_work_group_copy (DestP, SrcP, NumElements, Stride);
301
346
}
302
347
348
+ // / Asynchronously copies a number of elements specified by \p numElements
349
+ // / from the source pointed by \p src to destination pointed by \p dest and
350
+ // / returns a SYCL device_event which can be used to wait on the completion
351
+ // / of the copy.
352
+ // / Permitted types for dataT are all scalar and vector types.
303
353
template <typename dataT>
304
354
device_event async_work_group_copy (local_ptr<dataT> dest,
305
355
global_ptr<dataT> src,
306
- size_t numElements,
307
- size_t srcStride) const {
308
- using DestT = detail::ConvertToOpenCLType_t<decltype (dest)>;
309
- using SrcT = detail::ConvertToOpenCLType_t<decltype (src)>;
310
-
311
- __ocl_event_t e = OpGroupAsyncCopyGlobalToLocal (
312
- __spv::Scope::Workgroup, DestT (dest.get ()), SrcT (src.get ()),
313
- numElements, srcStride, 0 );
314
- return device_event (&e);
356
+ size_t numElements) const {
357
+ return async_work_group_copy (dest, src, numElements, 1 );
315
358
}
316
359
360
+ // / Asynchronously copies a number of elements specified by \p numElements
361
+ // / from the source pointed by \p src to destination pointed by \p dest and
362
+ // / returns a SYCL device_event which can be used to wait on the completion
363
+ // / of the copy.
364
+ // / Permitted types for dataT are all scalar and vector types.
317
365
template <typename dataT>
318
366
device_event async_work_group_copy (global_ptr<dataT> dest,
319
367
local_ptr<dataT> src,
320
- size_t numElements,
321
- size_t destStride) const {
322
- using DestT = detail::ConvertToOpenCLType_t<decltype (dest)>;
323
- using SrcT = detail::ConvertToOpenCLType_t<decltype (src)>;
324
-
325
- __ocl_event_t e = OpGroupAsyncCopyLocalToGlobal (
326
- __spv::Scope::Workgroup, DestT (dest.get ()), SrcT (src.get ()),
327
- numElements, destStride, 0 );
328
- return device_event (&e);
368
+ size_t numElements) const {
369
+ return async_work_group_copy (dest, src, numElements, 1 );
329
370
}
330
371
331
372
template <typename ... eventTN>
0 commit comments