Skip to content

Commit c71b9bf

Browse files
ikawrakowKawrakow
authored andcommitted
1.5 bit quantization (ggml-org#5453)
* iq1_s: WIP basics * iq1_s: CUDA is working * iq1_s: scalar CPU dot product * iq1_s: WIP AVX2 dot product - something is not right * Fix tests * Fix shadow warnings * Fix after merge with latest master * iq1_s: AVX2 finally works * iq1_s: ARM_NEON dot product. Works, but not very fast * iq1_s: better grid * iq1_s: use IQ2_XXS for attn_output At a cost of 0.04 extra bpw this gives a big improvement in PPL. * iq1_s: Metal basics Dequantize works, but not dot product * iq1_s: Metal works, but quite slow As usual, Apple Silicon does not like the code I write. * iq1_s: Tests * iq1_s: slightly faster dot product --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 7189c5e commit c71b9bf

File tree

12 files changed

+1286
-48
lines changed

12 files changed

+1286
-48
lines changed

examples/quantize/quantize.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
2323
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", },
2424
{ "IQ2_XXS",LLAMA_FTYPE_MOSTLY_IQ2_XXS," 2.06 bpw quantization", },
2525
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", },
26+
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
2627
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
2728
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
2829
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
@@ -287,9 +288,10 @@ int main(int argc, char ** argv) {
287288
}
288289
}
289290

290-
if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) && imatrix_data.empty()) {
291+
if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS ||
292+
params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) && imatrix_data.empty()) {
291293
fprintf(stderr, "\n===============================================================================================\n");
292-
fprintf(stderr, "Please do not use IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
294+
fprintf(stderr, "Please do not use IQ1_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
293295
fprintf(stderr, "===============================================================================================\n\n\n");
294296
return 1;
295297
}

ggml-backend.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
756756
GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
757757
switch (op->op) {
758758
case GGML_OP_CPY:
759-
return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS; // missing type_traits.from_float
759+
return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS && op->type != GGML_TYPE_IQ1_S; // missing type_traits.from_float
760760
case GGML_OP_MUL_MAT:
761761
return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
762762
default:

ggml-cuda.cu

Lines changed: 223 additions & 1 deletion
Large diffs are not rendered by default.

ggml-metal.m

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
6262
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
6363
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64+
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
6465
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
6566
GGML_METAL_KERNEL_TYPE_RMS_NORM,
6667
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -83,6 +84,7 @@
8384
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
8485
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
8586
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
87+
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
8688
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
8789
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
8890
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
@@ -101,6 +103,7 @@
101103
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
102104
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
103105
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
106+
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
104107
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
105108
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
106109
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
@@ -116,6 +119,7 @@
116119
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
117120
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
118121
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
122+
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
119123
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
120124
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
121125
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
@@ -131,6 +135,7 @@
131135
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
132136
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
133137
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
138+
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
134139
GGML_METAL_KERNEL_TYPE_ROPE_F32,
135140
GGML_METAL_KERNEL_TYPE_ROPE_F16,
136141
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
@@ -433,6 +438,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
433438
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
434439
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
435440
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
441+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
436442
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
437443
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
438444
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
@@ -455,6 +461,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
455461
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
456462
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
457463
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
464+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
458465
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
459466
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
460467
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
@@ -473,6 +480,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
473480
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
474481
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
475482
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
483+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
476484
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
477485
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
478486
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
@@ -488,6 +496,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
488496
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
489497
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
490498
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
499+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
491500
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
492501
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
493502
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
@@ -503,6 +512,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
503512
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
504513
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
505514
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
515+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
506516
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
507517
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
508518
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
@@ -1318,6 +1328,7 @@ static bool ggml_metal_graph_compute(
13181328
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
13191329
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
13201330
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1331+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
13211332
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
13221333
}
13231334

@@ -1452,6 +1463,12 @@ static bool ggml_metal_graph_compute(
14521463
nth1 = 16;
14531464
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
14541465
} break;
1466+
case GGML_TYPE_IQ1_S:
1467+
{
1468+
nth0 = 4;
1469+
nth1 = 16;
1470+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1471+
} break;
14551472
default:
14561473
{
14571474
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1486,7 +1503,7 @@ static bool ggml_metal_graph_compute(
14861503

14871504
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
14881505
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1489-
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1506+
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K) {
14901507
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
14911508
}
14921509
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1592,6 +1609,7 @@ static bool ggml_metal_graph_compute(
15921609
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
15931610
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
15941611
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
1612+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
15951613
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
15961614
}
15971615

@@ -1729,6 +1747,12 @@ static bool ggml_metal_graph_compute(
17291747
nth1 = 16;
17301748
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
17311749
} break;
1750+
case GGML_TYPE_IQ1_S:
1751+
{
1752+
nth0 = 4;
1753+
nth1 = 16;
1754+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1755+
} break;
17321756
default:
17331757
{
17341758
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1779,7 +1803,7 @@ static bool ggml_metal_graph_compute(
17791803

17801804
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
17811805
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1782-
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1806+
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K) {
17831807
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
17841808
}
17851809
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -1833,6 +1857,7 @@ static bool ggml_metal_graph_compute(
18331857
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
18341858
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
18351859
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
1860+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
18361861
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
18371862
default: GGML_ASSERT(false && "not implemented");
18381863
}

0 commit comments

Comments
 (0)