Skip to content

Commit 8602f5a

Browse files
committed
Merge branch 'master' into concedo_experimental
2 parents ac36aee + fbbc428 commit 8602f5a

File tree

1 file changed

+6
-22
lines changed

1 file changed

+6
-22
lines changed

Diff for: ggml.c

+6-22
Original file line numberDiff line numberDiff line change
@@ -15629,7 +15629,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1562915629
} break;
1563015630
case GGML_OP_DIAG_MASK_ZERO:
1563115631
case GGML_OP_DIAG_MASK_INF:
15632-
case GGML_OP_SOFT_MAX:
1563315632
case GGML_OP_SOFT_MAX_BACK:
1563415633
case GGML_OP_ROPE:
1563515634
case GGML_OP_ROPE_BACK:
@@ -15645,6 +15644,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1564515644
{
1564615645
n_tasks = 1; //TODO
1564715646
} break;
15647+
case GGML_OP_SOFT_MAX:
15648+
{
15649+
n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
15650+
} break;
1564815651
case GGML_OP_CONV_TRANSPOSE_1D:
1564915652
{
1565015653
n_tasks = n_threads;
@@ -15872,35 +15875,29 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1587215875

1587315876
// thread scheduling for the different operations + work buffer size estimation
1587415877
for (int i = 0; i < cgraph->n_nodes; i++) {
15875-
int n_tasks = 1;
15876-
1587715878
struct ggml_tensor * node = cgraph->nodes[i];
1587815879

15880+
const int n_tasks = ggml_get_n_tasks(node, n_threads);
15881+
1587915882
size_t cur = 0;
1588015883

1588115884
switch (node->op) {
1588215885
case GGML_OP_CPY:
1588315886
case GGML_OP_DUP:
1588415887
{
15885-
n_tasks = n_threads;
15886-
1588715888
if (ggml_is_quantized(node->type)) {
1588815889
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
1588915890
}
1589015891
} break;
1589115892
case GGML_OP_ADD:
1589215893
case GGML_OP_ADD1:
1589315894
{
15894-
n_tasks = n_threads;
15895-
1589615895
if (ggml_is_quantized(node->src[0]->type)) {
1589715896
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
1589815897
}
1589915898
} break;
1590015899
case GGML_OP_ACC:
1590115900
{
15902-
n_tasks = n_threads;
15903-
1590415901
if (ggml_is_quantized(node->src[0]->type)) {
1590515902
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
1590615903
}
@@ -15928,16 +15925,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1592815925
} break;
1592915926
case GGML_OP_OUT_PROD:
1593015927
{
15931-
n_tasks = n_threads;
15932-
1593315928
if (ggml_is_quantized(node->src[0]->type)) {
1593415929
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
1593515930
}
1593615931
} break;
1593715932
case GGML_OP_SOFT_MAX:
1593815933
{
15939-
n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
15940-
1594115934
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
1594215935
} break;
1594315936
case GGML_OP_CONV_TRANSPOSE_1D:
@@ -15967,7 +15960,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1596715960
} break;
1596815961
case GGML_OP_IM2COL:
1596915962
{
15970-
n_tasks = n_threads;
1597115963
} break;
1597215964
case GGML_OP_CONV_TRANSPOSE_2D:
1597315965
{
@@ -15985,8 +15977,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1598515977
} break;
1598615978
case GGML_OP_FLASH_ATTN:
1598715979
{
15988-
n_tasks = n_threads;
15989-
1599015980
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
1599115981

1599215982
if (node->src[1]->type == GGML_TYPE_F32) {
@@ -15999,8 +15989,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1599915989
} break;
1600015990
case GGML_OP_FLASH_FF:
1600115991
{
16002-
n_tasks = n_threads;
16003-
1600415992
if (node->src[1]->type == GGML_TYPE_F32) {
1600515993
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
1600615994
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
@@ -16011,8 +15999,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1601115999
} break;
1601216000
case GGML_OP_FLASH_ATTN_BACK:
1601316001
{
16014-
n_tasks = n_threads;
16015-
1601616002
const int64_t D = node->src[0]->ne[0];
1601716003
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
1601816004
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
@@ -16027,8 +16013,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1602716013

1602816014
case GGML_OP_CROSS_ENTROPY_LOSS:
1602916015
{
16030-
n_tasks = n_threads;
16031-
1603216016
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
1603316017
} break;
1603416018
case GGML_OP_COUNT:

0 commit comments

Comments
 (0)