Skip to content

Commit 9ea0f20

Browse files
[SYCL] Refactor address space casts functionality (#15543)
This follows the interfaces designed in https://github.com/intel/llvm/blob/3a1c3cb53566f904a73361d5c57b939d981564b5/sycl/doc/extensions/proposed/sycl_ext_oneapi_address_cast.asciidoc, but instead of operating on `multi_ptr`, these work on decorated C++ pointers (as that's what we need throughout our implementation, including `multi_ptr` implementation itself). Basically, I've moved the implementation of the extension to the new `detail::static|dynamic_address_cast` functions and replaced all uses of the old `detail::cast_AS` (that had inconsistent static vs dynamic behavior depending on address spaces/backends) and also uses of direct SPIRV builtin/wrappers invocations. This isn't NFC, because by doing that I've changed "dynamic" behavior to "static" whenever the spec allows that (e.g. if it's UB if runtime pointers doesn't point to a proper allocation).
1 parent ba99338 commit 9ea0f20

File tree

12 files changed

+269
-436
lines changed

12 files changed

+269
-436
lines changed

sycl/include/sycl/__spirv/spirv_ops.hpp

Lines changed: 0 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -540,190 +540,6 @@ __SPIRV_ATOMICS(__SPIRV_ATOMIC_MINMAX, Max)
540540
#undef __SPIRV_ATOMIC_UNSIGNED
541541
#undef __SPIRV_ATOMIC_XOR
542542

543-
template <typename dataT>
544-
extern __attribute__((opencl_global)) dataT *
545-
__SYCL_GenericCastToPtrExplicit_ToGlobal(void *Ptr) noexcept {
546-
return (__attribute__((opencl_global)) dataT *)
547-
__spirv_GenericCastToPtrExplicit_ToGlobal(
548-
Ptr, __spv::StorageClass::CrossWorkgroup);
549-
}
550-
551-
template <typename dataT>
552-
extern const __attribute__((opencl_global)) dataT *
553-
__SYCL_GenericCastToPtrExplicit_ToGlobal(const void *Ptr) noexcept {
554-
return (const __attribute__((opencl_global)) dataT *)
555-
__spirv_GenericCastToPtrExplicit_ToGlobal(
556-
Ptr, __spv::StorageClass::CrossWorkgroup);
557-
}
558-
559-
template <typename dataT>
560-
extern volatile __attribute__((opencl_global)) dataT *
561-
__SYCL_GenericCastToPtrExplicit_ToGlobal(volatile void *Ptr) noexcept {
562-
return (volatile __attribute__((opencl_global)) dataT *)
563-
__spirv_GenericCastToPtrExplicit_ToGlobal(
564-
Ptr, __spv::StorageClass::CrossWorkgroup);
565-
}
566-
567-
template <typename dataT>
568-
extern const volatile __attribute__((opencl_global)) dataT *
569-
__SYCL_GenericCastToPtrExplicit_ToGlobal(const volatile void *Ptr) noexcept {
570-
return (const volatile __attribute__((opencl_global)) dataT *)
571-
__spirv_GenericCastToPtrExplicit_ToGlobal(
572-
Ptr, __spv::StorageClass::CrossWorkgroup);
573-
}
574-
575-
template <typename dataT>
576-
extern __attribute__((opencl_local)) dataT *
577-
__SYCL_GenericCastToPtrExplicit_ToLocal(void *Ptr) noexcept {
578-
return (__attribute__((opencl_local)) dataT *)
579-
__spirv_GenericCastToPtrExplicit_ToLocal(Ptr,
580-
__spv::StorageClass::Workgroup);
581-
}
582-
583-
template <typename dataT>
584-
extern const __attribute__((opencl_local)) dataT *
585-
__SYCL_GenericCastToPtrExplicit_ToLocal(const void *Ptr) noexcept {
586-
return (const __attribute__((opencl_local)) dataT *)
587-
__spirv_GenericCastToPtrExplicit_ToLocal(Ptr,
588-
__spv::StorageClass::Workgroup);
589-
}
590-
591-
template <typename dataT>
592-
extern volatile __attribute__((opencl_local)) dataT *
593-
__SYCL_GenericCastToPtrExplicit_ToLocal(volatile void *Ptr) noexcept {
594-
return (volatile __attribute__((opencl_local)) dataT *)
595-
__spirv_GenericCastToPtrExplicit_ToLocal(Ptr,
596-
__spv::StorageClass::Workgroup);
597-
}
598-
599-
template <typename dataT>
600-
extern const volatile __attribute__((opencl_local)) dataT *
601-
__SYCL_GenericCastToPtrExplicit_ToLocal(const volatile void *Ptr) noexcept {
602-
return (const volatile __attribute__((opencl_local)) dataT *)
603-
__spirv_GenericCastToPtrExplicit_ToLocal(Ptr,
604-
__spv::StorageClass::Workgroup);
605-
}
606-
607-
template <typename dataT>
608-
extern __attribute__((opencl_private)) dataT *
609-
__SYCL_GenericCastToPtrExplicit_ToPrivate(void *Ptr) noexcept {
610-
return (__attribute__((opencl_private)) dataT *)
611-
__spirv_GenericCastToPtrExplicit_ToPrivate(Ptr,
612-
__spv::StorageClass::Function);
613-
}
614-
615-
template <typename dataT>
616-
extern const __attribute__((opencl_private)) dataT *
617-
__SYCL_GenericCastToPtrExplicit_ToPrivate(const void *Ptr) noexcept {
618-
return (const __attribute__((opencl_private)) dataT *)
619-
__spirv_GenericCastToPtrExplicit_ToPrivate(Ptr,
620-
__spv::StorageClass::Function);
621-
}
622-
623-
template <typename dataT>
624-
extern volatile __attribute__((opencl_private)) dataT *
625-
__SYCL_GenericCastToPtrExplicit_ToPrivate(volatile void *Ptr) noexcept {
626-
return (volatile __attribute__((opencl_private)) dataT *)
627-
__spirv_GenericCastToPtrExplicit_ToPrivate(Ptr,
628-
__spv::StorageClass::Function);
629-
}
630-
631-
template <typename dataT>
632-
extern const volatile __attribute__((opencl_private)) dataT *
633-
__SYCL_GenericCastToPtrExplicit_ToPrivate(const volatile void *Ptr) noexcept {
634-
return (const volatile __attribute__((opencl_private)) dataT *)
635-
__spirv_GenericCastToPtrExplicit_ToPrivate(Ptr,
636-
__spv::StorageClass::Function);
637-
}
638-
639-
template <typename dataT>
640-
extern __attribute__((opencl_global)) dataT *
641-
__SYCL_GenericCastToPtr_ToGlobal(void *Ptr) noexcept {
642-
return (__attribute__((opencl_global)) dataT *)
643-
__spirv_GenericCastToPtr_ToGlobal(Ptr,
644-
__spv::StorageClass::CrossWorkgroup);
645-
}
646-
647-
template <typename dataT>
648-
extern const __attribute__((opencl_global)) dataT *
649-
__SYCL_GenericCastToPtr_ToGlobal(const void *Ptr) noexcept {
650-
return (const __attribute__((opencl_global)) dataT *)
651-
__spirv_GenericCastToPtr_ToGlobal(Ptr,
652-
__spv::StorageClass::CrossWorkgroup);
653-
}
654-
655-
template <typename dataT>
656-
extern volatile __attribute__((opencl_global)) dataT *
657-
__SYCL_GenericCastToPtr_ToGlobal(volatile void *Ptr) noexcept {
658-
return (volatile __attribute__((opencl_global)) dataT *)
659-
__spirv_GenericCastToPtr_ToGlobal(Ptr,
660-
__spv::StorageClass::CrossWorkgroup);
661-
}
662-
663-
template <typename dataT>
664-
extern const volatile __attribute__((opencl_global)) dataT *
665-
__SYCL_GenericCastToPtr_ToGlobal(const volatile void *Ptr) noexcept {
666-
return (const volatile __attribute__((opencl_global)) dataT *)
667-
__spirv_GenericCastToPtr_ToGlobal(Ptr,
668-
__spv::StorageClass::CrossWorkgroup);
669-
}
670-
671-
template <typename dataT>
672-
extern __attribute__((opencl_local)) dataT *
673-
__SYCL_GenericCastToPtr_ToLocal(void *Ptr) noexcept {
674-
return (__attribute__((opencl_local)) dataT *)
675-
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
676-
}
677-
678-
template <typename dataT>
679-
extern const __attribute__((opencl_local)) dataT *
680-
__SYCL_GenericCastToPtr_ToLocal(const void *Ptr) noexcept {
681-
return (const __attribute__((opencl_local)) dataT *)
682-
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
683-
}
684-
685-
template <typename dataT>
686-
extern volatile __attribute__((opencl_local)) dataT *
687-
__SYCL_GenericCastToPtr_ToLocal(volatile void *Ptr) noexcept {
688-
return (volatile __attribute__((opencl_local)) dataT *)
689-
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
690-
}
691-
692-
template <typename dataT>
693-
extern const volatile __attribute__((opencl_local)) dataT *
694-
__SYCL_GenericCastToPtr_ToLocal(const volatile void *Ptr) noexcept {
695-
return (const volatile __attribute__((opencl_local)) dataT *)
696-
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
697-
}
698-
699-
template <typename dataT>
700-
extern __attribute__((opencl_private)) dataT *
701-
__SYCL_GenericCastToPtr_ToPrivate(void *Ptr) noexcept {
702-
return (__attribute__((opencl_private)) dataT *)
703-
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
704-
}
705-
706-
template <typename dataT>
707-
extern const __attribute__((opencl_private)) dataT *
708-
__SYCL_GenericCastToPtr_ToPrivate(const void *Ptr) noexcept {
709-
return (const __attribute__((opencl_private)) dataT *)
710-
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
711-
}
712-
713-
template <typename dataT>
714-
extern volatile __attribute__((opencl_private)) dataT *
715-
__SYCL_GenericCastToPtr_ToPrivate(volatile void *Ptr) noexcept {
716-
return (volatile __attribute__((opencl_private)) dataT *)
717-
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
718-
}
719-
720-
template <typename dataT>
721-
extern const volatile __attribute__((opencl_private)) dataT *
722-
__SYCL_GenericCastToPtr_ToPrivate(const volatile void *Ptr) noexcept {
723-
return (const volatile __attribute__((opencl_private)) dataT *)
724-
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
725-
}
726-
727543
template <typename dataT>
728544
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL dataT
729545
__spirv_SubgroupShuffleINTEL(dataT Data, uint32_t InvocationId) noexcept;

sycl/include/sycl/access/access.hpp

Lines changed: 143 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -325,58 +325,154 @@ template <typename T>
325325
using remove_decoration_t = typename remove_decoration<T>::type;
326326

327327
namespace detail {
328-
329-
// Helper function for selecting appropriate casts between address spaces.
330-
template <typename ToT, typename FromT> inline ToT cast_AS(FromT from) {
331328
#ifdef __SYCL_DEVICE_ONLY__
332-
constexpr access::address_space ToAS = deduce_AS<ToT>::value;
333-
constexpr access::address_space FromAS = deduce_AS<FromT>::value;
334-
if constexpr (FromAS == access::address_space::generic_space) {
335-
#if defined(__NVPTX__) || defined(__AMDGCN__) || defined(__SYCL_NATIVE_CPU__)
336-
// TODO: NVPTX and AMDGCN backends do not currently support the
337-
// __spirv_GenericCastToPtrExplicit_* builtins, so to work around this
338-
// we do C-style casting. This may produce warnings when targetting
339-
// these backends.
340-
return (ToT)from;
329+
inline constexpr bool
330+
address_space_cast_is_possible(access::address_space Src,
331+
access::address_space Dst) {
332+
// constant_space is unique and is not interchangeable with any other.
333+
auto constant_space = access::address_space::constant_space;
334+
if (Src == constant_space || Dst == constant_space)
335+
return Src == Dst;
336+
337+
auto generic_space = access::address_space::generic_space;
338+
if (Src == Dst || Src == generic_space || Dst == generic_space)
339+
return true;
340+
341+
// global_host/global_device could be casted to/from global
342+
auto global_space = access::address_space::global_space;
343+
auto global_device = access::address_space::ext_intel_global_device_space;
344+
auto global_host = access::address_space::ext_intel_global_host_space;
345+
346+
if (Src == global_space || Dst == global_space) {
347+
auto Other = Src == global_space ? Dst : Src;
348+
if (Other == global_device || Other == global_host)
349+
return true;
350+
}
351+
352+
// No more compatible combinations.
353+
return false;
354+
}
355+
356+
template <access::address_space Space, typename ElementType>
357+
auto static_address_cast(ElementType *Ptr) {
358+
constexpr auto generic_space = access::address_space::generic_space;
359+
constexpr auto global_space = access::address_space::global_space;
360+
constexpr auto local_space = access::address_space::local_space;
361+
constexpr auto private_space = access::address_space::private_space;
362+
constexpr auto global_device =
363+
access::address_space::ext_intel_global_device_space;
364+
constexpr auto global_host =
365+
access::address_space::ext_intel_global_host_space;
366+
367+
constexpr auto SrcAS = deduce_AS<ElementType *>::value;
368+
static_assert(address_space_cast_is_possible(SrcAS, Space));
369+
370+
using dst_type = typename DecoratedType<
371+
std::remove_pointer_t<remove_decoration_t<ElementType *>>, Space>::type *;
372+
373+
// Note: reinterpret_cast isn't enough for some of the casts between different
374+
// address spaces, use C-style cast instead.
375+
#if !defined(__SPIR__)
376+
return (dst_type)Ptr;
341377
#else
342-
using ToElemT = std::remove_pointer_t<remove_decoration_t<ToT>>;
343-
if constexpr (ToAS == access::address_space::global_space)
344-
return __SYCL_GenericCastToPtrExplicit_ToGlobal<ToElemT>(from);
345-
else if constexpr (ToAS == access::address_space::local_space)
346-
return __SYCL_GenericCastToPtrExplicit_ToLocal<ToElemT>(from);
347-
else if constexpr (ToAS == access::address_space::private_space)
348-
return __SYCL_GenericCastToPtrExplicit_ToPrivate<ToElemT>(from);
349-
#ifdef __ENABLE_USM_ADDR_SPACE__
350-
else if constexpr (ToAS == access::address_space::
351-
ext_intel_global_device_space ||
352-
ToAS ==
353-
access::address_space::ext_intel_global_host_space)
354-
// For extended address spaces we do not currently have a SPIR-V
355-
// conversion function, so we do a C-style cast. This may produce
356-
// warnings.
357-
return (ToT)from;
358-
#endif // __ENABLE_USM_ADDR_SPACE__
359-
else
360-
return reinterpret_cast<ToT>(from);
361-
#endif // defined(__NVPTX__) || defined(__AMDGCN__)
362-
} else
363-
#ifdef __ENABLE_USM_ADDR_SPACE__
364-
if constexpr (FromAS == access::address_space::global_space &&
365-
(ToAS ==
366-
access::address_space::ext_intel_global_device_space ||
367-
ToAS ==
368-
access::address_space::ext_intel_global_host_space)) {
369-
// Casting from global address space to the global device and host address
370-
// spaces is allowed.
371-
return (ToT)from;
372-
} else
373-
#endif // __ENABLE_USM_ADDR_SPACE__
374-
#endif // __SYCL_DEVICE_ONLY__
375-
{
376-
return reinterpret_cast<ToT>(from);
378+
if constexpr (SrcAS != generic_space) {
379+
return (dst_type)Ptr;
380+
} else if constexpr (Space == global_space) {
381+
return (dst_type)__spirv_GenericCastToPtr_ToGlobal(
382+
Ptr, __spv::StorageClass::CrossWorkgroup);
383+
} else if constexpr (Space == local_space) {
384+
return (dst_type)__spirv_GenericCastToPtr_ToLocal(
385+
Ptr, __spv::StorageClass::Workgroup);
386+
} else if constexpr (Space == private_space) {
387+
return (dst_type)__spirv_GenericCastToPtr_ToPrivate(
388+
Ptr, __spv::StorageClass::Function);
389+
#if !defined(__ENABLE_USM_ADDR_SPACE__)
390+
} else if constexpr (Space == global_device || Space == global_host) {
391+
// If __ENABLE_USM_ADDR_SPACE__ isn't defined then both
392+
// global_device/global_host are just aliases for global_space.
393+
return (dst_type)__spirv_GenericCastToPtr_ToGlobal(
394+
Ptr, __spv::StorageClass::CrossWorkgroup);
395+
#endif
396+
} else {
397+
return (dst_type)Ptr;
377398
}
399+
#endif
378400
}
379401

402+
// Previous implementation (`castAS`, used in `multi_ptr` ctors among other
403+
// places), used C-style cast instead of a proper dynamic check for some
404+
// backends/spaces. `SupressNotImplementedAssert = true` parameter is emulating
405+
// that previous behavior until the proper support is added for compatibility
406+
// reasons.
407+
template <access::address_space Space, bool SupressNotImplementedAssert = false,
408+
typename ElementType>
409+
auto dynamic_address_cast(ElementType *Ptr) {
410+
constexpr auto generic_space = access::address_space::generic_space;
411+
constexpr auto global_space = access::address_space::global_space;
412+
constexpr auto local_space = access::address_space::local_space;
413+
constexpr auto private_space = access::address_space::private_space;
414+
constexpr auto global_device =
415+
access::address_space::ext_intel_global_device_space;
416+
constexpr auto global_host =
417+
access::address_space::ext_intel_global_host_space;
418+
419+
constexpr auto SrcAS = deduce_AS<ElementType *>::value;
420+
using dst_type = typename DecoratedType<
421+
std::remove_pointer_t<remove_decoration_t<ElementType *>>, Space>::type *;
422+
423+
if constexpr (!address_space_cast_is_possible(SrcAS, Space)) {
424+
return (dst_type) nullptr;
425+
} else if constexpr (Space == generic_space) {
426+
return (dst_type)Ptr;
427+
} else if constexpr (Space == global_space &&
428+
(SrcAS == global_device || SrcAS == global_host)) {
429+
return (dst_type)Ptr;
430+
} else if constexpr (SrcAS == global_space &&
431+
(Space == global_device || Space == global_host)) {
432+
#if defined(__ENABLE_USM_ADDR_SPACE__)
433+
static_assert(SupressNotImplementedAssert || Space != Space,
434+
"Not supported yet!");
435+
return static_address_cast<Space>(Ptr);
436+
#else
437+
// If __ENABLE_USM_ADDR_SPACE__ isn't defined then both
438+
// global_device/global_host are just aliases for global_space.
439+
static_assert(std::is_same_v<dst_type, ElementType *>);
440+
return (dst_type)Ptr;
441+
#endif
442+
#if defined(__SPIR__)
443+
} else if constexpr (Space == global_space) {
444+
return (dst_type)__spirv_GenericCastToPtrExplicit_ToGlobal(
445+
Ptr, __spv::StorageClass::CrossWorkgroup);
446+
} else if constexpr (Space == local_space) {
447+
return (dst_type)__spirv_GenericCastToPtrExplicit_ToLocal(
448+
Ptr, __spv::StorageClass::Workgroup);
449+
} else if constexpr (Space == private_space) {
450+
return (dst_type)__spirv_GenericCastToPtrExplicit_ToPrivate(
451+
Ptr, __spv::StorageClass::Function);
452+
#if !defined(__ENABLE_USM_ADDR_SPACE__)
453+
} else if constexpr (SrcAS == generic_space &&
454+
(Space == global_device || Space == global_host)) {
455+
return (dst_type)__spirv_GenericCastToPtrExplicit_ToGlobal(
456+
Ptr, __spv::StorageClass::CrossWorkgroup);
457+
#endif
458+
#endif
459+
} else {
460+
static_assert(SupressNotImplementedAssert || Space != Space,
461+
"Not supported yet!");
462+
return static_address_cast<Space>(Ptr);
463+
}
464+
}
465+
#else // __SYCL_DEVICE_ONLY__
466+
template <access::address_space Space, typename ElementType>
467+
auto static_address_cast(ElementType *Ptr) {
468+
return Ptr;
469+
}
470+
template <access::address_space Space, bool SupressNotImplementedAssert = false,
471+
typename ElementType>
472+
auto dynamic_address_cast(ElementType *Ptr) {
473+
return Ptr;
474+
}
475+
#endif // __SYCL_DEVICE_ONLY__
380476
} // namespace detail
381477

382478
#undef __OPENCL_GLOBAL_AS__

0 commit comments

Comments
 (0)