Skip to content

Commit 0b65c98

Browse files
[SYCL] Use static address space cast for atomic_ref ctor in SPIR-V path (intel#15384)
From SYCL 2020 specification: > The sycl::atomic_ref class also has a template parameter AddressSpace, > which allows the application to make an assertion about the address > space of the object of type T that it references. The default value > for this parameter is access::address_space::generic_space, which > indicates that the object could be in either the global or local > address spaces. If the application knows the address space, it can set > this template parameter to either access::address_space::global_space > or access::address_space::local_space as an assertion to the > implementation. Specifying the address space via this template > parameter may allow the implementation to perform certain > optimizations. Specifying an address space that does not match the > object’s actual address space results in undefined behavior We use `ext::oneapi::experimental::static_address_cast` to do that. It's not implemented for CUDA/HIP yet, that path continues using `sycl::address_space_cast` that performs runtime checks: > An implementation must return nullptr if the run-time value of pointer > is not compatible with Space, and must issue a compiletime diagnostic > if the deduced address space for pointer is not compatible with Space.
1 parent 032d36a commit 0b65c98

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

sycl/include/sycl/atomic_ref.hpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
#include <sycl/access/access.hpp> // for address_space
1212
#include <sycl/bit_cast.hpp> // for bit_cast
13-
#include <sycl/memory_enums.hpp> // for getStdMemoryOrder, memory_order
13+
#include <sycl/ext/oneapi/experimental/address_cast.hpp>
14+
#include <sycl/memory_enums.hpp> // for getStdMemoryOrder, memory_order
1415

1516
#ifdef __SYCL_DEVICE_ONLY__
1617
#include <sycl/detail/spirv.hpp>
@@ -157,8 +158,16 @@ class atomic_ref_base {
157158
}
158159

159160
#ifdef __SYCL_DEVICE_ONLY__
161+
#if defined(__SPIR__)
162+
explicit atomic_ref_base(T &ref)
163+
: ptr(ext::oneapi::experimental::static_address_cast<AddressSpace>(
164+
&ref)) {}
165+
#else
166+
// CUDA/HIP don't support `ext::oneapi::experimental::static_address_cast`
167+
// yet.
160168
explicit atomic_ref_base(T &ref)
161169
: ptr(address_space_cast<AddressSpace, access::decorated::no>(&ref)) {}
170+
#endif
162171
#else
163172
// FIXME: This reinterpret_cast is UB, but happens to work for now
164173
explicit atomic_ref_base(T &ref)
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
2+
// RUN: %clangxx -O3 -fsycl -fsycl-device-only -fno-discard-value-names -S -emit-llvm -fno-sycl-instrument-device-code -o - %s | FileCheck %s
3+
4+
#include <sycl/sycl.hpp>
5+
6+
// CHECK-LABEL: define dso_local spir_func noundef i32 @_Z17atomic_ref_globalRi(
7+
// CHECK-SAME: ptr addrspace(4) noundef align 4 dereferenceable(4) [[I:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] !srcloc [[META6:![0-9]+]] !sycl_fixed_targets [[META7:![0-9]+]] {
8+
// CHECK-NEXT: [[ENTRY:.*:]]
9+
// CHECK-NEXT: [[CALL_I_I_I_I_I_I:%.*]] = tail call spir_func noundef ptr addrspace(1) @_Z33__spirv_GenericCastToPtr_ToGlobalPvi(ptr addrspace(4) noundef [[I]], i32 noundef 5) #[[ATTR3:[0-9]+]]
10+
// CHECK-NEXT: [[CALL3_I_I:%.*]] = tail call spir_func noundef i32 @_Z18__spirv_AtomicLoadPU3AS1KiN5__spv5Scope4FlagENS1_19MemorySemanticsMask4FlagE(ptr addrspace(1) noundef [[CALL_I_I_I_I_I_I]], i32 noundef 1, i32 noundef 898) #[[ATTR4:[0-9]+]]
11+
// CHECK-NEXT: ret i32 [[CALL3_I_I]]
12+
//
13+
SYCL_EXTERNAL auto atomic_ref_global(int &i) {
14+
// Verify that we use _Z33__spirv_GenericCastToPtr_ToGlobalPvi that doesn't
15+
// perform dynamic address space validation.
16+
sycl::atomic_ref<int, sycl::memory_order::acq_rel, sycl::memory_scope::device,
17+
sycl::access::address_space::global_space>
18+
a(i);
19+
return a.load();
20+
}

0 commit comments

Comments
 (0)