Skip to content

Commit 3544f93

Browse files
JohannesGaesslernopperl
authored andcommitted
CUDA: CUDART < 11.7 workaround for __hmax, __hmax2 (ggml-org#7019)
1 parent 4c00549 commit 3544f93

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

ggml-cuda/common.cuh

+40-5
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@
137137
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
138138

139139
#define WARP_SIZE 32
140-
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
140+
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
141+
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
141142

142143
#define CC_PASCAL 600
143144
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
@@ -293,20 +294,54 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
293294
return x;
294295
}
295296

297+
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
298+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
299+
300+
#if CUDART_VERSION >= CUDART_HMAX
301+
return __hmax(a, b);
302+
#else
303+
return __half2float(a) > __half2float(b) ? a : b;
304+
#endif // CUDART_VERSION >= CUDART_HMAX
305+
306+
#else
307+
GGML_UNUSED(a);
308+
GGML_UNUSED(b);
309+
NO_DEVICE_CODE;
310+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
311+
}
312+
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
313+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
314+
315+
#if CUDART_VERSION >= CUDART_HMAX
316+
return __hmax2(a, b);
317+
#else
318+
half2 ret;
319+
reinterpret_cast<half&>(ret.x) = __low2float(a) > __low2float(b) ? __low2half(a) : __low2half(b);
320+
reinterpret_cast<half&>(ret.y) = __high2float(a) > __high2float(b) ? __high2half(a) : __high2half(b);
321+
return ret;
322+
#endif // CUDART_VERSION >= CUDART_HMAX
323+
324+
#else
325+
GGML_UNUSED(a);
326+
GGML_UNUSED(b);
327+
NO_DEVICE_CODE;
328+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
329+
}
330+
296331
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
297-
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
332+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
298333
#pragma unroll
299334
for (int mask = 16; mask > 0; mask >>= 1) {
300-
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
335+
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
301336
}
302337
return x;
303338
#else
304339
GGML_UNUSED(x);
305340
NO_DEVICE_CODE;
306-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
341+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
307342
}
308343

309-
#if CUDART_VERSION < 12000
344+
#if CUDART_VERSION < CUDART_HMASK
310345
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
311346
const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
312347
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));

ggml-cuda/fattn.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ static __global__ void flash_attn_vec_ext_f16(
116116
sum2 = warp_reduce_sum(sum2);
117117
half sum = __low2half(sum2) + __high2half(sum2);
118118
sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
119-
kqmax_new = __hmax(kqmax_new, sum);
119+
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
120120
if (threadIdx.x == 0) {
121121
KQ[i_KQ] = sum;
122122
}
@@ -416,9 +416,9 @@ static __global__ void flash_attn_ext_f16(
416416
const int k = k0 + threadIdx.x;
417417

418418
KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
419-
KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
419+
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
420420
}
421-
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
421+
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
422422
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
423423
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
424424
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));

0 commit comments

Comments
 (0)