@@ -443,7 +443,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443
443
#define CUDA_SCALE_BLOCK_SIZE 256
444
444
#define CUDA_CLAMP_BLOCK_SIZE 256
445
445
#define CUDA_ROPE_BLOCK_SIZE 256
446
- #define CUDA_SOFT_MAX_BLOCK_SIZE 1024
447
446
#define CUDA_ALIBI_BLOCK_SIZE 32
448
447
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
449
448
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -503,31 +502,6 @@ static size_t g_scratch_offset = 0;
503
502
504
503
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr };
505
504
506
- static __device__ __forceinline__ float warp_reduce_sum (float x) {
507
- #pragma unroll
508
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
509
- x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
510
- }
511
- return x;
512
- }
513
-
514
- static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
515
- #pragma unroll
516
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
517
- a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
518
- a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
519
- }
520
- return a;
521
- }
522
-
523
- static __device__ __forceinline__ float warp_reduce_max (float x) {
524
- #pragma unroll
525
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
526
- x = fmaxf (x, __shfl_xor_sync (0xffffffff , x, mask, 32 ));
527
- }
528
- return x;
529
- }
530
-
531
505
static __global__ void add_f32 (const float * x, const float * y, float * dst, const int kx, const int ky) {
532
506
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
533
507
@@ -604,6 +578,15 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
604
578
dst[i] = x[i] * x[i];
605
579
}
606
580
581
+ static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
582
+ #pragma unroll
583
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
584
+ a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
585
+ a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
586
+ }
587
+ return a;
588
+ }
589
+
607
590
template <int block_size>
608
591
static __global__ void norm_f32 (const float * x, float * dst, const int ncols) {
609
592
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -642,6 +625,14 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
642
625
}
643
626
}
644
627
628
+ static __device__ __forceinline__ float warp_reduce_sum (float x) {
629
+ #pragma unroll
630
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
631
+ x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
632
+ }
633
+ return x;
634
+ }
635
+
645
636
template <int block_size>
646
637
static __global__ void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps) {
647
638
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -4727,74 +4718,45 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
4727
4718
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
4728
4719
}
4729
4720
4730
- static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4731
- const int tid = threadIdx .x ;
4732
- const int rowx = blockIdx .x ;
4733
- const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4734
-
4735
- const int block_size = blockDim .x ;
4736
-
4737
- const int warp_id = threadIdx .x / WARP_SIZE;
4738
- const int lane_id = threadIdx .x % WARP_SIZE;
4739
-
4740
- __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
4721
+ // the CUDA soft max implementation differs from the CPU implementation
4722
+ // instead of doubles floats are used
4723
+ static __global__ void soft_max_f32 (const float * x, float * dst, const int ncols) {
4724
+ const int row = blockDim .x *blockIdx .x + threadIdx .x ;
4725
+ const int block_size = blockDim .y ;
4726
+ const int tid = threadIdx .y ;
4741
4727
4742
4728
float max_val = -INFINITY;
4743
4729
4744
4730
for (int col = tid; col < ncols; col += block_size) {
4745
- const int ix = rowx*ncols + col;
4746
- const int iy = rowy*ncols + col;
4747
- max_val = max (max_val, x[ix]*scale + (y ? y[iy] : 0 .0f ));
4731
+ const int i = row*ncols + col;
4732
+ max_val = max (max_val, x[i]);
4748
4733
}
4749
4734
4750
4735
// find the max value in the block
4751
- max_val = warp_reduce_max (max_val);
4752
- if (block_size > WARP_SIZE) {
4753
- if (warp_id == 0 ) {
4754
- buf[lane_id] = -INFINITY;
4755
- }
4756
- __syncthreads ();
4757
-
4758
- if (lane_id == 0 ) {
4759
- buf[warp_id] = max_val;
4760
- }
4761
- __syncthreads ();
4762
-
4763
- max_val = buf[lane_id];
4764
- max_val = warp_reduce_max (max_val);
4736
+ #pragma unroll
4737
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
4738
+ max_val = max (max_val, __shfl_xor_sync (0xffffffff , max_val, mask, 32 ));
4765
4739
}
4766
4740
4767
4741
float tmp = 0 .f ;
4768
4742
4769
4743
for (int col = tid; col < ncols; col += block_size) {
4770
- const int ix = rowx*ncols + col;
4771
- const int iy = rowy*ncols + col;
4772
- const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - max_val);
4744
+ const int i = row*ncols + col;
4745
+ const float val = expf (x[i] - max_val);
4773
4746
tmp += val;
4774
- dst[ix ] = val;
4747
+ dst[i ] = val;
4775
4748
}
4776
4749
4777
- // find the sum of exps in the block
4778
- tmp = warp_reduce_sum (tmp);
4779
- if (block_size > WARP_SIZE) {
4780
- if (warp_id == 0 ) {
4781
- buf[lane_id] = 0 .f ;
4782
- }
4783
- __syncthreads ();
4784
-
4785
- if (lane_id == 0 ) {
4786
- buf[warp_id] = tmp;
4787
- }
4788
- __syncthreads ();
4789
-
4790
- tmp = buf[lane_id];
4791
- tmp = warp_reduce_sum (tmp);
4750
+ // sum up partial sums
4751
+ #pragma unroll
4752
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
4753
+ tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
4792
4754
}
4793
4755
4794
4756
const float inv_tmp = 1 .f / tmp;
4795
4757
4796
4758
for (int col = tid; col < ncols; col += block_size) {
4797
- const int i = rowx *ncols + col;
4759
+ const int i = row *ncols + col;
4798
4760
dst[i] *= inv_tmp;
4799
4761
}
4800
4762
}
@@ -5831,12 +5793,10 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
5831
5793
diag_mask_inf_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x, rows_per_channel, n_past);
5832
5794
}
5833
5795
5834
- static void soft_max_f32_cuda (const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5835
- int nth = WARP_SIZE;
5836
- while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
5837
- const dim3 block_dims (nth, 1 , 1 );
5796
+ static void soft_max_f32_cuda (const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
5797
+ const dim3 block_dims (1 , WARP_SIZE, 1 );
5838
5798
const dim3 block_nums (nrows_x, 1 , 1 );
5839
- soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, y, dst, ncols_x, nrows_y, scale );
5799
+ soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x);
5840
5800
}
5841
5801
5842
5802
static void im2col_f32_f16_cuda (const float * x, half * dst,
@@ -6875,18 +6835,14 @@ inline void ggml_cuda_op_soft_max(
6875
6835
GGML_ASSERT (src0->type == GGML_TYPE_F32);
6876
6836
GGML_ASSERT ( dst->type == GGML_TYPE_F32);
6877
6837
6878
- GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
6879
-
6880
6838
const int64_t ne00 = src0->ne [0 ];
6881
- const int64_t nrows_x = ggml_nrows (src0);
6882
- const int64_t nrows_y = src1 ? ggml_nrows (src1) : 1 ;
6883
-
6884
- float scale = 1 .0f ;
6885
- memcpy (&scale, dst->op_params , sizeof (float ));
6839
+ const int64_t nrows = ggml_nrows (src0);
6886
6840
6887
- soft_max_f32_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale , main_stream);
6841
+ soft_max_f32_cuda (src0_dd, dst_dd, ne00, nrows , main_stream);
6888
6842
6843
+ (void ) src1;
6889
6844
(void ) dst;
6845
+ (void ) src1_dd;
6890
6846
}
6891
6847
6892
6848
inline void ggml_cuda_op_scale (
0 commit comments