Skip to content

Commit 6cfb31f

Browse files
committed
metal : add indirect mat-vec kernels for all quantization types
1 parent 016f9bb commit 6cfb31f

File tree

2 files changed

+1255
-82
lines changed

2 files changed

+1255
-82
lines changed

Diff for: ggml-metal.m

+199-11
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,21 @@
102102
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
103103
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
104104
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
105+
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
106+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
107+
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
108+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
109+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
110+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
111+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
112+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
113+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
114+
GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
115+
GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
116+
GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
117+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
118+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
119+
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
105120
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
106121
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
107122
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -354,6 +369,21 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
354369
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
355370
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
356371
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
372+
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
373+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
374+
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
375+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
376+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
377+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
378+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
379+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
380+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
381+
GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
382+
GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
383+
GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
384+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
385+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
386+
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
357387
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
358388
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
359389
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -454,6 +484,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
454484
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
455485
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
456486
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
487+
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
488+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
489+
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
490+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
491+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
492+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
493+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
494+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
495+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
496+
GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
497+
GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
498+
GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
499+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
500+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
501+
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
457502
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
458503
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
459504
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -1491,17 +1536,22 @@ void ggml_metal_graph_compute(
14911536

14921537
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
14931538
// to the matrix-vector kernel
1494-
int ne11_mm_min = 0;
1539+
int ne11_mm_min = 1;
14951540

14961541
const int idx = ((int32_t *) dst->op_params)[0];
14971542

14981543
// batch size
14991544
GGML_ASSERT(ne01 == ne11);
15001545

1546+
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
1547+
15011548
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
15021549
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1503-
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1504-
ne11 > ne11_mm_min) {
1550+
// !!!
1551+
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
1552+
// indirect matrix multiplication
1553+
// !!!
1554+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
15051555
switch (src2->type) {
15061556
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
15071557
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
@@ -1517,7 +1567,6 @@ void ggml_metal_graph_compute(
15171567
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
15181568
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
15191569
}
1520-
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
15211570
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
15221571
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
15231572
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -1549,14 +1598,153 @@ void ggml_metal_graph_compute(
15491598

15501599
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
15511600

1552-
[encoder dispatchThreadgroups:MTLSizeMake( (1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1553-
//[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1554-
//for (int64_t i01 = 0; i01 < src0->ne[1]; i01++) {
1555-
// [encoder setBuffer:id_src0 offset:offs_src0 + i01*nb01 atIndex:0];
1556-
// [encoder setBuffer:id_src1 offset:offs_src1 + i01*nb11 atIndex:1];
1557-
// [encoder setBuffer:id_dst offset:offs_dst + i01*nb1 atIndex:2];
1601+
// TODO: processing one row at a time (ne11 -> 1) is not efficient
1602+
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1603+
} else {
1604+
int nth0 = 32;
1605+
int nth1 = 1;
1606+
int nrows = 1;
1607+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1608+
1609+
// use custom matrix x vector kernel
1610+
switch (src2t) {
1611+
case GGML_TYPE_F32:
1612+
{
1613+
GGML_ASSERT(src1t == GGML_TYPE_F32);
1614+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1615+
nrows = 4;
1616+
} break;
1617+
case GGML_TYPE_F16:
1618+
{
1619+
GGML_ASSERT(src1t == GGML_TYPE_F32);
1620+
nth0 = 32;
1621+
nth1 = 1;
1622+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1623+
} break;
1624+
case GGML_TYPE_Q4_0:
1625+
{
1626+
nth0 = 8;
1627+
nth1 = 8;
1628+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1629+
} break;
1630+
case GGML_TYPE_Q4_1:
1631+
{
1632+
nth0 = 8;
1633+
nth1 = 8;
1634+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1635+
} break;
1636+
case GGML_TYPE_Q5_0:
1637+
{
1638+
nth0 = 8;
1639+
nth1 = 8;
1640+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1641+
} break;
1642+
case GGML_TYPE_Q5_1:
1643+
{
1644+
nth0 = 8;
1645+
nth1 = 8;
1646+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1647+
} break;
1648+
case GGML_TYPE_Q8_0:
1649+
{
1650+
nth0 = 8;
1651+
nth1 = 8;
1652+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1653+
} break;
1654+
case GGML_TYPE_Q2_K:
1655+
{
1656+
nth0 = 2;
1657+
nth1 = 32;
1658+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1659+
} break;
1660+
case GGML_TYPE_Q3_K:
1661+
{
1662+
nth0 = 2;
1663+
nth1 = 32;
1664+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1665+
} break;
1666+
case GGML_TYPE_Q4_K:
1667+
{
1668+
nth0 = 4; //1;
1669+
nth1 = 8; //32;
1670+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1671+
} break;
1672+
case GGML_TYPE_Q5_K:
1673+
{
1674+
nth0 = 2;
1675+
nth1 = 32;
1676+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1677+
} break;
1678+
case GGML_TYPE_Q6_K:
1679+
{
1680+
nth0 = 2;
1681+
nth1 = 32;
1682+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1683+
} break;
1684+
default:
1685+
{
1686+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1687+
GGML_ASSERT(false && "not implemented");
1688+
}
1689+
};
1690+
1691+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1692+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1693+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1694+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1695+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1696+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1697+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1698+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1699+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1700+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1701+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1702+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1703+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1704+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1705+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1706+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1707+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1708+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1709+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1710+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1711+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1712+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1713+
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1714+
// TODO: how to make this an array? read Metal docs
1715+
for (int j = 0; j < n_as; ++j) {
1716+
struct ggml_tensor * src_cur = dst->src[2 + j];
1717+
1718+
size_t offs_src_cur = 0;
1719+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1720+
1721+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1722+
}
15581723

1559-
//}
1724+
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1725+
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1726+
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1727+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1728+
}
1729+
else if (src2t == GGML_TYPE_Q4_K) {
1730+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1731+
}
1732+
else if (src2t == GGML_TYPE_Q3_K) {
1733+
#ifdef GGML_QKK_64
1734+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1735+
#else
1736+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1737+
#endif
1738+
}
1739+
else if (src2t == GGML_TYPE_Q5_K) {
1740+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1741+
}
1742+
else if (src2t == GGML_TYPE_Q6_K) {
1743+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1744+
} else {
1745+
const int64_t ny = (_ne1 + nrows - 1)/nrows;
1746+
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1747+
}
15601748
}
15611749
} break;
15621750
case GGML_OP_GET_ROWS:

0 commit comments

Comments
 (0)