Skip to content

Commit 6a66f69

Browse files
committed
ggml : implement soft_max_ext (CPU)
1 parent 88519fb commit 6a66f69

File tree

1 file changed

+38
-14
lines changed

1 file changed

+38
-14
lines changed

Diff for: ggml.c

+38-14
Original file line numberDiff line numberDiff line change
@@ -4829,7 +4829,9 @@ static struct ggml_tensor * ggml_soft_max_impl(
48294829
struct ggml_tensor * mask,
48304830
float scale,
48314831
bool inplace) {
4832+
GGML_ASSERT(ggml_is_contiguous(a));
48324833
if (mask) {
4834+
GGML_ASSERT(ggml_is_contiguous(mask));
48334835
GGML_ASSERT(mask->ne[2] == 1);
48344836
GGML_ASSERT(mask->ne[3] == 1);
48354837
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
@@ -10571,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero(
1057110573
static void ggml_compute_forward_soft_max_f32(
1057210574
const struct ggml_compute_params * params,
1057310575
const struct ggml_tensor * src0,
10574-
struct ggml_tensor * dst) {
10575-
GGML_ASSERT(ggml_is_contiguous(src0));
10576-
GGML_ASSERT(ggml_is_contiguous(dst));
10577-
GGML_ASSERT(ggml_are_same_shape(src0, dst));
10576+
const struct ggml_tensor * src1,
10577+
struct ggml_tensor * dst) {
10578+
assert(ggml_is_contiguous(dst));
10579+
assert(ggml_are_same_shape(src0, dst));
1057810580

1057910581
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1058010582
return;
1058110583
}
1058210584

10585+
float scale = 1.0f;
10586+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
10587+
1058310588
// TODO: handle transposed/permuted matrices
1058410589

1058510590
const int ith = params->ith;
1058610591
const int nth = params->nth;
1058710592

10593+
const int64_t ne11 = src1 ? src1->ne[1] : 1;
10594+
1058810595
const int nc = src0->ne[0];
1058910596
const int nr = ggml_nrows(src0);
1059010597

@@ -10595,29 +10602,39 @@ static void ggml_compute_forward_soft_max_f32(
1059510602
const int ir0 = dr*ith;
1059610603
const int ir1 = MIN(ir0 + dr, nr);
1059710604

10605+
float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
10606+
1059810607
for (int i1 = ir0; i1 < ir1; i1++) {
10599-
float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
10600-
float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
10608+
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
10609+
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
10610+
10611+
// broadcast the mask across rows
10612+
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
10613+
10614+
float * wp = wdata;
10615+
for (int i = 0; i < nc; i++) {
10616+
wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f);
10617+
}
1060110618

1060210619
#ifndef NDEBUG
1060310620
for (int i = 0; i < nc; ++i) {
1060410621
//printf("p[%d] = %f\n", i, p[i]);
10605-
assert(!isnan(sp[i]));
10622+
assert(!isnan(wp[i]));
1060610623
}
1060710624
#endif
1060810625

1060910626
float max = -INFINITY;
10610-
ggml_vec_max_f32(nc, &max, sp);
10627+
ggml_vec_max_f32(nc, &max, wp);
1061110628

1061210629
ggml_float sum = 0.0;
1061310630

1061410631
uint16_t scvt;
1061510632
for (int i = 0; i < nc; i++) {
10616-
if (sp[i] == -INFINITY) {
10633+
if (wp[i] == -INFINITY) {
1061710634
dp[i] = 0.0f;
1061810635
} else {
10619-
// const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
10620-
ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
10636+
// const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
10637+
ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
1062110638
memcpy(&scvt, &s, sizeof(scvt));
1062210639
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
1062310640
sum += (ggml_float)val;
@@ -10642,11 +10659,12 @@ static void ggml_compute_forward_soft_max_f32(
1064210659
static void ggml_compute_forward_soft_max(
1064310660
const struct ggml_compute_params * params,
1064410661
const struct ggml_tensor * src0,
10645-
struct ggml_tensor * dst) {
10662+
const struct ggml_tensor * src1,
10663+
struct ggml_tensor * dst) {
1064610664
switch (src0->type) {
1064710665
case GGML_TYPE_F32:
1064810666
{
10649-
ggml_compute_forward_soft_max_f32(params, src0, dst);
10667+
ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
1065010668
} break;
1065110669
default:
1065210670
{
@@ -13883,7 +13901,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1388313901
} break;
1388413902
case GGML_OP_SOFT_MAX:
1388513903
{
13886-
ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
13904+
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
1388713905
} break;
1388813906
case GGML_OP_SOFT_MAX_BACK:
1388913907
{
@@ -15919,6 +15937,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
1591915937
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
1592015938
}
1592115939
} break;
15940+
case GGML_OP_SOFT_MAX:
15941+
{
15942+
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
15943+
15944+
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
15945+
} break;
1592215946
case GGML_OP_CONV_TRANSPOSE_1D:
1592315947
{
1592415948
GGML_ASSERT(node->src[0]->ne[3] == 1);

0 commit comments

Comments
 (0)