Skip to content

Commit a8e66ef

Browse files
committed
Revert "ggml : add ggml_soft_max_ext (ggml-org#4256)"
This reverts commit ef47ec1.
1 parent a829a1e commit a8e66ef

File tree

8 files changed

+183
-298
lines changed

8 files changed

+183
-298
lines changed

Diff for: examples/batched-bench/batched-bench.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ int main(int argc, char ** argv) {
155155
}
156156

157157
LOG_TEE("\n");
158-
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch);
158+
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq);
159159
LOG_TEE("\n");
160160

161161
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");

Diff for: ggml-alloc.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
137137

138138
#ifdef GGML_ALLOCATOR_DEBUG
139139
add_allocated_tensor(alloc, tensor);
140-
size_t cur_max = (char*)addr - (char*)alloc->base + size;
140+
size_t cur_max = (char*)addr - (char*)alloc->data + size;
141141
if (cur_max > alloc->max_size) {
142142
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
143143
for (int i = 0; i < 1024; i++) {

Diff for: ggml-cuda.cu

+43-87
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443443
#define CUDA_SCALE_BLOCK_SIZE 256
444444
#define CUDA_CLAMP_BLOCK_SIZE 256
445445
#define CUDA_ROPE_BLOCK_SIZE 256
446-
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
447446
#define CUDA_ALIBI_BLOCK_SIZE 32
448447
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
449448
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -503,31 +502,6 @@ static size_t g_scratch_offset = 0;
503502

504503
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
505504

506-
static __device__ __forceinline__ float warp_reduce_sum(float x) {
507-
#pragma unroll
508-
for (int mask = 16; mask > 0; mask >>= 1) {
509-
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
510-
}
511-
return x;
512-
}
513-
514-
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
515-
#pragma unroll
516-
for (int mask = 16; mask > 0; mask >>= 1) {
517-
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
518-
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
519-
}
520-
return a;
521-
}
522-
523-
static __device__ __forceinline__ float warp_reduce_max(float x) {
524-
#pragma unroll
525-
for (int mask = 16; mask > 0; mask >>= 1) {
526-
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
527-
}
528-
return x;
529-
}
530-
531505
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
532506
const int i = blockDim.x*blockIdx.x + threadIdx.x;
533507

@@ -604,6 +578,15 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
604578
dst[i] = x[i] * x[i];
605579
}
606580

581+
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
582+
#pragma unroll
583+
for (int mask = 16; mask > 0; mask >>= 1) {
584+
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
585+
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
586+
}
587+
return a;
588+
}
589+
607590
template <int block_size>
608591
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
609592
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -642,6 +625,14 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
642625
}
643626
}
644627

628+
static __device__ __forceinline__ float warp_reduce_sum(float x) {
629+
#pragma unroll
630+
for (int mask = 16; mask > 0; mask >>= 1) {
631+
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
632+
}
633+
return x;
634+
}
635+
645636
template <int block_size>
646637
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
647638
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -4727,74 +4718,45 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
47274718
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
47284719
}
47294720

4730-
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4731-
const int tid = threadIdx.x;
4732-
const int rowx = blockIdx.x;
4733-
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4734-
4735-
const int block_size = blockDim.x;
4736-
4737-
const int warp_id = threadIdx.x / WARP_SIZE;
4738-
const int lane_id = threadIdx.x % WARP_SIZE;
4739-
4740-
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
4721+
// the CUDA soft max implementation differs from the CPU implementation
4722+
// instead of doubles floats are used
4723+
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
4724+
const int row = blockDim.x*blockIdx.x + threadIdx.x;
4725+
const int block_size = blockDim.y;
4726+
const int tid = threadIdx.y;
47414727

47424728
float max_val = -INFINITY;
47434729

47444730
for (int col = tid; col < ncols; col += block_size) {
4745-
const int ix = rowx*ncols + col;
4746-
const int iy = rowy*ncols + col;
4747-
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
4731+
const int i = row*ncols + col;
4732+
max_val = max(max_val, x[i]);
47484733
}
47494734

47504735
// find the max value in the block
4751-
max_val = warp_reduce_max(max_val);
4752-
if (block_size > WARP_SIZE) {
4753-
if (warp_id == 0) {
4754-
buf[lane_id] = -INFINITY;
4755-
}
4756-
__syncthreads();
4757-
4758-
if (lane_id == 0) {
4759-
buf[warp_id] = max_val;
4760-
}
4761-
__syncthreads();
4762-
4763-
max_val = buf[lane_id];
4764-
max_val = warp_reduce_max(max_val);
4736+
#pragma unroll
4737+
for (int mask = 16; mask > 0; mask >>= 1) {
4738+
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
47654739
}
47664740

47674741
float tmp = 0.f;
47684742

47694743
for (int col = tid; col < ncols; col += block_size) {
4770-
const int ix = rowx*ncols + col;
4771-
const int iy = rowy*ncols + col;
4772-
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
4744+
const int i = row*ncols + col;
4745+
const float val = expf(x[i] - max_val);
47734746
tmp += val;
4774-
dst[ix] = val;
4747+
dst[i] = val;
47754748
}
47764749

4777-
// find the sum of exps in the block
4778-
tmp = warp_reduce_sum(tmp);
4779-
if (block_size > WARP_SIZE) {
4780-
if (warp_id == 0) {
4781-
buf[lane_id] = 0.f;
4782-
}
4783-
__syncthreads();
4784-
4785-
if (lane_id == 0) {
4786-
buf[warp_id] = tmp;
4787-
}
4788-
__syncthreads();
4789-
4790-
tmp = buf[lane_id];
4791-
tmp = warp_reduce_sum(tmp);
4750+
// sum up partial sums
4751+
#pragma unroll
4752+
for (int mask = 16; mask > 0; mask >>= 1) {
4753+
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
47924754
}
47934755

47944756
const float inv_tmp = 1.f / tmp;
47954757

47964758
for (int col = tid; col < ncols; col += block_size) {
4797-
const int i = rowx*ncols + col;
4759+
const int i = row*ncols + col;
47984760
dst[i] *= inv_tmp;
47994761
}
48004762
}
@@ -5831,12 +5793,10 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
58315793
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
58325794
}
58335795

5834-
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5835-
int nth = WARP_SIZE;
5836-
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
5837-
const dim3 block_dims(nth, 1, 1);
5796+
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
5797+
const dim3 block_dims(1, WARP_SIZE, 1);
58385798
const dim3 block_nums(nrows_x, 1, 1);
5839-
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
5799+
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
58405800
}
58415801

58425802
static void im2col_f32_f16_cuda(const float * x, half * dst,
@@ -6875,18 +6835,14 @@ inline void ggml_cuda_op_soft_max(
68756835
GGML_ASSERT(src0->type == GGML_TYPE_F32);
68766836
GGML_ASSERT( dst->type == GGML_TYPE_F32);
68776837

6878-
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
6879-
68806838
const int64_t ne00 = src0->ne[0];
6881-
const int64_t nrows_x = ggml_nrows(src0);
6882-
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
6883-
6884-
float scale = 1.0f;
6885-
memcpy(&scale, dst->op_params, sizeof(float));
6839+
const int64_t nrows = ggml_nrows(src0);
68866840

6887-
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
6841+
soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
68886842

6843+
(void) src1;
68896844
(void) dst;
6845+
(void) src1_dd;
68906846
}
68916847

68926848
inline void ggml_cuda_op_scale(

Diff for: ggml-metal.m

+16-27
Original file line numberDiff line numberDiff line change
@@ -1028,27 +1028,20 @@ void ggml_metal_graph_compute(
10281028
int nth = 32; // SIMD width
10291029

10301030
if (ne00%4 == 0) {
1031-
while (nth < ne00/4 && nth < 256) {
1032-
nth *= 2;
1033-
}
10341031
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
10351032
} else {
1036-
while (nth < ne00 && nth < 1024) {
1033+
do {
10371034
nth *= 2;
1038-
}
1035+
} while (nth <= ne00 && nth <= 1024);
1036+
nth /= 2;
10391037
[encoder setComputePipelineState:ctx->pipeline_soft_max];
10401038
}
1041-
1042-
const float scale = ((float *) dst->op_params)[0];
1043-
1044-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1045-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1046-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1047-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1048-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1049-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1050-
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1051-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1039+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1040+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1041+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1042+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1043+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1044+
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
10521045

10531046
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
10541047
} break;
@@ -1358,19 +1351,15 @@ void ggml_metal_graph_compute(
13581351
float eps;
13591352
memcpy(&eps, dst->op_params, sizeof(float));
13601353

1361-
int nth = 32; // SIMD width
1362-
1363-
while (nth < ne00/4 && nth < 1024) {
1364-
nth *= 2;
1365-
}
1354+
const int nth = MIN(512, ne00);
13661355

13671356
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
1368-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1369-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1370-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1371-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1372-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
1373-
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1357+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1358+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1359+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1360+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1361+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
1362+
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
13741363

13751364
const int64_t nrows = ggml_nrows(src0);
13761365

0 commit comments

Comments
 (0)