Skip to content

Commit 24ea3c6

Browse files
CUDA: implement __hmax and __hmax2 for CUDA < 11.7
1 parent c4ec9c0 commit 24ea3c6

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

ggml-cuda/common.cuh

+17-4
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@
137137
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
138138

139139
#define WARP_SIZE 32
140-
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
140+
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
141+
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
141142

142143
#define CC_PASCAL 600
143144
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
@@ -293,8 +294,20 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
293294
return x;
294295
}
295296

297+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
298+
static __device__ __forceinline__ half __hmax(const half a, const half b) {
299+
return __half2float(a) > __half2float(b) ? a : b;
300+
}
301+
static __device__ __forceinline__ half2 __hmax2(const half2 a, const half2 b) {
302+
half2 ret;
303+
reinterpret_cast<half&>(ret.x) = __low2float(a) > __low2float(b) ? __low2half(a) : __low2half(b);
304+
reinterpret_cast<half&>(ret.y) = __high2float(a) > __high2float(b) ? __high2half(a) : __high2half(b);
305+
return ret;
306+
}
307+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
308+
296309
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
297-
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
310+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
298311
#pragma unroll
299312
for (int mask = 16; mask > 0; mask >>= 1) {
300313
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
@@ -303,10 +316,10 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
303316
#else
304317
GGML_UNUSED(x);
305318
NO_DEVICE_CODE;
306-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
319+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
307320
}
308321

309-
#if CUDART_VERSION < 12000
322+
#if CUDART_VERSION < CUDART_HMASK
310323
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
311324
const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
312325
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));

0 commit comments

Comments
 (0)