@@ -4719,16 +4719,18 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
4719
4719
4720
4720
// the CUDA soft max implementation differs from the CPU implementation
4721
4721
// instead of doubles floats are used
4722
- static __global__ void soft_max_f32 (const float * x, float * dst, const int ncols) {
4723
- const int row = blockDim .x *blockIdx .x + threadIdx .x ;
4722
+ static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4723
+ const int rowx = blockDim .x *blockIdx .x + threadIdx .x ;
4724
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4724
4725
const int block_size = blockDim .y ;
4725
4726
const int tid = threadIdx .y ;
4726
4727
4727
4728
float max_val = -INFINITY;
4728
4729
4729
4730
for (int col = tid; col < ncols; col += block_size) {
4730
- const int i = row*ncols + col;
4731
- max_val = max (max_val, x[i]);
4731
+ const int ix = rowx*ncols + col;
4732
+ const int iy = rowy*ncols + col;
4733
+ max_val = max (max_val, x[ix]*scale + (y ? y[iy] : 0 .0f ));
4732
4734
}
4733
4735
4734
4736
// find the max value in the block
@@ -4740,10 +4742,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
4740
4742
float tmp = 0 .f ;
4741
4743
4742
4744
for (int col = tid; col < ncols; col += block_size) {
4743
- const int i = row*ncols + col;
4744
- const float val = expf (x[i] - max_val);
4745
+ const int ix = rowx*ncols + col;
4746
+ const int iy = rowy*ncols + col;
4747
+ const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - max_val);
4745
4748
tmp += val;
4746
- dst[i ] = val;
4749
+ dst[ix ] = val;
4747
4750
}
4748
4751
4749
4752
// sum up partial sums
@@ -4755,7 +4758,7 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
4755
4758
const float inv_tmp = 1 .f / tmp;
4756
4759
4757
4760
for (int col = tid; col < ncols; col += block_size) {
4758
- const int i = row *ncols + col;
4761
+ const int i = rowx *ncols + col;
4759
4762
dst[i] *= inv_tmp;
4760
4763
}
4761
4764
}
@@ -5792,10 +5795,10 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
5792
5795
diag_mask_inf_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x, rows_per_channel, n_past);
5793
5796
}
5794
5797
5795
- static void soft_max_f32_cuda (const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
5798
+ static void soft_max_f32_cuda (const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale , cudaStream_t stream) {
5796
5799
const dim3 block_dims (1 , WARP_SIZE, 1 );
5797
5800
const dim3 block_nums (nrows_x, 1 , 1 );
5798
- soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x);
5801
+ soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, y, dst, ncols_x, nrows_y, scale );
5799
5802
}
5800
5803
5801
5804
static void im2col_f32_f16_cuda (const float * x, half * dst,
@@ -6846,14 +6849,18 @@ inline void ggml_cuda_op_soft_max(
6846
6849
GGML_ASSERT (src0->type == GGML_TYPE_F32);
6847
6850
GGML_ASSERT ( dst->type == GGML_TYPE_F32);
6848
6851
6852
+ GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
6853
+
6849
6854
const int64_t ne00 = src0->ne [0 ];
6850
- const int64_t nrows = ggml_nrows (src0);
6855
+ const int64_t nrows_x = ggml_nrows (src0);
6856
+ const int64_t nrows_y = src1 ? ggml_nrows (src1) : 0 ;
6851
6857
6852
- soft_max_f32_cuda (src0_dd, dst_dd, ne00, nrows, main_stream);
6858
+ float scale = 1 .0f ;
6859
+ memcpy (&scale, dst->op_params , sizeof (float ));
6860
+
6861
+ soft_max_f32_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
6853
6862
6854
- (void ) src1;
6855
6863
(void ) dst;
6856
- (void ) src1_dd;
6857
6864
}
6858
6865
6859
6866
inline void ggml_cuda_op_scale (
0 commit comments