Skip to content

Commit 59ceaf4

Browse files
authored
[SYCL] Specialize atomic fetch_min/fetch_max for FP types (#3297)
Minor implementation details aside, this is a follow-up to #2765. The end-to-end tests are already done, the latest update being intel/llvm-test-suite#118. Signed-off-by: Artem Gindinson <[email protected]>
1 parent 98505e4 commit 59ceaf4

File tree

5 files changed

+89
-31
lines changed

5 files changed

+89
-31
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
#include <CL/__spirv/spirv_types.hpp>
1111
#include <CL/sycl/detail/defines.hpp>
1212
#include <CL/sycl/detail/export.hpp>
13+
#include <CL/sycl/detail/stl_type_traits.hpp>
1314
#include <cstddef>
1415
#include <cstdint>
15-
#include <type_traits>
1616

1717
// Convergent attribute
1818
#ifdef __SYCL_DEVICE_ONLY__
@@ -91,6 +91,10 @@ extern SYCL_EXTERNAL TempRetT __spirv_ImageSampleExplicitLod(SampledType,
9191
extern SYCL_EXTERNAL Type __spirv_AtomicUMin( \
9292
AS Type *P, __spv::Scope::Flag S, __spv::MemorySemanticsMask::Flag O, \
9393
Type V);
94+
#define __SPIRV_ATOMIC_FMIN(AS, Type) \
95+
extern SYCL_EXTERNAL Type __spirv_AtomicFMinEXT( \
96+
AS Type *P, __spv::Scope::Flag S, __spv::MemorySemanticsMask::Flag O, \
97+
Type V);
9498
#define __SPIRV_ATOMIC_SMAX(AS, Type) \
9599
extern SYCL_EXTERNAL Type __spirv_AtomicSMax( \
96100
AS Type *P, __spv::Scope::Flag S, __spv::MemorySemanticsMask::Flag O, \
@@ -99,6 +103,10 @@ extern SYCL_EXTERNAL TempRetT __spirv_ImageSampleExplicitLod(SampledType,
99103
extern SYCL_EXTERNAL Type __spirv_AtomicUMax( \
100104
AS Type *P, __spv::Scope::Flag S, __spv::MemorySemanticsMask::Flag O, \
101105
Type V);
106+
#define __SPIRV_ATOMIC_FMAX(AS, Type) \
107+
extern SYCL_EXTERNAL Type __spirv_AtomicFMaxEXT( \
108+
AS Type *P, __spv::Scope::Flag S, __spv::MemorySemanticsMask::Flag O, \
109+
Type V);
102110
#define __SPIRV_ATOMIC_AND(AS, Type) \
103111
extern SYCL_EXTERNAL Type __spirv_AtomicAnd( \
104112
AS Type *P, __spv::Scope::Flag S, __spv::MemorySemanticsMask::Flag O, \
@@ -114,6 +122,8 @@ extern SYCL_EXTERNAL TempRetT __spirv_ImageSampleExplicitLod(SampledType,
114122

115123
#define __SPIRV_ATOMIC_FLOAT(AS, Type) \
116124
__SPIRV_ATOMIC_FADD(AS, Type) \
125+
__SPIRV_ATOMIC_FMIN(AS, Type) \
126+
__SPIRV_ATOMIC_FMAX(AS, Type) \
117127
__SPIRV_ATOMIC_LOAD(AS, Type) \
118128
__SPIRV_ATOMIC_STORE(AS, Type) \
119129
__SPIRV_ATOMIC_EXCHANGE(AS, Type)
@@ -138,21 +148,30 @@ extern SYCL_EXTERNAL TempRetT __spirv_ImageSampleExplicitLod(SampledType,
138148
__SPIRV_ATOMIC_UMAX(AS, Type)
139149

140150
// Helper atomic operations which select correct signed/unsigned version
141-
// of atomic min/max based on the signed-ness of the type
151+
// of atomic min/max based on the type
142152
#define __SPIRV_ATOMIC_MINMAX(AS, Op) \
143153
template <typename T> \
144-
typename std::enable_if<std::is_signed<T>::value, T>::type \
154+
typename cl::sycl::detail::enable_if_t< \
155+
std::is_integral<T>::value && std::is_signed<T>::value, T> \
145156
__spirv_Atomic##Op(AS T *Ptr, __spv::Scope::Flag Memory, \
146157
__spv::MemorySemanticsMask::Flag Semantics, \
147158
T Value) { \
148159
return __spirv_AtomicS##Op(Ptr, Memory, Semantics, Value); \
149160
} \
150161
template <typename T> \
151-
typename std::enable_if<!std::is_signed<T>::value, T>::type \
162+
typename cl::sycl::detail::enable_if_t< \
163+
std::is_integral<T>::value && !std::is_signed<T>::value, T> \
152164
__spirv_Atomic##Op(AS T *Ptr, __spv::Scope::Flag Memory, \
153165
__spv::MemorySemanticsMask::Flag Semantics, \
154166
T Value) { \
155167
return __spirv_AtomicU##Op(Ptr, Memory, Semantics, Value); \
168+
} \
169+
template <typename T> \
170+
typename cl::sycl::detail::enable_if_t<std::is_floating_point<T>::value, T> \
171+
__spirv_Atomic##Op(AS T *Ptr, __spv::Scope::Flag Memory, \
172+
__spv::MemorySemanticsMask::Flag Semantics, \
173+
T Value) { \
174+
return __spirv_AtomicF##Op##EXT(Ptr, Memory, Semantics, Value); \
156175
}
157176

158177
#define __SPIRV_ATOMICS(macro, Arg) \

sycl/include/CL/sycl/ONEAPI/atomic_ref.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,6 @@ class atomic_ref_impl<T, DefaultOrder, DefaultScope, AddressSpace,
413413
};
414414

415415
// Partial specialization for floating-point types
416-
// TODO: Leverage floating-point SPIR-V atomics instead of emulation
417416
template <typename T, memory_order DefaultOrder, memory_scope DefaultScope,
418417
access::address_space AddressSpace>
419418
class atomic_ref_impl<
@@ -486,22 +485,34 @@ class atomic_ref_impl<
486485

487486
T fetch_min(T operand, memory_order order = default_read_modify_write_order,
488487
memory_scope scope = default_scope) const noexcept {
488+
// TODO: Remove the "native atomics" macro check once implemented for all
489+
// backends
490+
#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_USE_NATIVE_FP_ATOMICS)
491+
return detail::spirv::AtomicMin(ptr, scope, order, operand);
492+
#else
489493
auto load_order = detail::getLoadOrder(order);
490494
T old = load(load_order, scope);
491495
while (operand < old &&
492496
!compare_exchange_weak(old, operand, order, scope)) {
493497
}
494498
return old;
499+
#endif
495500
}
496501

497502
T fetch_max(T operand, memory_order order = default_read_modify_write_order,
498503
memory_scope scope = default_scope) const noexcept {
504+
// TODO: Remove the "native atomics" macro check once implemented for all
505+
// backends
506+
#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_USE_NATIVE_FP_ATOMICS)
507+
return detail::spirv::AtomicMax(ptr, scope, order, operand);
508+
#else
499509
auto load_order = detail::getLoadOrder(order);
500510
T old = load(load_order, scope);
501511
while (operand > old &&
502512
!compare_exchange_weak(old, operand, order, scope)) {
503513
}
504514
return old;
515+
#endif
505516
}
506517

507518
private:

sycl/include/CL/sycl/detail/spirv.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,16 @@ AtomicMin(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
435435
return __spirv_AtomicMin(Ptr, SPIRVScope, SPIRVOrder, Value);
436436
}
437437

438+
template <typename T, access::address_space AddressSpace>
439+
inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
440+
AtomicMin(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
441+
ONEAPI::memory_order Order, T Value) {
442+
auto *Ptr = MPtr.get();
443+
auto SPIRVOrder = getMemorySemanticsMask(Order);
444+
auto SPIRVScope = getScope(Scope);
445+
return __spirv_AtomicMin(Ptr, SPIRVScope, SPIRVOrder, Value);
446+
}
447+
438448
template <typename T, access::address_space AddressSpace>
439449
inline typename detail::enable_if_t<std::is_integral<T>::value, T>
440450
AtomicMax(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
@@ -445,6 +455,16 @@ AtomicMax(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
445455
return __spirv_AtomicMax(Ptr, SPIRVScope, SPIRVOrder, Value);
446456
}
447457

458+
template <typename T, access::address_space AddressSpace>
459+
inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
460+
AtomicMax(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
461+
ONEAPI::memory_order Order, T Value) {
462+
auto *Ptr = MPtr.get();
463+
auto SPIRVOrder = getMemorySemanticsMask(Order);
464+
auto SPIRVScope = getScope(Scope);
465+
return __spirv_AtomicMax(Ptr, SPIRVScope, SPIRVOrder, Value);
466+
}
467+
448468
// Native shuffles map directly to a shuffle intrinsic:
449469
// - The Intel SPIR-V extension natively supports all arithmetic types
450470
// - The CUDA shfl intrinsics do not support vectors, and we use the _i32

sycl/test/atomic_ref/max.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -DSYCL_USE_NATIVE_FP_ATOMICS \
2+
// RUN: -fsycl-device-only -S %s -o - | FileCheck %s --check-prefix=CHECK-LLVM
13
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -fsycl-device-only -S %s -o - \
2-
// RUN: | FileCheck %s --check-prefix=CHECK-LLVM
4+
// RUN: | FileCheck %s --check-prefix=CHECK-LLVM-EMU
35
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -fsycl-targets=%sycl_triple %s -o %t.out
46
// RUN: %RUN_ON_HOST %t.out
57

@@ -83,19 +85,21 @@ int main() {
8385
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicUMax
8486
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32, i64)
8587
max_test<unsigned long long>(q, N);
86-
// CHECK-LLVM: declare dso_local spir_func i32
87-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicLoad
88-
// CHECK-LLVM-SAME: (i32 addrspace(1)*, i32, i32)
89-
// CHECK-LLVM: declare dso_local spir_func i32
90-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicCompareExchange
91-
// CHECK-LLVM-SAME: (i32 addrspace(1)*, i32, i32, i32, i32, i32)
88+
// CHECK-LLVM: declare dso_local spir_func float
89+
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicFMaxEXT
90+
// CHECK-LLVM-SAME: (float addrspace(1)*, i32, i32, float)
91+
// CHECK-LLVM-EMU: declare {{.*}} i32 @{{.*}}__spirv_AtomicLoad
92+
// CHECK-LLVM-EMU-SAME: (i32 addrspace(1)*, i32, i32)
93+
// CHECK-LLVM-EMU: declare {{.*}} i32 @{{.*}}__spirv_AtomicCompareExchange
94+
// CHECK-LLVM-EMU-SAME: (i32 addrspace(1)*, i32, i32, i32, i32, i32)
9295
max_test<float>(q, N);
93-
// CHECK-LLVM: declare dso_local spir_func i64
94-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicLoad
95-
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32)
96-
// CHECK-LLVM: declare dso_local spir_func i64
97-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicCompareExchange
98-
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32, i32, i64, i64)
96+
// CHECK-LLVM: declare dso_local spir_func double
97+
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicFMaxEXT
98+
// CHECK-LLVM-SAME: (double addrspace(1)*, i32, i32, double)
99+
// CHECK-LLVM-EMU: declare {{.*}} i64 @{{.*}}__spirv_AtomicLoad
100+
// CHECK-LLVM-EMU-SAME: (i64 addrspace(1)*, i32, i32)
101+
// CHECK-LLVM-EMU: declare {{.*}} i64 @{{.*}}__spirv_AtomicCompareExchange
102+
// CHECK-LLVM-EMU-SAME: (i64 addrspace(1)*, i32, i32, i32, i64, i64)
99103
max_test<double>(q, N);
100104

101105
std::cout << "Test passed." << std::endl;

sycl/test/atomic_ref/min.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -DSYCL_USE_NATIVE_FP_ATOMICS \
2+
// RUN: -fsycl-device-only -S %s -o - | FileCheck %s --check-prefix=CHECK-LLVM
13
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -fsycl-device-only -S %s -o - \
2-
// RUN: | FileCheck %s --check-prefix=CHECK-LLVM
4+
// RUN: | FileCheck %s --check-prefix=CHECK-LLVM-EMU
35
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -fsycl-targets=%sycl_triple %s -o %t.out
46
// RUN: %RUN_ON_HOST %t.out
57

@@ -81,19 +83,21 @@ int main() {
8183
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicUMin
8284
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32, i64)
8385
min_test<unsigned long long>(q, N);
84-
// CHECK-LLVM: declare dso_local spir_func i32
85-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicLoad
86-
// CHECK-LLVM-SAME: (i32 addrspace(1)*, i32, i32)
87-
// CHECK-LLVM: declare dso_local spir_func i32
88-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicCompareExchange
89-
// CHECK-LLVM-SAME: (i32 addrspace(1)*, i32, i32, i32, i32, i32)
86+
// CHECK-LLVM: declare dso_local spir_func float
87+
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicFMinEXT
88+
// CHECK-LLVM-SAME: (float addrspace(1)*, i32, i32, float)
89+
// CHECK-LLVM-EMU: declare {{.*}} i32 @{{.*}}__spirv_AtomicLoad
90+
// CHECK-LLVM-EMU-SAME: (i32 addrspace(1)*, i32, i32)
91+
// CHECK-LLVM-EMU: declare {{.*}} i32 @{{.*}}__spirv_AtomicCompareExchange
92+
// CHECK-LLVM-EMU-SAME: (i32 addrspace(1)*, i32, i32, i32, i32, i32)
9093
min_test<float>(q, N);
91-
// CHECK-LLVM: declare dso_local spir_func i64
92-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicLoad
93-
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32)
94-
// CHECK-LLVM: declare dso_local spir_func i64
95-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicCompareExchange
96-
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32, i32, i64, i64)
94+
// CHECK-LLVM: declare dso_local spir_func double
95+
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicFMinEXT
96+
// CHECK-LLVM-SAME: (double addrspace(1)*, i32, i32, double)
97+
// CHECK-LLVM-EMU: declare {{.*}} i64 @{{.*}}__spirv_AtomicLoad
98+
// CHECK-LLVM-EMU-SAME: (i64 addrspace(1)*, i32, i32)
99+
// CHECK-LLVM-EMU: declare {{.*}} i64 @{{.*}}__spirv_AtomicCompareExchange
100+
// CHECK-LLVM-EMU-SAME: (i64 addrspace(1)*, i32, i32, i32, i64, i64)
97101
min_test<double>(q, N);
98102

99103
std::cout << "Test passed." << std::endl;

0 commit comments

Comments
 (0)