Skip to content

Commit c67fe68

Browse files
jhen0409ggerganov
andauthored
metal : implement q5_0 and q5_1 kernels (#3648)
* metal : implement dequantize_q5_0 * metal : block_q_n_dot_y for block_q5_0 (broken) * metal : revert unnecessary change * metal : implement dequantize_q5_1 * metal : block_q_n_dot_y for q5_1 (broken) * metal : fix block_q_n_dot_y * minor : spaces / formatting --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 1117d06 commit c67fe68

File tree

2 files changed

+206
-4
lines changed

2 files changed

+206
-4
lines changed

ggml-metal.m

+44-3
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
GGML_METAL_DECL_KERNEL(get_rows_f16);
7474
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
7575
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
76+
GGML_METAL_DECL_KERNEL(get_rows_q5_0);
77+
GGML_METAL_DECL_KERNEL(get_rows_q5_1);
7678
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
7779
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
7880
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
@@ -87,6 +89,8 @@
8789
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
8890
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
8991
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
92+
GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
93+
GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
9094
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
9195
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
9296
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
@@ -97,6 +101,8 @@
97101
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
98102
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
99103
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
104+
GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
105+
GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
100106
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
101107
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
102108
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
@@ -254,6 +260,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
254260
GGML_METAL_ADD_KERNEL(get_rows_f16);
255261
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
256262
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
263+
GGML_METAL_ADD_KERNEL(get_rows_q5_0);
264+
GGML_METAL_ADD_KERNEL(get_rows_q5_1);
257265
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
258266
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
259267
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
@@ -268,6 +276,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
268276
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
269277
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
270278
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
279+
GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
280+
GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
271281
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
272282
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
273283
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
@@ -278,8 +288,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
278288
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
279289
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
280290
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
281-
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
282291
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
292+
GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
293+
GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
294+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
283295
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
284296
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
285297
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
@@ -346,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
346358
GGML_METAL_DEL_KERNEL(get_rows_f16);
347359
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
348360
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
361+
GGML_METAL_DEL_KERNEL(get_rows_q5_0);
362+
GGML_METAL_DEL_KERNEL(get_rows_q5_1);
349363
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
350364
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
351365
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
@@ -360,6 +374,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
360374
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
361375
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
362376
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
377+
GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
378+
GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
363379
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
364380
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
365381
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
@@ -370,8 +386,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
370386
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
371387
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
372388
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
373-
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
374389
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
390+
GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
391+
GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
392+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
375393
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
376394
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
377395
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
@@ -1052,6 +1070,8 @@ void ggml_metal_graph_compute(
10521070
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
10531071
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
10541072
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
1073+
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
1074+
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
10551075
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
10561076
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
10571077
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
@@ -1121,6 +1141,24 @@ void ggml_metal_graph_compute(
11211141
nth1 = 8;
11221142
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
11231143
} break;
1144+
case GGML_TYPE_Q5_0:
1145+
{
1146+
GGML_ASSERT(ne02 == 1);
1147+
GGML_ASSERT(ne12 == 1);
1148+
1149+
nth0 = 8;
1150+
nth1 = 8;
1151+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1152+
} break;
1153+
case GGML_TYPE_Q5_1:
1154+
{
1155+
GGML_ASSERT(ne02 == 1);
1156+
GGML_ASSERT(ne12 == 1);
1157+
1158+
nth0 = 8;
1159+
nth1 = 8;
1160+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1161+
} break;
11241162
case GGML_TYPE_Q8_0:
11251163
{
11261164
GGML_ASSERT(ne02 == 1);
@@ -1201,7 +1239,8 @@ void ggml_metal_graph_compute(
12011239
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
12021240
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
12031241

1204-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1242+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1243+
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
12051244
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
12061245
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
12071246
}
@@ -1233,6 +1272,8 @@ void ggml_metal_graph_compute(
12331272
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
12341273
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
12351274
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
1275+
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
1276+
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
12361277
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
12371278
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
12381279
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;

ggml-metal.metal

+162-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,21 @@ typedef struct {
1818
uint8_t qs[QK4_1 / 2]; // nibbles / quants
1919
} block_q4_1;
2020

21+
#define QK5_0 32
22+
typedef struct {
23+
half d; // delta
24+
uint8_t qh[4]; // 5-th bit of quants
25+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
26+
} block_q5_0;
27+
28+
#define QK5_1 32
29+
typedef struct {
30+
half d; // delta
31+
half m; // min
32+
uint8_t qh[4]; // 5-th bit of quants
33+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
34+
} block_q5_1;
35+
2136
#define QK8_0 32
2237
typedef struct {
2338
half d; // delta
@@ -399,8 +414,11 @@ kernel void kernel_rms_norm(
399414
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
400415
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
401416
float d = qb_curr->d;
417+
402418
float2 acc = 0.f;
419+
403420
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
421+
404422
for (int i = 0; i < 8; i+=2) {
405423
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
406424
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -417,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
417435
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
418436
float d = qb_curr->d;
419437
float m = qb_curr->m;
420-
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
438+
421439
float2 acc = 0.f;
440+
441+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
442+
422443
for (int i = 0; i < 8; i+=2) {
423444
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
424445
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -428,6 +449,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
428449
return d * (acc[0] + acc[1]) + sumy * m;
429450
}
430451

452+
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
453+
// il indicates where the q5 quants begin (0 or QK5_0/4)
454+
// we assume that the yl's have been multiplied with the appropriate scale factor
455+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
456+
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
457+
float d = qb_curr->d;
458+
459+
float2 acc = 0.f;
460+
461+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
462+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
463+
464+
for (int i = 0; i < 8; i+=2) {
465+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
466+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
467+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
468+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
469+
}
470+
return d * (sumy * -16.f + acc[0] + acc[1]);
471+
}
472+
473+
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
474+
// il indicates where the q5 quants begin (0 or QK5_1/4)
475+
// we assume that the yl's have been multiplied with the appropriate scale factor
476+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
477+
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
478+
float d = qb_curr->d;
479+
float m = qb_curr->m;
480+
481+
float2 acc = 0.f;
482+
483+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
484+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
485+
486+
for (int i = 0; i < 8; i+=2) {
487+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
488+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
489+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
490+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
491+
}
492+
return d * (acc[0] + acc[1]) + sumy * m;
493+
}
494+
431495
// putting them in the kernel cause a significant performance penalty
432496
#define N_DST 4 // each SIMD group works on 4 rows
433497
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
@@ -525,6 +589,43 @@ kernel void kernel_mul_mv_q4_1_f32(
525589
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
526590
}
527591

592+
kernel void kernel_mul_mv_q5_0_f32(
593+
device const void * src0,
594+
device const float * src1,
595+
device float * dst,
596+
constant int64_t & ne00,
597+
constant int64_t & ne01[[buffer(4)]],
598+
constant int64_t & ne02[[buffer(5)]],
599+
constant int64_t & ne10[[buffer(9)]],
600+
constant int64_t & ne12[[buffer(11)]],
601+
constant int64_t & ne0[[buffer(15)]],
602+
constant int64_t & ne1[[buffer(16)]],
603+
constant uint & gqa[[buffer(17)]],
604+
uint3 tgpig[[threadgroup_position_in_grid]],
605+
uint tiisg[[thread_index_in_simdgroup]],
606+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
607+
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
608+
}
609+
610+
kernel void kernel_mul_mv_q5_1_f32(
611+
device const void * src0,
612+
device const float * src1,
613+
device float * dst,
614+
constant int64_t & ne00,
615+
constant int64_t & ne01[[buffer(4)]],
616+
constant int64_t & ne02[[buffer(5)]],
617+
constant int64_t & ne10[[buffer(9)]],
618+
constant int64_t & ne12[[buffer(11)]],
619+
constant int64_t & ne0[[buffer(15)]],
620+
constant int64_t & ne1[[buffer(16)]],
621+
constant uint & gqa[[buffer(17)]],
622+
uint3 tgpig[[threadgroup_position_in_grid]],
623+
uint tiisg[[thread_index_in_simdgroup]],
624+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
625+
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
626+
}
627+
628+
528629
#define NB_Q8_0 8
529630

530631
kernel void kernel_mul_mv_q8_0_f32(
@@ -2149,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
21492250
}
21502251
}
21512252

2253+
template <typename type4x4>
2254+
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2255+
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
2256+
const float d = xb->d;
2257+
const float md = -16.h * xb->d;
2258+
const ushort mask = il ? 0x00F0 : 0x000F;
2259+
2260+
const uint32_t qh = *((device const uint32_t *)xb->qh);
2261+
2262+
const int x_mv = il ? 4 : 0;
2263+
2264+
const int gh_mv = il ? 12 : 0;
2265+
const int gh_bk = il ? 0 : 4;
2266+
2267+
for (int i = 0; i < 8; i++) {
2268+
// extract the 5-th bits for x0 and x1
2269+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2270+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2271+
2272+
// combine the 4-bits from qs with the 5th bit
2273+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2274+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2275+
2276+
reg[i/2][2*(i%2)+0] = d * x0 + md;
2277+
reg[i/2][2*(i%2)+1] = d * x1 + md;
2278+
}
2279+
}
2280+
2281+
template <typename type4x4>
2282+
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
2283+
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
2284+
const float d = xb->d;
2285+
const float m = xb->m;
2286+
const ushort mask = il ? 0x00F0 : 0x000F;
2287+
2288+
const uint32_t qh = *((device const uint32_t *)xb->qh);
2289+
2290+
const int x_mv = il ? 4 : 0;
2291+
2292+
const int gh_mv = il ? 12 : 0;
2293+
const int gh_bk = il ? 0 : 4;
2294+
2295+
for (int i = 0; i < 8; i++) {
2296+
// extract the 5-th bits for x0 and x1
2297+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2298+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2299+
2300+
// combine the 4-bits from qs with the 5th bit
2301+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2302+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2303+
2304+
reg[i/2][2*(i%2)+0] = d * x0 + m;
2305+
reg[i/2][2*(i%2)+1] = d * x1 + m;
2306+
}
2307+
}
2308+
21522309
template <typename type4x4>
21532310
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
21542311
device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2490,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
24902647
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
24912648
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
24922649
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2650+
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
2651+
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
24932652
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
24942653
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
24952654
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2518,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
25182677
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
25192678
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
25202679
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2680+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
2681+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
25212682
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
25222683
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
25232684
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;

0 commit comments

Comments
 (0)