Skip to content

Commit fbbc428

Browse files
authored
ggml : reuse ggml_get_n_tasks() in ggml_graph_plan() (#4308)
* ggml : fix soft max out-of-bounds access ggml-ci * ggml : reuse ggml_get_n_tasks() in ggml_graph_plan() ggml-ci
1 parent adf3de4 commit fbbc428

File tree

1 file changed

+2
-21
lines changed

1 file changed

+2
-21
lines changed

Diff for: ggml.c

+2-21
Original file line numberDiff line numberDiff line change
@@ -15879,35 +15879,29 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1587915879

1588015880
// thread scheduling for the different operations + work buffer size estimation
1588115881
for (int i = 0; i < cgraph->n_nodes; i++) {
15882-
int n_tasks = 1;
15883-
1588415882
struct ggml_tensor * node = cgraph->nodes[i];
1588515883

15884+
const int n_tasks = ggml_get_n_tasks(node, n_threads);
15885+
1588615886
size_t cur = 0;
1588715887

1588815888
switch (node->op) {
1588915889
case GGML_OP_CPY:
1589015890
case GGML_OP_DUP:
1589115891
{
15892-
n_tasks = n_threads;
15893-
1589415892
if (ggml_is_quantized(node->type)) {
1589515893
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
1589615894
}
1589715895
} break;
1589815896
case GGML_OP_ADD:
1589915897
case GGML_OP_ADD1:
1590015898
{
15901-
n_tasks = n_threads;
15902-
1590315899
if (ggml_is_quantized(node->src[0]->type)) {
1590415900
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
1590515901
}
1590615902
} break;
1590715903
case GGML_OP_ACC:
1590815904
{
15909-
n_tasks = n_threads;
15910-
1591115905
if (ggml_is_quantized(node->src[0]->type)) {
1591215906
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
1591315907
}
@@ -15935,16 +15929,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1593515929
} break;
1593615930
case GGML_OP_OUT_PROD:
1593715931
{
15938-
n_tasks = n_threads;
15939-
1594015932
if (ggml_is_quantized(node->src[0]->type)) {
1594115933
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
1594215934
}
1594315935
} break;
1594415936
case GGML_OP_SOFT_MAX:
1594515937
{
15946-
n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
15947-
1594815938
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
1594915939
} break;
1595015940
case GGML_OP_CONV_TRANSPOSE_1D:
@@ -15974,7 +15964,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1597415964
} break;
1597515965
case GGML_OP_IM2COL:
1597615966
{
15977-
n_tasks = n_threads;
1597815967
} break;
1597915968
case GGML_OP_CONV_TRANSPOSE_2D:
1598015969
{
@@ -15992,8 +15981,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1599215981
} break;
1599315982
case GGML_OP_FLASH_ATTN:
1599415983
{
15995-
n_tasks = n_threads;
15996-
1599715984
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
1599815985

1599915986
if (node->src[1]->type == GGML_TYPE_F32) {
@@ -16006,8 +15993,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1600615993
} break;
1600715994
case GGML_OP_FLASH_FF:
1600815995
{
16009-
n_tasks = n_threads;
16010-
1601115996
if (node->src[1]->type == GGML_TYPE_F32) {
1601215997
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
1601315998
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
@@ -16018,8 +16003,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1601816003
} break;
1601916004
case GGML_OP_FLASH_ATTN_BACK:
1602016005
{
16021-
n_tasks = n_threads;
16022-
1602316006
const int64_t D = node->src[0]->ne[0];
1602416007
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
1602516008
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
@@ -16034,8 +16017,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1603416017

1603516018
case GGML_OP_CROSS_ENTROPY_LOSS:
1603616019
{
16037-
n_tasks = n_threads;
16038-
1603916020
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
1604016021
} break;
1604116022
case GGML_OP_COUNT:

0 commit comments

Comments
 (0)