|
137 | 137 | #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
|
138 | 138 |
|
139 | 139 | #define WARP_SIZE 32
|
140 |
| -#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) |
| 140 | +#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) |
| 141 | +#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons |
141 | 142 |
|
142 | 143 | #define CC_PASCAL 600
|
143 | 144 | #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
@@ -293,20 +294,54 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
293 | 294 | return x;
|
294 | 295 | }
|
295 | 296 |
|
| 297 | +static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { |
| 298 | +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) |
| 299 | + |
| 300 | +#if CUDART_VERSION >= CUDART_HMAX |
| 301 | + return __hmax(a, b); |
| 302 | +#else |
| 303 | + return __half2float(a) > __half2float(b) ? a : b; |
| 304 | +#endif // CUDART_VERSION >= CUDART_HMAX |
| 305 | + |
| 306 | +#else |
| 307 | + GGML_UNUSED(a); |
| 308 | + GGML_UNUSED(b); |
| 309 | + NO_DEVICE_CODE; |
| 310 | +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX |
| 311 | +} |
| 312 | +static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) { |
| 313 | +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) |
| 314 | + |
| 315 | +#if CUDART_VERSION >= CUDART_HMAX |
| 316 | + return __hmax2(a, b); |
| 317 | +#else |
| 318 | + half2 ret; |
| 319 | + reinterpret_cast<half&>(ret.x) = __low2float(a) > __low2float(b) ? __low2half(a) : __low2half(b); |
| 320 | + reinterpret_cast<half&>(ret.y) = __high2float(a) > __high2float(b) ? __high2half(a) : __high2half(b); |
| 321 | + return ret; |
| 322 | +#endif // CUDART_VERSION >= CUDART_HMAX |
| 323 | + |
| 324 | +#else |
| 325 | + GGML_UNUSED(a); |
| 326 | + GGML_UNUSED(b); |
| 327 | + NO_DEVICE_CODE; |
| 328 | +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX |
| 329 | +} |
| 330 | + |
296 | 331 | static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
297 |
| -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX |
| 332 | +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL |
298 | 333 | #pragma unroll
|
299 | 334 | for (int mask = 16; mask > 0; mask >>= 1) {
|
300 |
| - x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); |
| 335 | + x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); |
301 | 336 | }
|
302 | 337 | return x;
|
303 | 338 | #else
|
304 | 339 | GGML_UNUSED(x);
|
305 | 340 | NO_DEVICE_CODE;
|
306 |
| -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX |
| 341 | +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL |
307 | 342 | }
|
308 | 343 |
|
309 |
| -#if CUDART_VERSION < 12000 |
| 344 | +#if CUDART_VERSION < CUDART_HMASK |
310 | 345 | static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
|
311 | 346 | const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
|
312 | 347 | const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
|
|
0 commit comments