@@ -4829,7 +4829,9 @@ static struct ggml_tensor * ggml_soft_max_impl(
4829
4829
struct ggml_tensor * mask,
4830
4830
float scale,
4831
4831
bool inplace) {
4832
+ GGML_ASSERT(ggml_is_contiguous(a));
4832
4833
if (mask) {
4834
+ GGML_ASSERT(ggml_is_contiguous(mask));
4833
4835
GGML_ASSERT(mask->ne[2] == 1);
4834
4836
GGML_ASSERT(mask->ne[3] == 1);
4835
4837
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
@@ -10571,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero(
10571
10573
static void ggml_compute_forward_soft_max_f32(
10572
10574
const struct ggml_compute_params * params,
10573
10575
const struct ggml_tensor * src0,
10574
- struct ggml_tensor * dst) {
10575
- GGML_ASSERT(ggml_is_contiguous(src0));
10576
- GGML_ASSERT (ggml_is_contiguous(dst));
10577
- GGML_ASSERT (ggml_are_same_shape(src0, dst));
10576
+ const struct ggml_tensor * src1,
10577
+ struct ggml_tensor * dst) {
10578
+ assert (ggml_is_contiguous(dst));
10579
+ assert (ggml_are_same_shape(src0, dst));
10578
10580
10579
10581
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10580
10582
return;
10581
10583
}
10582
10584
10585
+ float scale = 1.0f;
10586
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
10587
+
10583
10588
// TODO: handle transposed/permuted matrices
10584
10589
10585
10590
const int ith = params->ith;
10586
10591
const int nth = params->nth;
10587
10592
10593
+ const int64_t ne11 = src1 ? src1->ne[1] : 1;
10594
+
10588
10595
const int nc = src0->ne[0];
10589
10596
const int nr = ggml_nrows(src0);
10590
10597
@@ -10595,29 +10602,39 @@ static void ggml_compute_forward_soft_max_f32(
10595
10602
const int ir0 = dr*ith;
10596
10603
const int ir1 = MIN(ir0 + dr, nr);
10597
10604
10605
+ float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
10606
+
10598
10607
for (int i1 = ir0; i1 < ir1; i1++) {
10599
- float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
10600
- float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
10608
+ float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
10609
+ float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
10610
+
10611
+ // broadcast the mask across rows
10612
+ float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
10613
+
10614
+ float * wp = wdata;
10615
+ for (int i = 0; i < nc; i++) {
10616
+ wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f);
10617
+ }
10601
10618
10602
10619
#ifndef NDEBUG
10603
10620
for (int i = 0; i < nc; ++i) {
10604
10621
//printf("p[%d] = %f\n", i, p[i]);
10605
- assert(!isnan(sp [i]));
10622
+ assert(!isnan(wp [i]));
10606
10623
}
10607
10624
#endif
10608
10625
10609
10626
float max = -INFINITY;
10610
- ggml_vec_max_f32(nc, &max, sp );
10627
+ ggml_vec_max_f32(nc, &max, wp );
10611
10628
10612
10629
ggml_float sum = 0.0;
10613
10630
10614
10631
uint16_t scvt;
10615
10632
for (int i = 0; i < nc; i++) {
10616
- if (sp [i] == -INFINITY) {
10633
+ if (wp [i] == -INFINITY) {
10617
10634
dp[i] = 0.0f;
10618
10635
} else {
10619
- // const float val = (sp [i] == -INFINITY) ? 0.0 : exp(sp [i] - max);
10620
- ggml_fp16_t s = GGML_FP32_TO_FP16(sp [i] - max);
10636
+ // const float val = (wp [i] == -INFINITY) ? 0.0 : exp(wp [i] - max);
10637
+ ggml_fp16_t s = GGML_FP32_TO_FP16(wp [i] - max);
10621
10638
memcpy(&scvt, &s, sizeof(scvt));
10622
10639
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
10623
10640
sum += (ggml_float)val;
@@ -10642,11 +10659,12 @@ static void ggml_compute_forward_soft_max_f32(
10642
10659
static void ggml_compute_forward_soft_max(
10643
10660
const struct ggml_compute_params * params,
10644
10661
const struct ggml_tensor * src0,
10645
- struct ggml_tensor * dst) {
10662
+ const struct ggml_tensor * src1,
10663
+ struct ggml_tensor * dst) {
10646
10664
switch (src0->type) {
10647
10665
case GGML_TYPE_F32:
10648
10666
{
10649
- ggml_compute_forward_soft_max_f32(params, src0, dst);
10667
+ ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
10650
10668
} break;
10651
10669
default:
10652
10670
{
@@ -13883,7 +13901,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13883
13901
} break;
13884
13902
case GGML_OP_SOFT_MAX:
13885
13903
{
13886
- ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
13904
+ ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor );
13887
13905
} break;
13888
13906
case GGML_OP_SOFT_MAX_BACK:
13889
13907
{
@@ -15919,6 +15937,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
15919
15937
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
15920
15938
}
15921
15939
} break;
15940
+ case GGML_OP_SOFT_MAX:
15941
+ {
15942
+ n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
15943
+
15944
+ cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
15945
+ } break;
15922
15946
case GGML_OP_CONV_TRANSPOSE_1D:
15923
15947
{
15924
15948
GGML_ASSERT(node->src[0]->ne[3] == 1);
0 commit comments