Skip to content

Commit cc4c4e3

Browse files
Engininja2YellowRoseCx
authored andcommitted
New __dp4a assembly
Now compatible with gfx900 and faster as well.
1 parent 1a03b70 commit cc4c4e3

File tree

1 file changed

+11
-17
lines changed

1 file changed

+11
-17
lines changed

ggml-cuda.cu

+11-17
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
#include "ggml.h"
7373

7474
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
75-
#define CC_TURING 10000
75+
#define CC_TURING 1000000000
7676

7777
#if defined(GGML_USE_HIPBLAS)
7878
#define __CUDA_ARCH__ 1300
@@ -88,24 +88,18 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
8888
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
8989
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
9090
c = __builtin_amdgcn_sdot4(a, b, c, false);
91-
#elif defined(__gfx1010__)// || defined(__gfx900__)
92-
int ashift;
93-
int bshift;
94-
int aext;
95-
int bext;
91+
#elif defined(__gfx1010__) || defined(__gfx900__)
92+
int tmp1;
93+
int tmp2;
9694
asm("\n \
97-
v_pk_ashrrev_i16 %1, 0x80008, %5 \n \
98-
v_pk_ashrrev_i16 %2, 0x80008, %6 \n \
99-
v_mov_b32_sdwa %3, sext(%5) dst_sel:WORD_1 src0_sel:BYTE_2 \n \
100-
v_mov_b32_sdwa %3, sext(%5) dst_sel:WORD_0 dst_unused:UNUSED_PRESERVE src0_sel:BYTE_0 \n \
101-
v_mov_b32_sdwa %4, sext(%6) dst_sel:WORD_1 src0_sel:BYTE_2 \n \
102-
v_mov_b32_sdwa %4, sext(%6) dst_sel:WORD_0 dst_unused:UNUSED_PRESERVE src0_sel:BYTE_0 \n \
103-
v_mad_i32_i16 %0, %1, %2, %0 op_sel:[0, 0, 0, 0] \n \
104-
v_mad_i32_i16 %0, %1, %2, %0 op_sel:[1, 1, 0, 0] \n \
105-
v_mad_i32_i16 %0, %3, %4, %0 op_sel:[0, 0, 0, 0] \n \
106-
v_mad_i32_i16 %0, %3, %4, %0 op_sel:[1, 1, 0, 0] \n \
95+
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
96+
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
97+
v_add3_u32 %0, %1, %2, %0 \n \
98+
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
99+
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
100+
v_add3_u32 %0, %1, %2, %0 \n \
107101
"
108-
: "+v"(c), "=&v"(ashift), "=&v"(bshift), "=&v"(aext), "=&v"(bext)
102+
: "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
109103
: "v"(a), "v"(b)
110104
);
111105
#else

0 commit comments

Comments
 (0)