Skip to content

Commit 8591589

Browse files
ikawrakowKawrakow
authored andcommitted
1.5 bit quantization (ggml-org#5453)
* iq1_s: WIP basics * iq1_s: CUDA is working * iq1_s: scalar CPU dot product * iq1_s: WIP AVX2 dot product - something is not right * Fix tests * Fix shadow warnings * Fix after merge with latest master * iq1_s: AVX2 finally works * iq1_s: ARM_NEON dot product. Works, but not very fast * iq1_s: better grid * iq1_s: use IQ2_XXS for attn_output At a cost of 0.04 extra bpw this gives a big improvement in PPL. * iq1_s: Metal basics Dequantize works, but not dot product * iq1_s: Metal works, but quite slow As usual, Apple Silicon does not like the code I write. * iq1_s: Tests * iq1_s: slightly faster dot product --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 302edbf commit 8591589

File tree

4 files changed

+371
-9
lines changed

4 files changed

+371
-9
lines changed

ggml-backend.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
756756
GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
757757
switch (op->op) {
758758
case GGML_OP_CPY:
759-
return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS; // missing type_traits.from_float
759+
return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS && op->type != GGML_TYPE_IQ1_S; // missing type_traits.from_float
760760
case GGML_OP_MUL_MAT:
761761
return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
762762
default:

ggml-metal.m

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
6262
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
6363
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
6465
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
6566
GGML_METAL_KERNEL_TYPE_RMS_NORM,
6667
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -83,6 +84,7 @@
8384
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
8485
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
8586
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
87+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
8688
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
8789
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
8890
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -101,6 +103,7 @@
101103
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
102104
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
103105
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
106+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
104107
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
105108
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
106109
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -116,6 +119,7 @@
116119
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
117120
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
118121
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
122+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
119123
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
120124
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
121125
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -131,6 +135,7 @@
131135
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
132136
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
133137
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
138+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
134139
GGML_METAL_KERNEL_TYPE_ROPE_F32,
135140
GGML_METAL_KERNEL_TYPE_ROPE_F16,
136141
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -433,6 +438,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
433438
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
434439
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
435440
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
441+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
436442
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
437443
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
438444
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -455,6 +461,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
455461
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
456462
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
457463
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
464+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
458465
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
459466
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
460467
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -473,6 +480,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
473480
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
474481
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
475482
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
483+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
476484
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
477485
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
478486
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -488,6 +496,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
488496
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
489497
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
490498
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
499+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
491500
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
492501
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
493502
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -503,6 +512,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
503512
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
504513
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
505514
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
515+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
506516
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
507517
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
508518
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -1297,6 +1307,7 @@ static bool ggml_metal_graph_compute(
12971307
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
12981308
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
12991309
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1310+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
13001311
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
13011312
}
13021313

@@ -1431,6 +1442,12 @@ static bool ggml_metal_graph_compute(
14311442
nth1 = 16;
14321443
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
14331444
} break;
1445+
case GGML_TYPE_IQ1_S:
1446+
{
1447+
nth0 = 4;
1448+
nth1 = 16;
1449+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1450+
} break;
14341451
default:
14351452
{
14361453
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1465,7 +1482,7 @@ static bool ggml_metal_graph_compute(
14651482

14661483
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
14671484
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1468-
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1485+
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) {
14691486
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
14701487
}
14711488
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1573,6 +1590,7 @@ static bool ggml_metal_graph_compute(
15731590
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
15741591
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
15751592
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1593+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
15761594
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
15771595
}
15781596

@@ -1710,6 +1728,12 @@ static bool ggml_metal_graph_compute(
17101728
nth1 = 16;
17111729
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
17121730
} break;
1731+
case GGML_TYPE_IQ1_S:
1732+
{
1733+
nth0 = 4;
1734+
nth1 = 16;
1735+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1736+
} break;
17131737
default:
17141738
{
17151739
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1760,7 +1784,7 @@ static bool ggml_metal_graph_compute(
17601784

17611785
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
17621786
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1763-
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1787+
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) {
17641788
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
17651789
}
17661790
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -1814,6 +1838,7 @@ static bool ggml_metal_graph_compute(
18141838
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
18151839
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
18161840
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1841+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
18171842
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
18181843
default: GGML_ASSERT(false && "not implemented");
18191844
}

0 commit comments

Comments
 (0)