Skip to content

Commit 88519fb

Browse files
committed
cuda : implement soft_max_ext
1 parent e89597c commit 88519fb

File tree

3 files changed

+28
-14
lines changed

3 files changed

+28
-14
lines changed

Diff for: ggml-cuda.cu

+21-14
Original file line numberDiff line numberDiff line change
@@ -4719,16 +4719,18 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
47194719

47204720
// the CUDA soft max implementation differs from the CPU implementation
47214721
// instead of doubles floats are used
4722-
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
4723-
const int row = blockDim.x*blockIdx.x + threadIdx.x;
4722+
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4723+
const int rowx = blockDim.x*blockIdx.x + threadIdx.x;
4724+
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
47244725
const int block_size = blockDim.y;
47254726
const int tid = threadIdx.y;
47264727

47274728
float max_val = -INFINITY;
47284729

47294730
for (int col = tid; col < ncols; col += block_size) {
4730-
const int i = row*ncols + col;
4731-
max_val = max(max_val, x[i]);
4731+
const int ix = rowx*ncols + col;
4732+
const int iy = rowy*ncols + col;
4733+
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
47324734
}
47334735

47344736
// find the max value in the block
@@ -4740,10 +4742,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
47404742
float tmp = 0.f;
47414743

47424744
for (int col = tid; col < ncols; col += block_size) {
4743-
const int i = row*ncols + col;
4744-
const float val = expf(x[i] - max_val);
4745+
const int ix = rowx*ncols + col;
4746+
const int iy = rowy*ncols + col;
4747+
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
47454748
tmp += val;
4746-
dst[i] = val;
4749+
dst[ix] = val;
47474750
}
47484751

47494752
// sum up partial sums
@@ -4755,7 +4758,7 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
47554758
const float inv_tmp = 1.f / tmp;
47564759

47574760
for (int col = tid; col < ncols; col += block_size) {
4758-
const int i = row*ncols + col;
4761+
const int i = rowx*ncols + col;
47594762
dst[i] *= inv_tmp;
47604763
}
47614764
}
@@ -5792,10 +5795,10 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
57925795
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
57935796
}
57945797

5795-
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
5798+
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
57965799
const dim3 block_dims(1, WARP_SIZE, 1);
57975800
const dim3 block_nums(nrows_x, 1, 1);
5798-
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
5801+
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
57995802
}
58005803

58015804
static void im2col_f32_f16_cuda(const float * x, half * dst,
@@ -6846,14 +6849,18 @@ inline void ggml_cuda_op_soft_max(
68466849
GGML_ASSERT(src0->type == GGML_TYPE_F32);
68476850
GGML_ASSERT( dst->type == GGML_TYPE_F32);
68486851

6852+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
6853+
68496854
const int64_t ne00 = src0->ne[0];
6850-
const int64_t nrows = ggml_nrows(src0);
6855+
const int64_t nrows_x = ggml_nrows(src0);
6856+
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 0;
68516857

6852-
soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
6858+
float scale = 1.0f;
6859+
memcpy(&scale, dst->op_params, sizeof(float));
6860+
6861+
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
68536862

6854-
(void) src1;
68556863
(void) dst;
6856-
(void) src1_dd;
68576864
}
68586865

68596866
inline void ggml_cuda_op_scale(

Diff for: ggml.c

+6
Original file line numberDiff line numberDiff line change
@@ -4829,6 +4829,12 @@ static struct ggml_tensor * ggml_soft_max_impl(
48294829
struct ggml_tensor * mask,
48304830
float scale,
48314831
bool inplace) {
4832+
if (mask) {
4833+
GGML_ASSERT(mask->ne[2] == 1);
4834+
GGML_ASSERT(mask->ne[3] == 1);
4835+
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
4836+
}
4837+
48324838
bool is_node = false;
48334839

48344840
if (a->grad) {

Diff for: llama.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -5048,6 +5048,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
50485048
{ "kq_scaled_alibi", OFFLOAD_FUNC_KQ },
50495049
{ "kq_masked", OFFLOAD_FUNC_KQ },
50505050
{ "kq_soft_max", OFFLOAD_FUNC_V },
5051+
{ "kq_soft_max_ext", OFFLOAD_FUNC_V },
50515052
{ "v", OFFLOAD_FUNC_V },
50525053
{ "kqv", OFFLOAD_FUNC_V },
50535054
{ "kqv_merged", OFFLOAD_FUNC_V },

0 commit comments

Comments
 (0)