Skip to content

Commit 42a2b64

Browse files
ikawrakowKawrakow
authored andcommitted
ggml : SOTA 2-bit quants (add IQ2_XS) (ggml-org#4856)
* iq2_xs: basics * iq2_xs: this should have been in the basics * iq2_xs: CUDA and scalar CPU works * iq2_xs: WIP Metal * iq2_xs: Metal now works * iq2_xs: working, but dog slow, ARM_NEON dot product * iq2_xs: better ARM_NEON dot product We are now at 19.5 t/s for TG-128 and 61 t/s for PP-512 when running on the CPU. * iq2_xs: AVX2 dot product - 19.5 t/s * iq2_xs: faster AVX2 dit product 21.4 t/s for TG-128, 59.2 t/s for PP-512. The latter is 2x compared to the previous version. * iq2_xs: had forgotten to delete iq2-data.h * Add llama enum for IQ2_XS --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent eaeb2c2 commit 42a2b64

10 files changed

+1038
-28
lines changed

Diff for: ggml-cuda.cu

+227-5
Large diffs are not rendered by default.

Diff for: ggml-metal.m

+36-6
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
9090
GGML_METAL_DECL_KERNEL(get_rows_i32);
9191
GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
92+
GGML_METAL_DECL_KERNEL(get_rows_iq2_xs);
9293
GGML_METAL_DECL_KERNEL(rms_norm);
9394
GGML_METAL_DECL_KERNEL(group_norm);
9495
GGML_METAL_DECL_KERNEL(norm);
@@ -108,6 +109,7 @@
108109
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
109110
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
110111
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
112+
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32);
111113
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
112114
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
113115
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
@@ -124,6 +126,7 @@
124126
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
125127
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
126128
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
129+
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32);
127130
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
128131
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
129132
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -137,6 +140,7 @@
137140
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
138141
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
139142
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
143+
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32);
140144
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
141145
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
142146
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
@@ -150,6 +154,7 @@
150154
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
151155
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
152156
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
157+
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32);
153158
GGML_METAL_DECL_KERNEL(rope_f32);
154159
GGML_METAL_DECL_KERNEL(rope_f16);
155160
GGML_METAL_DECL_KERNEL(alibi_f32);
@@ -385,6 +390,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
385390
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
386391
GGML_METAL_ADD_KERNEL(get_rows_i32);
387392
GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
393+
GGML_METAL_ADD_KERNEL(get_rows_iq2_xs);
388394
GGML_METAL_ADD_KERNEL(rms_norm);
389395
GGML_METAL_ADD_KERNEL(group_norm);
390396
GGML_METAL_ADD_KERNEL(norm);
@@ -404,6 +410,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
404410
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
405411
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
406412
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
413+
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xs_f32);
407414
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
408415
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
409416
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
@@ -420,6 +427,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
420427
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
421428
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
422429
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
430+
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xs_f32);
423431
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
424432
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
425433
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -434,6 +442,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
434442
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
435443
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
436444
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
445+
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xs_f32);
437446
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
438447
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
439448
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
@@ -447,6 +456,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
447456
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
448457
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
449458
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
459+
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xs_f32);
450460
}
451461
GGML_METAL_ADD_KERNEL(rope_f32);
452462
GGML_METAL_ADD_KERNEL(rope_f16);
@@ -513,6 +523,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
513523
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
514524
GGML_METAL_DEL_KERNEL(get_rows_i32);
515525
GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
526+
GGML_METAL_DEL_KERNEL(get_rows_iq2_xs);
516527
GGML_METAL_DEL_KERNEL(rms_norm);
517528
GGML_METAL_DEL_KERNEL(group_norm);
518529
GGML_METAL_DEL_KERNEL(norm);
@@ -532,6 +543,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
532543
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
533544
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
534545
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
546+
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32);
535547
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
536548
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
537549
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
@@ -548,6 +560,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
548560
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
549561
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
550562
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
563+
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32);
551564
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
552565
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
553566
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -562,6 +575,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
562575
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
563576
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
564577
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
578+
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32);
565579
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
566580
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
567581
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
@@ -575,6 +589,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
575589
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
576590
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
577591
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
592+
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32);
578593
}
579594
GGML_METAL_DEL_KERNEL(rope_f32);
580595
GGML_METAL_DEL_KERNEL(rope_f16);
@@ -1561,6 +1576,7 @@ bool ggml_metal_graph_compute(
15611576
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
15621577
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
15631578
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
1579+
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xs_f32]; break;
15641580
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
15651581
}
15661582
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1679,6 +1695,12 @@ bool ggml_metal_graph_compute(
16791695
nth1 = 16;
16801696
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
16811697
} break;
1698+
case GGML_TYPE_IQ2_XS:
1699+
{
1700+
nth0 = 4;
1701+
nth1 = 16;
1702+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32];
1703+
} break;
16821704
default:
16831705
{
16841706
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@@ -1712,12 +1734,12 @@ bool ggml_metal_graph_compute(
17121734

17131735
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
17141736
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1715-
//src0t == GGML_TYPE_IQ2_XXS ||
17161737
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
17171738
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
17181739
}
1719-
else if (src0t == GGML_TYPE_IQ2_XXS) {
1720-
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
1740+
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1741+
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1742+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
17211743
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
17221744
}
17231745
else if (src0t == GGML_TYPE_Q4_K) {
@@ -1810,6 +1832,7 @@ bool ggml_metal_graph_compute(
18101832
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
18111833
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
18121834
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
1835+
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break;
18131836
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
18141837
}
18151838
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1931,6 +1954,12 @@ bool ggml_metal_graph_compute(
19311954
nth1 = 16;
19321955
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
19331956
} break;
1957+
case GGML_TYPE_IQ2_XS:
1958+
{
1959+
nth0 = 4;
1960+
nth1 = 16;
1961+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32];
1962+
} break;
19341963
default:
19351964
{
19361965
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@@ -1980,12 +2009,12 @@ bool ggml_metal_graph_compute(
19802009

19812010
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
19822011
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1983-
//src2t == GGML_TYPE_IQ2_XXS ||
19842012
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
19852013
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
19862014
}
1987-
else if (src2t == GGML_TYPE_IQ2_XXS) {
1988-
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
2015+
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
2016+
const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2017+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
19892018
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
19902019
}
19912020
else if (src2t == GGML_TYPE_Q4_K) {
@@ -2026,6 +2055,7 @@ bool ggml_metal_graph_compute(
20262055
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
20272056
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
20282057
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
2058+
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break;
20292059
default: GGML_ASSERT(false && "not implemented");
20302060
}
20312061

0 commit comments

Comments
 (0)