102
102
GGML_METAL_DECL_KERNEL (mul_mv_q4_K_f32);
103
103
GGML_METAL_DECL_KERNEL (mul_mv_q5_K_f32);
104
104
GGML_METAL_DECL_KERNEL (mul_mv_q6_K_f32);
105
+ GGML_METAL_DECL_KERNEL (mul_mv_id_f32_f32);
106
+ // GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
107
+ GGML_METAL_DECL_KERNEL (mul_mv_id_f16_f32);
108
+ // GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
109
+ // GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
110
+ GGML_METAL_DECL_KERNEL (mul_mv_id_q4_0_f32);
111
+ GGML_METAL_DECL_KERNEL (mul_mv_id_q4_1_f32);
112
+ GGML_METAL_DECL_KERNEL (mul_mv_id_q5_0_f32);
113
+ GGML_METAL_DECL_KERNEL (mul_mv_id_q5_1_f32);
114
+ GGML_METAL_DECL_KERNEL (mul_mv_id_q8_0_f32);
115
+ GGML_METAL_DECL_KERNEL (mul_mv_id_q2_K_f32);
116
+ GGML_METAL_DECL_KERNEL (mul_mv_id_q3_K_f32);
117
+ GGML_METAL_DECL_KERNEL (mul_mv_id_q4_K_f32);
118
+ GGML_METAL_DECL_KERNEL (mul_mv_id_q5_K_f32);
119
+ GGML_METAL_DECL_KERNEL (mul_mv_id_q6_K_f32);
105
120
GGML_METAL_DECL_KERNEL (mul_mm_f32_f32);
106
121
GGML_METAL_DECL_KERNEL (mul_mm_f16_f32);
107
122
GGML_METAL_DECL_KERNEL (mul_mm_q4_0_f32);
@@ -354,6 +369,21 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
354
369
GGML_METAL_ADD_KERNEL (mul_mv_q4_K_f32);
355
370
GGML_METAL_ADD_KERNEL (mul_mv_q5_K_f32);
356
371
GGML_METAL_ADD_KERNEL (mul_mv_q6_K_f32);
372
+ GGML_METAL_ADD_KERNEL (mul_mv_id_f32_f32);
373
+ // GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
374
+ GGML_METAL_ADD_KERNEL (mul_mv_id_f16_f32);
375
+ // GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
376
+ // GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
377
+ GGML_METAL_ADD_KERNEL (mul_mv_id_q4_0_f32);
378
+ GGML_METAL_ADD_KERNEL (mul_mv_id_q4_1_f32);
379
+ GGML_METAL_ADD_KERNEL (mul_mv_id_q5_0_f32);
380
+ GGML_METAL_ADD_KERNEL (mul_mv_id_q5_1_f32);
381
+ GGML_METAL_ADD_KERNEL (mul_mv_id_q8_0_f32);
382
+ GGML_METAL_ADD_KERNEL (mul_mv_id_q2_K_f32);
383
+ GGML_METAL_ADD_KERNEL (mul_mv_id_q3_K_f32);
384
+ GGML_METAL_ADD_KERNEL (mul_mv_id_q4_K_f32);
385
+ GGML_METAL_ADD_KERNEL (mul_mv_id_q5_K_f32);
386
+ GGML_METAL_ADD_KERNEL (mul_mv_id_q6_K_f32);
357
387
if ([ctx->device supportsFamily: MTLGPUFamilyApple7]) {
358
388
GGML_METAL_ADD_KERNEL (mul_mm_f32_f32);
359
389
GGML_METAL_ADD_KERNEL (mul_mm_f16_f32);
@@ -454,6 +484,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
454
484
GGML_METAL_DEL_KERNEL (mul_mv_q4_K_f32);
455
485
GGML_METAL_DEL_KERNEL (mul_mv_q5_K_f32);
456
486
GGML_METAL_DEL_KERNEL (mul_mv_q6_K_f32);
487
+ GGML_METAL_DEL_KERNEL (mul_mv_id_f32_f32);
488
+ // GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
489
+ GGML_METAL_DEL_KERNEL (mul_mv_id_f16_f32);
490
+ // GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
491
+ // GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
492
+ GGML_METAL_DEL_KERNEL (mul_mv_id_q4_0_f32);
493
+ GGML_METAL_DEL_KERNEL (mul_mv_id_q4_1_f32);
494
+ GGML_METAL_DEL_KERNEL (mul_mv_id_q5_0_f32);
495
+ GGML_METAL_DEL_KERNEL (mul_mv_id_q5_1_f32);
496
+ GGML_METAL_DEL_KERNEL (mul_mv_id_q8_0_f32);
497
+ GGML_METAL_DEL_KERNEL (mul_mv_id_q2_K_f32);
498
+ GGML_METAL_DEL_KERNEL (mul_mv_id_q3_K_f32);
499
+ GGML_METAL_DEL_KERNEL (mul_mv_id_q4_K_f32);
500
+ GGML_METAL_DEL_KERNEL (mul_mv_id_q5_K_f32);
501
+ GGML_METAL_DEL_KERNEL (mul_mv_id_q6_K_f32);
457
502
if ([ctx->device supportsFamily: MTLGPUFamilyApple7]) {
458
503
GGML_METAL_DEL_KERNEL (mul_mm_f32_f32);
459
504
GGML_METAL_DEL_KERNEL (mul_mm_f16_f32);
@@ -1491,17 +1536,22 @@ void ggml_metal_graph_compute(
1491
1536
1492
1537
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
1493
1538
// to the matrix-vector kernel
1494
- int ne11_mm_min = 0 ;
1539
+ int ne11_mm_min = 1 ;
1495
1540
1496
1541
const int idx = ((int32_t *) dst->op_params )[0 ];
1497
1542
1498
1543
// batch size
1499
1544
GGML_ASSERT (ne01 == ne11);
1500
1545
1546
+ const int64_t _ne1 = 1 ; // kernel_mul_mm_impl needs a reference in constant memory
1547
+
1501
1548
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1502
1549
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1503
- if ([ctx->device supportsFamily: MTLGPUFamilyApple7] &&
1504
- ne11 > ne11_mm_min) {
1550
+ // !!!
1551
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1552
+ // indirect matrix multiplication
1553
+ // !!!
1554
+ if ([ctx->device supportsFamily: MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
1505
1555
switch (src2->type ) {
1506
1556
case GGML_TYPE_F32: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_f32_f32]; break ;
1507
1557
case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_f16_f32]; break ;
@@ -1517,7 +1567,6 @@ void ggml_metal_graph_compute(
1517
1567
case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_q6_K_f32]; break ;
1518
1568
default : GGML_ASSERT (false && " MUL_MAT_ID not implemented" );
1519
1569
}
1520
- const int64_t _ne1 = 1 ; // kernel_mul_mm_impl needs a reference in constant memory
1521
1570
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1522
1571
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1523
1572
[encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
@@ -1549,14 +1598,153 @@ void ggml_metal_graph_compute(
1549
1598
1550
1599
[encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
1551
1600
1552
- [encoder dispatchThreadgroups: MTLSizeMake ( (1 + 31 )/32 , (ne21 + 63 )/64 , ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
1553
- // [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1554
- // for (int64_t i01 = 0; i01 < src0->ne[1]; i01++) {
1555
- // [encoder setBuffer:id_src0 offset:offs_src0 + i01*nb01 atIndex:0];
1556
- // [encoder setBuffer:id_src1 offset:offs_src1 + i01*nb11 atIndex:1];
1557
- // [encoder setBuffer:id_dst offset:offs_dst + i01*nb1 atIndex:2];
1601
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
1602
+ [encoder dispatchThreadgroups: MTLSizeMake ( (_ne1 + 31 )/32 , (ne21 + 63 )/64 , ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
1603
+ } else {
1604
+ int nth0 = 32 ;
1605
+ int nth1 = 1 ;
1606
+ int nrows = 1 ;
1607
+ // printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1608
+
1609
+ // use custom matrix x vector kernel
1610
+ switch (src2t) {
1611
+ case GGML_TYPE_F32:
1612
+ {
1613
+ GGML_ASSERT (src1t == GGML_TYPE_F32);
1614
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_f32_f32];
1615
+ nrows = 4 ;
1616
+ } break ;
1617
+ case GGML_TYPE_F16:
1618
+ {
1619
+ GGML_ASSERT (src1t == GGML_TYPE_F32);
1620
+ nth0 = 32 ;
1621
+ nth1 = 1 ;
1622
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_f16_f32];
1623
+ } break ;
1624
+ case GGML_TYPE_Q4_0:
1625
+ {
1626
+ nth0 = 8 ;
1627
+ nth1 = 8 ;
1628
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_q4_0_f32];
1629
+ } break ;
1630
+ case GGML_TYPE_Q4_1:
1631
+ {
1632
+ nth0 = 8 ;
1633
+ nth1 = 8 ;
1634
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_q4_1_f32];
1635
+ } break ;
1636
+ case GGML_TYPE_Q5_0:
1637
+ {
1638
+ nth0 = 8 ;
1639
+ nth1 = 8 ;
1640
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_q5_0_f32];
1641
+ } break ;
1642
+ case GGML_TYPE_Q5_1:
1643
+ {
1644
+ nth0 = 8 ;
1645
+ nth1 = 8 ;
1646
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_q5_1_f32];
1647
+ } break ;
1648
+ case GGML_TYPE_Q8_0:
1649
+ {
1650
+ nth0 = 8 ;
1651
+ nth1 = 8 ;
1652
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_q8_0_f32];
1653
+ } break ;
1654
+ case GGML_TYPE_Q2_K:
1655
+ {
1656
+ nth0 = 2 ;
1657
+ nth1 = 32 ;
1658
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_q2_K_f32];
1659
+ } break ;
1660
+ case GGML_TYPE_Q3_K:
1661
+ {
1662
+ nth0 = 2 ;
1663
+ nth1 = 32 ;
1664
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_q3_K_f32];
1665
+ } break ;
1666
+ case GGML_TYPE_Q4_K:
1667
+ {
1668
+ nth0 = 4 ; // 1;
1669
+ nth1 = 8 ; // 32;
1670
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_q4_K_f32];
1671
+ } break ;
1672
+ case GGML_TYPE_Q5_K:
1673
+ {
1674
+ nth0 = 2 ;
1675
+ nth1 = 32 ;
1676
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_q5_K_f32];
1677
+ } break ;
1678
+ case GGML_TYPE_Q6_K:
1679
+ {
1680
+ nth0 = 2 ;
1681
+ nth1 = 32 ;
1682
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_id_q6_K_f32];
1683
+ } break ;
1684
+ default :
1685
+ {
1686
+ GGML_METAL_LOG_ERROR (" Asserting on type %d \n " , (int )src0t);
1687
+ GGML_ASSERT (false && " not implemented" );
1688
+ }
1689
+ };
1690
+
1691
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1692
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1693
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1694
+ [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 3 ];
1695
+ [encoder setBytes: &ne20 length: sizeof (ne20) atIndex: 4 ];
1696
+ [encoder setBytes: &ne21 length: sizeof (ne21) atIndex: 5 ];
1697
+ [encoder setBytes: &ne22 length: sizeof (ne22) atIndex: 6 ];
1698
+ [encoder setBytes: &nb20 length: sizeof (nb20) atIndex: 7 ];
1699
+ [encoder setBytes: &nb21 length: sizeof (nb21) atIndex: 8 ];
1700
+ [encoder setBytes: &nb22 length: sizeof (nb22) atIndex: 9 ];
1701
+ [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 10 ];
1702
+ [encoder setBytes: &_ne1 length: sizeof (_ne1) atIndex: 11 ];
1703
+ [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 12 ];
1704
+ [encoder setBytes: &ne13 length: sizeof (ne13) atIndex: 13 ];
1705
+ [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 14 ];
1706
+ [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 15 ];
1707
+ [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 16 ];
1708
+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 17 ];
1709
+ [encoder setBytes: &_ne1 length: sizeof (_ne1) atIndex: 18 ];
1710
+ [encoder setBytes: &nb1 length: sizeof (nb1) atIndex: 19 ];
1711
+ [encoder setBytes: &r2 length: sizeof (r2) atIndex: 20 ];
1712
+ [encoder setBytes: &r3 length: sizeof (r3) atIndex: 21 ];
1713
+ [encoder setBytes: &idx length: sizeof (idx) atIndex: 22 ];
1714
+ // TODO: how to make this an array? read Metal docs
1715
+ for (int j = 0 ; j < n_as; ++j) {
1716
+ struct ggml_tensor * src_cur = dst->src [2 + j];
1717
+
1718
+ size_t offs_src_cur = 0 ;
1719
+ id <MTLBuffer > id_src_cur = ggml_metal_get_buffer (ctx, src_cur, &offs_src_cur);
1720
+
1721
+ [encoder setBuffer: id_src_cur offset: offs_src_cur atIndex: 23 + j];
1722
+ }
1558
1723
1559
- // }
1724
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1725
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1726
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1727
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 7 )/8 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1728
+ }
1729
+ else if (src2t == GGML_TYPE_Q4_K) {
1730
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 3 )/4 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1731
+ }
1732
+ else if (src2t == GGML_TYPE_Q3_K) {
1733
+ #ifdef GGML_QKK_64
1734
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 1 )/2 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1735
+ #else
1736
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 3 )/4 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1737
+ #endif
1738
+ }
1739
+ else if (src2t == GGML_TYPE_Q5_K) {
1740
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 3 )/4 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1741
+ }
1742
+ else if (src2t == GGML_TYPE_Q6_K) {
1743
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 1 )/2 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1744
+ } else {
1745
+ const int64_t ny = (_ne1 + nrows - 1 )/nrows;
1746
+ [encoder dispatchThreadgroups: MTLSizeMake (ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1747
+ }
1560
1748
}
1561
1749
} break ;
1562
1750
case GGML_OP_GET_ROWS:
0 commit comments