Skip to content

Commit 72d1735

Browse files
[NFCI][SYCL] Refactor address space casts functionality
1 parent 23fed07 commit 72d1735

File tree

6 files changed

+180
-117
lines changed

6 files changed

+180
-117
lines changed

sycl/include/sycl/access/access.hpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,137 @@ template <typename ToT, typename FromT> inline ToT cast_AS(FromT from) {
377377
}
378378
}
379379

380+
#ifdef __SYCL_DEVICE_ONLY__
381+
inline constexpr bool
382+
address_space_cast_is_possible(access::address_space Src,
383+
access::address_space Dst) {
384+
auto generic_space = access::address_space::generic_space;
385+
if (Src == Dst || Src == generic_space || Dst == generic_space)
386+
return true;
387+
388+
// global_host/global_device could be casted to/from global
389+
auto global_space = access::address_space::global_space;
390+
auto global_device = access::address_space::ext_intel_global_device_space;
391+
auto global_host = access::address_space::ext_intel_global_host_space;
392+
393+
if (Src == global_space || Dst == global_space) {
394+
auto Other = Src == global_space ? Dst : Src;
395+
if (Other == global_device || Other == global_host)
396+
return true;
397+
}
398+
399+
// No more compatible combinations.
400+
return false;
401+
}
402+
template <access::address_space Space, typename ElementType>
403+
auto static_address_cast(ElementType *Ptr) {
404+
constexpr auto generic_space = access::address_space::generic_space;
405+
constexpr auto global_space = access::address_space::global_space;
406+
constexpr auto local_space = access::address_space::local_space;
407+
constexpr auto private_space = access::address_space::private_space;
408+
constexpr auto global_device =
409+
access::address_space::ext_intel_global_device_space;
410+
constexpr auto global_host =
411+
access::address_space::ext_intel_global_host_space;
412+
413+
constexpr auto SrcAS = deduce_AS<ElementType *>::value;
414+
static_assert(address_space_cast_is_possible(SrcAS, Space));
415+
416+
using dst_type = typename DecoratedType<
417+
std::remove_pointer_t<remove_decoration_t<ElementType *>>, Space>::type *;
418+
419+
// Note: reinterpret_cast isn't enough for some of the casts between different
420+
// address spaces, use C-style cast instead.
421+
#if !defined(__SPIR__)
422+
return (dst_type)Ptr;
423+
#else
424+
if constexpr (SrcAS != generic_space) {
425+
return (dst_type)Ptr;
426+
} else if constexpr (Space == global_space) {
427+
return (dst_type)__spirv_GenericCastToPtr_ToGlobal(
428+
Ptr, __spv::StorageClass::CrossWorkgroup);
429+
} else if constexpr (Space == local_space) {
430+
return (dst_type)__spirv_GenericCastToPtr_ToLocal(
431+
Ptr, __spv::StorageClass::Workgroup);
432+
} else if constexpr (Space == private_space) {
433+
return (dst_type)__spirv_GenericCastToPtr_ToPrivate(
434+
Ptr, __spv::StorageClass::Function);
435+
#if !defined(__ENABLE_USM_ADDR_SPACE__)
436+
} else if constexpr (Space == global_device || Space == global_host) {
437+
// If __ENABLE_USM_ADDR_SPACE__ isn't defined then both
438+
// global_device/global_host are just aliases for global_space.
439+
return (dst_type)__spirv_GenericCastToPtr_ToGlobal(
440+
Ptr, __spv::StorageClass::CrossWorkgroup);
441+
#endif
442+
} else {
443+
return (dst_type)Ptr;
444+
}
445+
#endif
446+
}
447+
template <access::address_space Space, typename ElementType>
448+
auto dynamic_address_cast(ElementType *Ptr) {
449+
constexpr auto generic_space = access::address_space::generic_space;
450+
constexpr auto global_space = access::address_space::global_space;
451+
constexpr auto local_space = access::address_space::local_space;
452+
constexpr auto private_space = access::address_space::private_space;
453+
constexpr auto global_device =
454+
access::address_space::ext_intel_global_device_space;
455+
constexpr auto global_host =
456+
access::address_space::ext_intel_global_host_space;
457+
458+
constexpr auto SrcAS = deduce_AS<ElementType *>::value;
459+
using dst_type = typename DecoratedType<
460+
std::remove_pointer_t<remove_decoration_t<ElementType *>>, Space>::type *;
461+
462+
if constexpr (!address_space_cast_is_possible(SrcAS, Space)) {
463+
return (dst_type) nullptr;
464+
} else if constexpr (Space == generic_space) {
465+
return (dst_type)Ptr;
466+
} else if constexpr (Space == global_space &&
467+
(SrcAS == global_device || SrcAS == global_host)) {
468+
return (dst_type)Ptr;
469+
} else if constexpr (SrcAS == global_space &&
470+
(Space == global_device || Space == global_host)) {
471+
#if defined(__ENABLE_USM_ADDR_SPACE__)
472+
static_assert(Space != Space, "Not supported yet!");
473+
#else
474+
// If __ENABLE_USM_ADDR_SPACE__ isn't defined then both
475+
// global_device/global_host are just aliases for global_space.
476+
static_assert(std::is_same_v<dst_type, ElementType *>);
477+
return (dst_type)Ptr;
478+
#endif
479+
#if defined(__SPIR__)
480+
} else if constexpr (Space == global_space) {
481+
return (dst_type)__spirv_GenericCastToPtrExplicit_ToGlobal(
482+
Ptr, __spv::StorageClass::CrossWorkgroup);
483+
} else if constexpr (Space == local_space) {
484+
return (dst_type)__spirv_GenericCastToPtrExplicit_ToLocal(
485+
Ptr, __spv::StorageClass::Workgroup);
486+
} else if constexpr (Space == private_space) {
487+
return (dst_type)__spirv_GenericCastToPtrExplicit_ToPrivate(
488+
Ptr, __spv::StorageClass::Function);
489+
#if !defined(__ENABLE_USM_ADDR_SPACE__)
490+
} else if constexpr (SrcAS == generic_space &&
491+
(Space == global_device || Space == global_host)) {
492+
return (dst_type)__spirv_GenericCastToPtrExplicit_ToGlobal(
493+
Ptr, __spv::StorageClass::CrossWorkgroup);
494+
#endif
495+
#endif
496+
} else {
497+
static_assert(Space != Space, "Not supported yet!");
498+
return (dst_type) nullptr;
499+
}
500+
}
501+
#else // __SYCL_DEVICE_ONLY__
502+
template <access::address_space Space, typename ElementType>
503+
auto static_address_cast(ElementType *Ptr) {
504+
return Ptr;
505+
}
506+
template <access::address_space Space, typename ElementType>
507+
auto dynamic_address_cast(ElementType *Ptr) {
508+
return Ptr;
509+
}
510+
#endif // __SYCL_DEVICE_ONLY__
380511
} // namespace detail
381512

382513
#undef __OPENCL_GLOBAL_AS__

sycl/include/sycl/ext/oneapi/experimental/address_cast.hpp

Lines changed: 14 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,47 +13,25 @@
1313

1414
namespace sycl {
1515
inline namespace _V1 {
16-
namespace ext {
17-
namespace oneapi {
18-
namespace experimental {
16+
namespace ext::oneapi ::experimental {
1917
namespace detail {
2018
using namespace sycl::detail;
2119
}
2220
// Shorthands for address space names
23-
constexpr inline access::address_space global_space = access::address_space::global_space;
24-
constexpr inline access::address_space local_space = access::address_space::local_space;
25-
constexpr inline access::address_space private_space = access::address_space::private_space;
26-
constexpr inline access::address_space generic_space = access::address_space::generic_space;
21+
constexpr inline access::address_space global_space =
22+
access::address_space::global_space;
23+
constexpr inline access::address_space local_space =
24+
access::address_space::local_space;
25+
constexpr inline access::address_space private_space =
26+
access::address_space::private_space;
27+
constexpr inline access::address_space generic_space =
28+
access::address_space::generic_space;
2729

2830
template <access::address_space Space, typename ElementType>
2931
multi_ptr<ElementType, Space, access::decorated::no>
3032
static_address_cast(ElementType *Ptr) {
3133
using ret_ty = multi_ptr<ElementType, Space, access::decorated::no>;
32-
#ifdef __SYCL_DEVICE_ONLY__
33-
static_assert(std::is_same_v<ElementType, remove_decoration_t<ElementType>>,
34-
"The extension expects undecorated raw pointers only!");
35-
if constexpr (Space == generic_space) {
36-
// Undecorated raw pointer is in generic AS already, no extra casts needed.
37-
return ret_ty(Ptr);
38-
} else if constexpr (Space == access::address_space::
39-
ext_intel_global_device_space ||
40-
Space ==
41-
access::address_space::ext_intel_global_host_space) {
42-
#ifdef __ENABLE_USM_ADDR_SPACE__
43-
// No SPIR-V intrinsic for this yet.
44-
using raw_type = detail::DecoratedType<ElementType, Space>::type *;
45-
auto CastPtr = (raw_type)(Ptr);
46-
#else
47-
auto CastPtr = sycl::detail::spirv::GenericCastToPtr<global_space>(Ptr);
48-
#endif
49-
return ret_ty(CastPtr);
50-
} else {
51-
auto CastPtr = sycl::detail::spirv::GenericCastToPtr<Space>(Ptr);
52-
return ret_ty(CastPtr);
53-
}
54-
#else
55-
return ret_ty(Ptr);
56-
#endif
34+
return ret_ty{detail::static_address_cast<Space>(Ptr)};
5735
}
5836

5937
template <access::address_space Space, access::decorated DecorateAddress,
@@ -63,39 +41,14 @@ multi_ptr<ElementType, Space, DecorateAddress> static_address_cast(
6341
if constexpr (Space == generic_space)
6442
return Ptr;
6543
else
66-
return {static_address_cast<Space>(Ptr.get_raw())};
44+
return {static_address_cast<Space>(Ptr.get_decorated())};
6745
}
6846

6947
template <access::address_space Space, typename ElementType>
7048
multi_ptr<ElementType, Space, access::decorated::no>
7149
dynamic_address_cast(ElementType *Ptr) {
7250
using ret_ty = multi_ptr<ElementType, Space, access::decorated::no>;
73-
#ifdef __SYCL_DEVICE_ONLY__
74-
static_assert(std::is_same_v<ElementType, remove_decoration_t<ElementType>>,
75-
"The extension expects undecorated raw pointers only!");
76-
if constexpr (Space == generic_space) {
77-
return ret_ty(Ptr);
78-
} else if constexpr (Space == access::address_space::
79-
ext_intel_global_device_space ||
80-
Space ==
81-
access::address_space::ext_intel_global_host_space) {
82-
#ifdef __ENABLE_USM_ADDR_SPACE__
83-
static_assert(
84-
Space != access::address_space::ext_intel_global_device_space &&
85-
Space != access::address_space::ext_intel_global_host_space,
86-
"Not supported yet!");
87-
return ret_ty(nullptr);
88-
#else
89-
auto CastPtr = sycl::detail::spirv::GenericCastToPtr<global_space>(Ptr);
90-
return ret_ty(CastPtr);
91-
#endif
92-
} else {
93-
auto CastPtr = sycl::detail::spirv::GenericCastToPtrExplicit<Space>(Ptr);
94-
return ret_ty(CastPtr);
95-
}
96-
#else
97-
return ret_ty(Ptr);
98-
#endif
51+
return ret_ty{detail::dynamic_address_cast<Space>(Ptr)};
9952
}
10053

10154
template <access::address_space Space, access::decorated DecorateAddress,
@@ -105,11 +58,9 @@ multi_ptr<ElementType, Space, DecorateAddress> dynamic_address_cast(
10558
if constexpr (Space == generic_space)
10659
return Ptr;
10760
else
108-
return {dynamic_address_cast<Space>(Ptr.get_raw())};
61+
return {dynamic_address_cast<Space>(Ptr.get_decorated())};
10962
}
11063

111-
} // namespace experimental
112-
} // namespace oneapi
113-
} // namespace ext
64+
} // namespace ext::oneapi::experimental
11465
} // namespace _V1
11566
} // namespace sycl

sycl/include/sycl/ext/oneapi/experimental/group_load_store.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,11 +217,10 @@ auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
217217
if constexpr (AS == access::address_space::global_space) {
218218
return is_aligned ? reinterpret_cast<block_pointer_type>(iter) : nullptr;
219219
} else if constexpr (AS == access::address_space::generic_space) {
220-
return is_aligned
221-
? reinterpret_cast<block_pointer_type>(
222-
__SYCL_GenericCastToPtrExplicit_ToGlobal<value_type>(
223-
iter))
224-
: nullptr;
220+
return is_aligned ? reinterpret_cast<block_pointer_type>(
221+
detail::dynamic_address_cast<
222+
access::address_space::global_space>(iter))
223+
: nullptr;
225224
} else {
226225
return nullptr;
227226
}

sycl/include/sycl/sub_group.hpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,14 @@ struct sub_group {
228228
#if defined(__NVPTX__) || defined(__AMDGCN__)
229229
return src[get_local_id()[0]];
230230
#else // __NVPTX__ || __AMDGCN__
231-
auto l = __SYCL_GenericCastToPtrExplicit_ToLocal<T>(src);
232-
if (l)
231+
if (auto l =
232+
detail::dynamic_address_cast<access::address_space::local_space>(
233+
src))
233234
return load(l);
234235

235-
auto g = __SYCL_GenericCastToPtrExplicit_ToGlobal<T>(src);
236-
if (g)
236+
if (auto g =
237+
detail::dynamic_address_cast<access::address_space::global_space>(
238+
src))
237239
return load(g);
238240

239241
// Sub-group load() is supported for local or global pointers only.
@@ -418,14 +420,16 @@ struct sub_group {
418420
#if defined(__NVPTX__) || defined(__AMDGCN__)
419421
dst[get_local_id()[0]] = x;
420422
#else // __NVPTX__ || __AMDGCN__
421-
auto l = __SYCL_GenericCastToPtrExplicit_ToLocal<T>(dst);
422-
if (l) {
423+
if (auto l =
424+
detail::dynamic_address_cast<access::address_space::local_space>(
425+
dst)) {
423426
store(l, x);
424427
return;
425428
}
426429

427-
auto g = __SYCL_GenericCastToPtrExplicit_ToGlobal<T>(dst);
428-
if (g) {
430+
if (auto g =
431+
detail::dynamic_address_cast<access::address_space::global_space>(
432+
dst)) {
429433
store(g, x);
430434
return;
431435
}
Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
// RUN: %clangxx -fsycl-device-only -O3 -S -emit-llvm -Xclang -no-enable-noundef-analysis %s -o - | FileCheck %s --check-prefix CHECK-O3
2-
// RUN: %clangxx -fsycl-device-only -O0 -S -emit-llvm -Xclang -no-enable-noundef-analysis %s -o - | FileCheck %s --check-prefix CHECK-O0
3-
// Test compilation with -O3 when all methods are inlined in kernel function
4-
// and -O0 when helper methods are preserved.
1+
// RUN: %clangxx -fsycl-device-only -O3 -S -emit-llvm -Xclang -no-enable-noundef-analysis %s -o - | FileCheck %s
52
#include <cstdint>
63
#include <cstdio>
74
#include <cstdlib>
@@ -38,49 +35,30 @@ SYCL_EXTERNAL void test(sycl::accessor<int, 1, sycl::access::mode::read_write,
3835
}
3936

4037
// clang-format off
41-
// CHECK-O3: call spir_func void {{.*}}spirv_ControlBarrierjjj
38+
// CHECK: call spir_func void {{.*}}spirv_ControlBarrierjjj
4239

4340
// load() for global address space
44-
// CHECK-O3: call spir_func ptr addrspace(3) {{.*}}spirv_GenericCastToPtrExplicit_ToLocal{{.*}}(ptr addrspace(4)
45-
// CHECK-O3: {{.*}}SubgroupLocalInvocationId
46-
// CHECK-O3: call spir_func ptr addrspace(1) {{.*}}spirv_GenericCastToPtrExplicit_ToGlobal{{.*}}(ptr addrspace(4)
47-
// CHECK-O3: call spir_func i32 {{.*}}spirv_SubgroupBlockRead{{.*}}(ptr addrspace(1)
41+
// CHECK: call spir_func ptr addrspace(3) {{.*}}spirv_GenericCastToPtrExplicit_ToLocal{{.*}}(ptr addrspace(4)
42+
// CHECK: {{.*}}SubgroupLocalInvocationId
43+
// CHECK: call spir_func ptr addrspace(1) {{.*}}spirv_GenericCastToPtrExplicit_ToGlobal{{.*}}(ptr addrspace(4)
44+
// CHECK: call spir_func i32 {{.*}}spirv_SubgroupBlockRead{{.*}}(ptr addrspace(1)
4845

4946

5047
// load() for local address space
51-
// CHECK-O3: call spir_func ptr addrspace(3) {{.*}}spirv_GenericCastToPtrExplicit_ToLocal{{.*}}(ptr addrspace(4)
52-
// CHECK-O3: {{.*}}SubgroupLocalInvocationId
53-
// CHECK-O3: call spir_func ptr addrspace(1) {{.*}}spirv_GenericCastToPtrExplicit_ToGlobal{{.*}}(ptr addrspace(4)
54-
// CHECK-O3: call spir_func i32 {{.*}}spirv_SubgroupBlockRead{{.*}}(ptr addrspace(1)
48+
// CHECK: call spir_func ptr addrspace(3) {{.*}}spirv_GenericCastToPtrExplicit_ToLocal{{.*}}(ptr addrspace(4)
49+
// CHECK: {{.*}}SubgroupLocalInvocationId
50+
// CHECK: call spir_func ptr addrspace(1) {{.*}}spirv_GenericCastToPtrExplicit_ToGlobal{{.*}}(ptr addrspace(4)
51+
// CHECK: call spir_func i32 {{.*}}spirv_SubgroupBlockRead{{.*}}(ptr addrspace(1)
5552

5653
// load() for private address space
57-
// CHECK-O3: call spir_func ptr addrspace(3) {{.*}}spirv_GenericCastToPtrExplicit_ToLocal{{.*}}(ptr addrspace(4)
58-
// CHECK-O3: {{.*}}SubgroupLocalInvocationId
59-
// CHECK-O3: call spir_func ptr addrspace(1) {{.*}}spirv_GenericCastToPtrExplicit_ToGlobal{{.*}}(ptr addrspace(4)
60-
// CHECK-O3: call spir_func i32 {{.*}}spirv_SubgroupBlockRead{{.*}}(ptr addrspace(1)
54+
// CHECK: call spir_func ptr addrspace(3) {{.*}}spirv_GenericCastToPtrExplicit_ToLocal{{.*}}(ptr addrspace(4)
55+
// CHECK: {{.*}}SubgroupLocalInvocationId
56+
// CHECK: call spir_func ptr addrspace(1) {{.*}}spirv_GenericCastToPtrExplicit_ToGlobal{{.*}}(ptr addrspace(4)
57+
// CHECK: call spir_func i32 {{.*}}spirv_SubgroupBlockRead{{.*}}(ptr addrspace(1)
6158

6259
// store() for global address space
6360
// NOTE: Call to __spirv_GenericCastToPtrExplicit_ToLocal is consolidated with an earlier call to it.
64-
// CHECK-O3: {{.*}}SubgroupLocalInvocationId
65-
// CHECK-O3: call spir_func ptr addrspace(1) {{.*}}spirv_GenericCastToPtrExplicit_ToGlobal{{.*}}(ptr addrspace(4)
66-
// CHECK-O3: call spir_func void {{.*}}spirv_SubgroupBlockWriteINTEL{{.*}}(ptr addrspace(1)
67-
68-
// load() accepting raw pointers method
69-
// CHECK-O0: define{{.*}}spir_func i32 {{.*}}4sycl3_V19sub_group4load{{.*}}addrspace(4) %
70-
// CHECK-O0: call spir_func ptr addrspace(3) {{.*}}SYCL_GenericCastToPtrExplicit_ToLocal{{.*}}(ptr addrspace(4)
71-
// CHECK-O0: call spir_func i32 {{.*}}sycl3_V19sub_group4load{{.*}}ptr addrspace(3) %
72-
// CHECK-O0: call spir_func ptr addrspace(1) {{.*}}SYCL_GenericCastToPtrExplicit_ToGlobal{{.*}}(ptr addrspace(4)
73-
// CHECK-O0: call spir_func i32 {{.*}}sycl3_V19sub_group4load{{.*}}ptr addrspace(1) %
74-
75-
// store() accepting raw pointers method
76-
// CHECK-O0: define{{.*}}spir_func void {{.*}}4sycl3_V19sub_group5store{{.*}}ptr addrspace(4) %
77-
// CHECK-O0: call spir_func ptr addrspace(3) {{.*}}SYCL_GenericCastToPtrExplicit_ToLocal{{.*}}(ptr addrspace(4)
78-
// CHECK-O0: call spir_func void {{.*}}4sycl3_V19sub_group5store{{.*}}, ptr addrspace(3) %
79-
// CHECK-O0: call spir_func ptr addrspace(1) {{.*}}SYCL_GenericCastToPtrExplicit_ToGlobal{{.*}}(ptr addrspace(4)
80-
// CHECK-O0: call spir_func void {{.*}}4sycl3_V19sub_group5store{{.*}}, ptr addrspace(1) %
81-
82-
// CHECK-O0: define {{.*}}spir_func ptr addrspace(3) {{.*}}SYCL_GenericCastToPtrExplicit_ToLocal{{.*}}(ptr addrspace(4) %
83-
// CHECK-O0: call spir_func ptr addrspace(3) {{.*}}spirv_GenericCastToPtrExplicit_ToLocal{{.*}}(ptr addrspace(4)
84-
// CHECK-O0: define {{.*}}spir_func ptr addrspace(1) {{.*}}SYCL_GenericCastToPtrExplicit_ToGlobal{{.*}}(ptr addrspace(4) %
85-
// CHECK-O0: call spir_func ptr addrspace(1) {{.*}}spirv_GenericCastToPtrExplicit_ToGlobal{{.*}}(ptr addrspace(4)
61+
// CHECK: {{.*}}SubgroupLocalInvocationId
62+
// CHECK: call spir_func ptr addrspace(1) {{.*}}spirv_GenericCastToPtrExplicit_ToGlobal{{.*}}(ptr addrspace(4)
63+
// CHECK: call spir_func void {{.*}}spirv_SubgroupBlockWriteINTEL{{.*}}(ptr addrspace(1)
8664
// clang-format off

sycl/test/extensions/address_cast_negative.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
using namespace sycl::ext::oneapi::experimental;
66

77
SYCL_EXTERNAL void test(int *p) {
8-
// expected-error-re@sycl/ext/oneapi/experimental/address_cast.hpp:* {{{{.*}}Not supported yet!}}
8+
// expected-error-re@sycl/access/access.hpp:* {{{{.*}}Not supported yet!}}
99
std::ignore = dynamic_address_cast<
1010
sycl::access::address_space::ext_intel_global_device_space>(p);
11-
// expected-error-re@sycl/ext/oneapi/experimental/address_cast.hpp:* {{{{.*}}Not supported yet!}}
11+
// expected-error-re@sycl/access/access.hpp:* {{{{.*}}Not supported yet!}}
1212
std::ignore = dynamic_address_cast<
1313
sycl::access::address_space::ext_intel_global_host_space>(p);
1414
}

0 commit comments

Comments
 (0)