@@ -15879,35 +15879,29 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15879
15879
15880
15880
// thread scheduling for the different operations + work buffer size estimation
15881
15881
for (int i = 0; i < cgraph->n_nodes; i++) {
15882
- int n_tasks = 1;
15883
-
15884
15882
struct ggml_tensor * node = cgraph->nodes[i];
15885
15883
15884
+ const int n_tasks = ggml_get_n_tasks(node, n_threads);
15885
+
15886
15886
size_t cur = 0;
15887
15887
15888
15888
switch (node->op) {
15889
15889
case GGML_OP_CPY:
15890
15890
case GGML_OP_DUP:
15891
15891
{
15892
- n_tasks = n_threads;
15893
-
15894
15892
if (ggml_is_quantized(node->type)) {
15895
15893
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
15896
15894
}
15897
15895
} break;
15898
15896
case GGML_OP_ADD:
15899
15897
case GGML_OP_ADD1:
15900
15898
{
15901
- n_tasks = n_threads;
15902
-
15903
15899
if (ggml_is_quantized(node->src[0]->type)) {
15904
15900
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
15905
15901
}
15906
15902
} break;
15907
15903
case GGML_OP_ACC:
15908
15904
{
15909
- n_tasks = n_threads;
15910
-
15911
15905
if (ggml_is_quantized(node->src[0]->type)) {
15912
15906
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
15913
15907
}
@@ -15935,16 +15929,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15935
15929
} break;
15936
15930
case GGML_OP_OUT_PROD:
15937
15931
{
15938
- n_tasks = n_threads;
15939
-
15940
15932
if (ggml_is_quantized(node->src[0]->type)) {
15941
15933
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
15942
15934
}
15943
15935
} break;
15944
15936
case GGML_OP_SOFT_MAX:
15945
15937
{
15946
- n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
15947
-
15948
15938
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
15949
15939
} break;
15950
15940
case GGML_OP_CONV_TRANSPOSE_1D:
@@ -15974,7 +15964,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15974
15964
} break;
15975
15965
case GGML_OP_IM2COL:
15976
15966
{
15977
- n_tasks = n_threads;
15978
15967
} break;
15979
15968
case GGML_OP_CONV_TRANSPOSE_2D:
15980
15969
{
@@ -15992,8 +15981,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15992
15981
} break;
15993
15982
case GGML_OP_FLASH_ATTN:
15994
15983
{
15995
- n_tasks = n_threads;
15996
-
15997
15984
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
15998
15985
15999
15986
if (node->src[1]->type == GGML_TYPE_F32) {
@@ -16006,8 +15993,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16006
15993
} break;
16007
15994
case GGML_OP_FLASH_FF:
16008
15995
{
16009
- n_tasks = n_threads;
16010
-
16011
15996
if (node->src[1]->type == GGML_TYPE_F32) {
16012
15997
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16013
15998
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
@@ -16018,8 +16003,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16018
16003
} break;
16019
16004
case GGML_OP_FLASH_ATTN_BACK:
16020
16005
{
16021
- n_tasks = n_threads;
16022
-
16023
16006
const int64_t D = node->src[0]->ne[0];
16024
16007
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16025
16008
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
@@ -16034,8 +16017,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16034
16017
16035
16018
case GGML_OP_CROSS_ENTROPY_LOSS:
16036
16019
{
16037
- n_tasks = n_threads;
16038
-
16039
16020
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
16040
16021
} break;
16041
16022
case GGML_OP_COUNT:
0 commit comments