@@ -15629,7 +15629,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15629
15629
} break;
15630
15630
case GGML_OP_DIAG_MASK_ZERO:
15631
15631
case GGML_OP_DIAG_MASK_INF:
15632
- case GGML_OP_SOFT_MAX:
15633
15632
case GGML_OP_SOFT_MAX_BACK:
15634
15633
case GGML_OP_ROPE:
15635
15634
case GGML_OP_ROPE_BACK:
@@ -15645,6 +15644,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
15645
15644
{
15646
15645
n_tasks = 1; //TODO
15647
15646
} break;
15647
+ case GGML_OP_SOFT_MAX:
15648
+ {
15649
+ n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
15650
+ } break;
15648
15651
case GGML_OP_CONV_TRANSPOSE_1D:
15649
15652
{
15650
15653
n_tasks = n_threads;
@@ -15872,35 +15875,29 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15872
15875
15873
15876
// thread scheduling for the different operations + work buffer size estimation
15874
15877
for (int i = 0; i < cgraph->n_nodes; i++) {
15875
- int n_tasks = 1;
15876
-
15877
15878
struct ggml_tensor * node = cgraph->nodes[i];
15878
15879
15880
+ const int n_tasks = ggml_get_n_tasks(node, n_threads);
15881
+
15879
15882
size_t cur = 0;
15880
15883
15881
15884
switch (node->op) {
15882
15885
case GGML_OP_CPY:
15883
15886
case GGML_OP_DUP:
15884
15887
{
15885
- n_tasks = n_threads;
15886
-
15887
15888
if (ggml_is_quantized(node->type)) {
15888
15889
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
15889
15890
}
15890
15891
} break;
15891
15892
case GGML_OP_ADD:
15892
15893
case GGML_OP_ADD1:
15893
15894
{
15894
- n_tasks = n_threads;
15895
-
15896
15895
if (ggml_is_quantized(node->src[0]->type)) {
15897
15896
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
15898
15897
}
15899
15898
} break;
15900
15899
case GGML_OP_ACC:
15901
15900
{
15902
- n_tasks = n_threads;
15903
-
15904
15901
if (ggml_is_quantized(node->src[0]->type)) {
15905
15902
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
15906
15903
}
@@ -15928,16 +15925,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15928
15925
} break;
15929
15926
case GGML_OP_OUT_PROD:
15930
15927
{
15931
- n_tasks = n_threads;
15932
-
15933
15928
if (ggml_is_quantized(node->src[0]->type)) {
15934
15929
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
15935
15930
}
15936
15931
} break;
15937
15932
case GGML_OP_SOFT_MAX:
15938
15933
{
15939
- n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
15940
-
15941
15934
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
15942
15935
} break;
15943
15936
case GGML_OP_CONV_TRANSPOSE_1D:
@@ -15967,7 +15960,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15967
15960
} break;
15968
15961
case GGML_OP_IM2COL:
15969
15962
{
15970
- n_tasks = n_threads;
15971
15963
} break;
15972
15964
case GGML_OP_CONV_TRANSPOSE_2D:
15973
15965
{
@@ -15985,8 +15977,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15985
15977
} break;
15986
15978
case GGML_OP_FLASH_ATTN:
15987
15979
{
15988
- n_tasks = n_threads;
15989
-
15990
15980
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
15991
15981
15992
15982
if (node->src[1]->type == GGML_TYPE_F32) {
@@ -15999,8 +15989,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15999
15989
} break;
16000
15990
case GGML_OP_FLASH_FF:
16001
15991
{
16002
- n_tasks = n_threads;
16003
-
16004
15992
if (node->src[1]->type == GGML_TYPE_F32) {
16005
15993
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16006
15994
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
@@ -16011,8 +15999,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16011
15999
} break;
16012
16000
case GGML_OP_FLASH_ATTN_BACK:
16013
16001
{
16014
- n_tasks = n_threads;
16015
-
16016
16002
const int64_t D = node->src[0]->ne[0];
16017
16003
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16018
16004
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
@@ -16027,8 +16013,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16027
16013
16028
16014
case GGML_OP_CROSS_ENTROPY_LOSS:
16029
16015
{
16030
- n_tasks = n_threads;
16031
-
16032
16016
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
16033
16017
} break;
16034
16018
case GGML_OP_COUNT:
0 commit comments