89
89
GGML_METAL_DECL_KERNEL (get_rows_q6_K);
90
90
GGML_METAL_DECL_KERNEL (get_rows_i32);
91
91
GGML_METAL_DECL_KERNEL (get_rows_iq2_xxs);
92
+ GGML_METAL_DECL_KERNEL (get_rows_iq2_xs);
92
93
GGML_METAL_DECL_KERNEL (rms_norm);
93
94
GGML_METAL_DECL_KERNEL (group_norm);
94
95
GGML_METAL_DECL_KERNEL (norm);
108
109
GGML_METAL_DECL_KERNEL (mul_mv_q5_K_f32);
109
110
GGML_METAL_DECL_KERNEL (mul_mv_q6_K_f32);
110
111
GGML_METAL_DECL_KERNEL (mul_mv_iq2_xxs_f32);
112
+ GGML_METAL_DECL_KERNEL (mul_mv_iq2_xs_f32);
111
113
GGML_METAL_DECL_KERNEL (mul_mv_id_f32_f32);
112
114
// GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
113
115
GGML_METAL_DECL_KERNEL (mul_mv_id_f16_f32);
124
126
GGML_METAL_DECL_KERNEL (mul_mv_id_q5_K_f32);
125
127
GGML_METAL_DECL_KERNEL (mul_mv_id_q6_K_f32);
126
128
GGML_METAL_DECL_KERNEL (mul_mv_id_iq2_xxs_f32);
129
+ GGML_METAL_DECL_KERNEL (mul_mv_id_iq2_xs_f32);
127
130
GGML_METAL_DECL_KERNEL (mul_mm_f32_f32);
128
131
GGML_METAL_DECL_KERNEL (mul_mm_f16_f32);
129
132
GGML_METAL_DECL_KERNEL (mul_mm_q4_0_f32);
137
140
GGML_METAL_DECL_KERNEL (mul_mm_q5_K_f32);
138
141
GGML_METAL_DECL_KERNEL (mul_mm_q6_K_f32);
139
142
GGML_METAL_DECL_KERNEL (mul_mm_iq2_xxs_f32);
143
+ GGML_METAL_DECL_KERNEL (mul_mm_iq2_xs_f32);
140
144
GGML_METAL_DECL_KERNEL (mul_mm_id_f32_f32);
141
145
GGML_METAL_DECL_KERNEL (mul_mm_id_f16_f32);
142
146
GGML_METAL_DECL_KERNEL (mul_mm_id_q4_0_f32);
150
154
GGML_METAL_DECL_KERNEL (mul_mm_id_q5_K_f32);
151
155
GGML_METAL_DECL_KERNEL (mul_mm_id_q6_K_f32);
152
156
GGML_METAL_DECL_KERNEL (mul_mm_id_iq2_xxs_f32);
157
+ GGML_METAL_DECL_KERNEL (mul_mm_id_iq2_xs_f32);
153
158
GGML_METAL_DECL_KERNEL (rope_f32);
154
159
GGML_METAL_DECL_KERNEL (rope_f16);
155
160
GGML_METAL_DECL_KERNEL (alibi_f32);
@@ -385,6 +390,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
385
390
GGML_METAL_ADD_KERNEL (get_rows_q6_K);
386
391
GGML_METAL_ADD_KERNEL (get_rows_i32);
387
392
GGML_METAL_ADD_KERNEL (get_rows_iq2_xxs);
393
+ GGML_METAL_ADD_KERNEL (get_rows_iq2_xs);
388
394
GGML_METAL_ADD_KERNEL (rms_norm);
389
395
GGML_METAL_ADD_KERNEL (group_norm);
390
396
GGML_METAL_ADD_KERNEL (norm);
@@ -404,6 +410,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
404
410
GGML_METAL_ADD_KERNEL (mul_mv_q5_K_f32);
405
411
GGML_METAL_ADD_KERNEL (mul_mv_q6_K_f32);
406
412
GGML_METAL_ADD_KERNEL (mul_mv_iq2_xxs_f32);
413
+ GGML_METAL_ADD_KERNEL (mul_mv_iq2_xs_f32);
407
414
GGML_METAL_ADD_KERNEL (mul_mv_id_f32_f32);
408
415
// GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
409
416
GGML_METAL_ADD_KERNEL (mul_mv_id_f16_f32);
@@ -420,6 +427,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
420
427
GGML_METAL_ADD_KERNEL (mul_mv_id_q5_K_f32);
421
428
GGML_METAL_ADD_KERNEL (mul_mv_id_q6_K_f32);
422
429
GGML_METAL_ADD_KERNEL (mul_mv_id_iq2_xxs_f32);
430
+ GGML_METAL_ADD_KERNEL (mul_mv_id_iq2_xs_f32);
423
431
if ([ctx->device supportsFamily: MTLGPUFamilyApple7]) {
424
432
GGML_METAL_ADD_KERNEL (mul_mm_f32_f32);
425
433
GGML_METAL_ADD_KERNEL (mul_mm_f16_f32);
@@ -434,6 +442,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
434
442
GGML_METAL_ADD_KERNEL (mul_mm_q5_K_f32);
435
443
GGML_METAL_ADD_KERNEL (mul_mm_q6_K_f32);
436
444
GGML_METAL_ADD_KERNEL (mul_mm_iq2_xxs_f32);
445
+ GGML_METAL_ADD_KERNEL (mul_mm_iq2_xs_f32);
437
446
GGML_METAL_ADD_KERNEL (mul_mm_id_f32_f32);
438
447
GGML_METAL_ADD_KERNEL (mul_mm_id_f16_f32);
439
448
GGML_METAL_ADD_KERNEL (mul_mm_id_q4_0_f32);
@@ -447,6 +456,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
447
456
GGML_METAL_ADD_KERNEL (mul_mm_id_q5_K_f32);
448
457
GGML_METAL_ADD_KERNEL (mul_mm_id_q6_K_f32);
449
458
GGML_METAL_ADD_KERNEL (mul_mm_id_iq2_xxs_f32);
459
+ GGML_METAL_ADD_KERNEL (mul_mm_id_iq2_xs_f32);
450
460
}
451
461
GGML_METAL_ADD_KERNEL (rope_f32);
452
462
GGML_METAL_ADD_KERNEL (rope_f16);
@@ -513,6 +523,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
513
523
GGML_METAL_DEL_KERNEL (get_rows_q6_K);
514
524
GGML_METAL_DEL_KERNEL (get_rows_i32);
515
525
GGML_METAL_DEL_KERNEL (get_rows_iq2_xxs);
526
+ GGML_METAL_DEL_KERNEL (get_rows_iq2_xs);
516
527
GGML_METAL_DEL_KERNEL (rms_norm);
517
528
GGML_METAL_DEL_KERNEL (group_norm);
518
529
GGML_METAL_DEL_KERNEL (norm);
@@ -532,6 +543,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
532
543
GGML_METAL_DEL_KERNEL (mul_mv_q5_K_f32);
533
544
GGML_METAL_DEL_KERNEL (mul_mv_q6_K_f32);
534
545
GGML_METAL_DEL_KERNEL (mul_mv_iq2_xxs_f32);
546
+ GGML_METAL_DEL_KERNEL (mul_mv_iq2_xs_f32);
535
547
GGML_METAL_DEL_KERNEL (mul_mv_id_f32_f32);
536
548
// GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
537
549
GGML_METAL_DEL_KERNEL (mul_mv_id_f16_f32);
@@ -548,6 +560,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
548
560
GGML_METAL_DEL_KERNEL (mul_mv_id_q5_K_f32);
549
561
GGML_METAL_DEL_KERNEL (mul_mv_id_q6_K_f32);
550
562
GGML_METAL_DEL_KERNEL (mul_mv_id_iq2_xxs_f32);
563
+ GGML_METAL_DEL_KERNEL (mul_mv_id_iq2_xs_f32);
551
564
if ([ctx->device supportsFamily: MTLGPUFamilyApple7]) {
552
565
GGML_METAL_DEL_KERNEL (mul_mm_f32_f32);
553
566
GGML_METAL_DEL_KERNEL (mul_mm_f16_f32);
@@ -562,6 +575,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
562
575
GGML_METAL_DEL_KERNEL (mul_mm_q5_K_f32);
563
576
GGML_METAL_DEL_KERNEL (mul_mm_q6_K_f32);
564
577
GGML_METAL_DEL_KERNEL (mul_mm_iq2_xxs_f32);
578
+ GGML_METAL_DEL_KERNEL (mul_mm_iq2_xs_f32);
565
579
GGML_METAL_DEL_KERNEL (mul_mm_id_f32_f32);
566
580
GGML_METAL_DEL_KERNEL (mul_mm_id_f16_f32);
567
581
GGML_METAL_DEL_KERNEL (mul_mm_id_q4_0_f32);
@@ -575,6 +589,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
575
589
GGML_METAL_DEL_KERNEL (mul_mm_id_q5_K_f32);
576
590
GGML_METAL_DEL_KERNEL (mul_mm_id_q6_K_f32);
577
591
GGML_METAL_DEL_KERNEL (mul_mm_id_iq2_xxs_f32);
592
+ GGML_METAL_DEL_KERNEL (mul_mm_id_iq2_xs_f32);
578
593
}
579
594
GGML_METAL_DEL_KERNEL (rope_f32);
580
595
GGML_METAL_DEL_KERNEL (rope_f16);
@@ -1561,6 +1576,7 @@ bool ggml_metal_graph_compute(
1561
1576
case GGML_TYPE_Q5_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q5_K_f32]; break ;
1562
1577
case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q6_K_f32]; break ;
1563
1578
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState: ctx->pipeline_mul_mm_iq2_xxs_f32]; break ;
1579
+ case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState: ctx->pipeline_mul_mm_iq2_xs_f32]; break ;
1564
1580
default : GGML_ASSERT (false && " MUL MAT-MAT not implemented" );
1565
1581
}
1566
1582
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
@@ -1679,6 +1695,12 @@ bool ggml_metal_graph_compute(
1679
1695
nth1 = 16 ;
1680
1696
[encoder setComputePipelineState: ctx->pipeline_mul_mv_iq2_xxs_f32];
1681
1697
} break ;
1698
+ case GGML_TYPE_IQ2_XS:
1699
+ {
1700
+ nth0 = 4 ;
1701
+ nth1 = 16 ;
1702
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_iq2_xs_f32];
1703
+ } break ;
1682
1704
default :
1683
1705
{
1684
1706
GGML_METAL_LOG_ERROR (" Asserting on type %d \n " , (int )src0t);
@@ -1712,12 +1734,12 @@ bool ggml_metal_graph_compute(
1712
1734
1713
1735
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1714
1736
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1715
- // src0t == GGML_TYPE_IQ2_XXS ||
1716
1737
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1717
1738
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1718
1739
}
1719
- else if (src0t == GGML_TYPE_IQ2_XXS) {
1720
- [encoder setThreadgroupMemoryLength: (256 *8 +128 ) atIndex: 0 ];
1740
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1741
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256 *8 +128 : 512 *8 +128 ;
1742
+ [encoder setThreadgroupMemoryLength: mem_size atIndex: 0 ];
1721
1743
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1722
1744
}
1723
1745
else if (src0t == GGML_TYPE_Q4_K) {
@@ -1810,6 +1832,7 @@ bool ggml_metal_graph_compute(
1810
1832
case GGML_TYPE_Q5_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_q5_K_f32]; break ;
1811
1833
case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_q6_K_f32]; break ;
1812
1834
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break ;
1835
+ case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_iq2_xs_f32]; break ;
1813
1836
default : GGML_ASSERT (false && " MUL_MAT_ID not implemented" );
1814
1837
}
1815
1838
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
@@ -1931,6 +1954,12 @@ bool ggml_metal_graph_compute(
1931
1954
nth1 = 16 ;
1932
1955
[encoder setComputePipelineState: ctx->pipeline_mul_mv_id_iq2_xxs_f32];
1933
1956
} break ;
1957
+ case GGML_TYPE_IQ2_XS:
1958
+ {
1959
+ nth0 = 4 ;
1960
+ nth1 = 16 ;
1961
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_iq2_xs_f32];
1962
+ } break ;
1934
1963
default :
1935
1964
{
1936
1965
GGML_METAL_LOG_ERROR (" Asserting on type %d \n " , (int )src2t);
@@ -1980,12 +2009,12 @@ bool ggml_metal_graph_compute(
1980
2009
1981
2010
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1982
2011
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1983
- // src2t == GGML_TYPE_IQ2_XXS ||
1984
2012
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1985
2013
[encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 7 )/8 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1986
2014
}
1987
- else if (src2t == GGML_TYPE_IQ2_XXS) {
1988
- [encoder setThreadgroupMemoryLength: (256 *8 +128 ) atIndex: 0 ];
2015
+ else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
2016
+ const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256 *8 +128 : 512 *8 +128 ;
2017
+ [encoder setThreadgroupMemoryLength: mem_size atIndex: 0 ];
1989
2018
[encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 7 )/8 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1990
2019
}
1991
2020
else if (src2t == GGML_TYPE_Q4_K) {
@@ -2026,6 +2055,7 @@ bool ggml_metal_graph_compute(
2026
2055
case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_get_rows_q6_K]; break ;
2027
2056
case GGML_TYPE_I32: [encoder setComputePipelineState: ctx->pipeline_get_rows_i32]; break ;
2028
2057
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState: ctx->pipeline_get_rows_iq2_xxs]; break ;
2058
+ case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState: ctx->pipeline_get_rows_iq2_xs]; break ;
2029
2059
default : GGML_ASSERT (false && " not implemented" );
2030
2060
}
2031
2061
0 commit comments