Skip to content

Commit b066553

Browse files
jeffbolznvtinglou
authored andcommitted
vulkan: Add VK_NV_cooperative_matrix2 support for mul_mat and flash attention (ggml-org#10206)
1 parent 02d6378 commit b066553

File tree

6 files changed

+1669
-101
lines changed

6 files changed

+1669
-101
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

+671-80
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
find_package (Threads REQUIRED)
2+
find_package(Vulkan COMPONENTS glslc REQUIRED)
23

34
set(TARGET vulkan-shaders-gen)
45
add_executable(${TARGET} vulkan-shaders-gen.cpp)
56
install(TARGETS ${TARGET} RUNTIME)
67
target_compile_features(${TARGET} PRIVATE cxx_std_17)
78
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
9+
target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
2+
#include "types.comp"
3+
4+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
5+
block_q4_0_packed16 block;
6+
};
7+
8+
float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
9+
{
10+
const float16_t d = bl.block.d;
11+
const uint idx = coordInBlock[1];
12+
const uint shift = (idx & 0x10) >> 2;
13+
uint32_t qs = unpack8(uint32_t(bl.block.qs[(idx & 0xE) >> 1]))[idx & 1];
14+
qs >>= shift;
15+
qs &= 0xF;
16+
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
17+
return ret;
18+
}
19+
20+
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
21+
block_q4_1 block;
22+
};
23+
24+
float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
25+
{
26+
const float16_t d = bl.block.d;
27+
const float16_t m = bl.block.m;
28+
const uint idx = coordInBlock[1];
29+
const uint iqs = idx & 0xF;
30+
const uint shift = (idx & 0x10) >> 2;
31+
uint32_t qs = bl.block.qs[iqs];
32+
qs >>= shift;
33+
qs &= 0xF;
34+
float16_t ret = float16_t(qs) * d + m;
35+
return ret;
36+
}
37+
38+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
39+
block_q5_0 block;
40+
};
41+
42+
float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
43+
{
44+
const float16_t d = bl.block.d;
45+
const uint idx = coordInBlock[1];
46+
const uint iqs = idx & 0xF;
47+
48+
const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];
49+
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
50+
51+
const uint shift = (idx & 0x10) >> 2;
52+
uint32_t qs = bl.block.qs[iqs];
53+
qs >>= shift;
54+
qs &= 0xF;
55+
56+
float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;
57+
return ret;
58+
}
59+
60+
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
61+
block_q5_1 block;
62+
};
63+
64+
float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
65+
{
66+
const float16_t d = bl.block.d;
67+
const float16_t m = bl.block.m;
68+
const uint idx = coordInBlock[1];
69+
const uint iqs = idx & 0xF;
70+
71+
const uint uint_qh = bl.block.qh;
72+
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
73+
74+
const uint shift = (idx & 0x10) >> 2;
75+
uint32_t qs = bl.block.qs[iqs];
76+
qs >>= shift;
77+
qs &= 0xF;
78+
79+
float16_t ret = float16_t(qs | qh) * d + m;
80+
return ret;
81+
}
82+
83+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
84+
block_q8_0_packed16 block;
85+
};
86+
87+
float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
88+
{
89+
const float16_t d = bl.block.d;
90+
const uint idx = coordInBlock[1];
91+
const uint iqs = idx;
92+
93+
// Load 16b and select the byte for this element
94+
int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1];
95+
float16_t ret = float16_t(qs) * d;
96+
return ret;
97+
}
98+
99+
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
100+
block_q2_K block;
101+
};
102+
103+
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
104+
{
105+
const f16vec2 d = bl.block.d;
106+
const uint idx = coordInBlock[1];
107+
const uint iqs = idx;
108+
109+
const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31
110+
const uint scalesi = iqs / 16; // 0..15
111+
const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6
112+
113+
uint32_t qs = bl.block.qs[qsi];
114+
const uint scales = bl.block.scales[scalesi];
115+
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4);
116+
return ret;
117+
}
118+
119+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
120+
block_q3_K block;
121+
};
122+
123+
float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
124+
{
125+
const uint idx = coordInBlock[1];
126+
const uint iqs = idx;
127+
128+
const uint n = iqs / 128; // 0,1
129+
const uint qsi = n * 32 + (iqs % 32); // 0..63
130+
const uint hmi = (iqs % 32); // 0..31
131+
const uint j = (iqs % 128) / 8; // 0..15
132+
const uint is = iqs / 16; // 0..15
133+
const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3
134+
const uint qsshift = halfsplit * 2; // 0,2,4,6
135+
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
136+
137+
uint32_t scaleidx0 = (is < 8) ? is : (is-8);
138+
uint32_t scaleidx0shift = (is < 8) ? 0 : 4;
139+
uint32_t scaleidx1 = is + 8 - (is/4)*4;
140+
uint32_t scaleidx1shift = (is/4)*2;
141+
142+
const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
143+
144+
const float16_t dl = bl.block.d * float16_t(us - 32);
145+
146+
float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4));
147+
148+
return ret;
149+
}
150+
151+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
152+
block_q4_K block;
153+
};
154+
155+
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
156+
{
157+
const uint idx = coordInBlock[1];
158+
const uint iqs = idx;
159+
160+
const uint n = iqs / 64; // 0,1,2,3
161+
const uint b = (iqs % 64) / 32; // 0,1
162+
const uint is = (idx & 0xE0) >> 5; // 0..7
163+
const uint qsi = n * 32 + (iqs % 32); // 0..127
164+
165+
const f16vec2 loadd = bl.block.d;
166+
167+
uint32_t sc;
168+
uint32_t mbyte;
169+
170+
uint32_t scidx0 = (is < 4) ? is : (is + 4);
171+
uint32_t scidx1 = (is < 4) ? is : (is - 4);
172+
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
173+
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
174+
uint32_t mbidx0 = is + 4;
175+
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
176+
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
177+
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
178+
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
179+
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
180+
181+
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
182+
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
183+
184+
const float16_t d = loadd.x * float16_t(sc);
185+
const float16_t m = loadd.y * float16_t(mbyte);
186+
187+
uint32_t dmask = 0xF << (b * 4);
188+
189+
float16_t ret = d * float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) - m;
190+
191+
return ret;
192+
}
193+
194+
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
195+
block_q5_K block;
196+
};
197+
198+
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
199+
{
200+
const uint idx = coordInBlock[1];
201+
const uint iqs = idx;
202+
203+
const uint n = iqs / 64; // 0,1,2,3
204+
const uint b = (iqs % 64) / 32; // 0,1
205+
const uint is = (idx & 0xE0) >> 5; // 0..7
206+
const uint qsi = n * 32 + (iqs % 32); // 0..127
207+
const uint qhi = (iqs % 32); // 0..31
208+
209+
const uint8_t hm = uint8_t(1 << (iqs / 32));
210+
211+
const f16vec2 loadd = bl.block.d;
212+
213+
uint32_t sc;
214+
uint32_t mbyte;
215+
216+
uint32_t scidx0 = (is < 4) ? is : (is + 4);
217+
uint32_t scidx1 = (is < 4) ? is : (is - 4);
218+
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
219+
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
220+
uint32_t mbidx0 = is + 4;
221+
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
222+
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
223+
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
224+
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
225+
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
226+
227+
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
228+
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
229+
230+
const float16_t d = loadd.x * float16_t(sc);
231+
const float16_t m = loadd.y * float16_t(mbyte);
232+
233+
uint32_t dmask = 0xF << (b * 4);
234+
235+
float16_t ret = d * (float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) + float16_t((bl.block.qh[qhi ] & hm) != 0 ? 16 : 0)) - m;
236+
237+
return ret;
238+
}
239+
240+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
241+
block_q6_K block;
242+
};
243+
244+
float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
245+
{
246+
const uint idx = coordInBlock[1];
247+
const uint iqs = idx;
248+
249+
const uint n = iqs / 128; // 0,1
250+
const uint b = (iqs % 128) / 64; // 0,1
251+
const uint is_b = (iqs % 32) / 16; // 0,1
252+
const uint qhshift = ((iqs % 128) / 32) * 2;// 0,2,4,6
253+
const uint is = 8 * n + qhshift + is_b; // 0..15
254+
const uint qsi = n * 64 + (iqs % 64); // 0..127
255+
const uint qhi = n * 32 + (iqs % 32); // 0..63
256+
257+
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
258+
259+
float16_t ret = dscale * float16_t(int8_t(((bl.block.ql[qsi ] >> (b * 4)) & 0xF) | (((bl.block.qh[qhi ] >> qhshift) & 3) << 4)) - 32);
260+
261+
return ret;
262+
}
263+
264+
#if defined(DATA_A_IQ4_NL)
265+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
266+
block_iq4_nl block;
267+
};
268+
269+
float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
270+
{
271+
const float16_t d = bl.block.d;
272+
const uint idx = coordInBlock[1];
273+
const uint iqs = idx & 0xF;
274+
const uint shift = (idx & 0x10) >> 2;
275+
uint32_t qs = bl.block.qs[iqs];
276+
qs >>= shift;
277+
qs &= 0xF;
278+
float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
279+
return ret;
280+
}
281+
#endif
282+
283+
#if defined(DATA_A_Q4_0)
284+
#define dequantFuncA dequantFuncQ4_0
285+
#elif defined(DATA_A_Q4_1)
286+
#define dequantFuncA dequantFuncQ4_1
287+
#elif defined(DATA_A_Q5_0)
288+
#define dequantFuncA dequantFuncQ5_0
289+
#elif defined(DATA_A_Q5_1)
290+
#define dequantFuncA dequantFuncQ5_1
291+
#elif defined(DATA_A_Q8_0)
292+
#define dequantFuncA dequantFuncQ8_0
293+
#elif defined(DATA_A_Q2_K)
294+
#define dequantFuncA dequantFuncQ2_K
295+
#elif defined(DATA_A_Q3_K)
296+
#define dequantFuncA dequantFuncQ3_K
297+
#elif defined(DATA_A_Q4_K)
298+
#define dequantFuncA dequantFuncQ4_K
299+
#elif defined(DATA_A_Q5_K)
300+
#define dequantFuncA dequantFuncQ5_K
301+
#elif defined(DATA_A_Q6_K)
302+
#define dequantFuncA dequantFuncQ6_K
303+
#elif defined(DATA_A_IQ4_NL)
304+
#define dequantFuncA dequantFuncIQ4_NL
305+
#endif

0 commit comments

Comments
 (0)