Skip to content

Commit d233b50

Browse files
authored
cuda : add half2 __shfl_xor() for ROCm 5.5 (#7263)
1 parent 0f98acf commit d233b50

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

ggml-cuda/common.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,20 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
315315
#endif
316316
return c;
317317
}
318+
319+
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
320+
// __shfl_xor() for half2 was added in ROCm 5.6
321+
static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
322+
typedef union half2_b32 {
323+
half2 val;
324+
int b32;
325+
} half2_b32_t;
326+
half2_b32_t tmp;
327+
tmp.val = var;
328+
tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
329+
return tmp.val;
330+
}
331+
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
318332
#endif // defined(GGML_USE_HIPBLAS)
319333

320334
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL

0 commit comments

Comments
 (0)