@@ -259,6 +259,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
259
259
#define CUDA_CPY_BLOCK_SIZE 32
260
260
#define CUDA_SCALE_BLOCK_SIZE 256
261
261
#define CUDA_ROPE_BLOCK_SIZE 256
262
+ #define CUDA_ALIBI_BLOCK_SIZE 32
262
263
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
263
264
#define CUDA_QUANTIZE_BLOCK_SIZE 256
264
265
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
@@ -3940,6 +3941,29 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
3940
3941
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
3941
3942
}
3942
3943
3944
+ static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
3945
+ const int n_heads_log2_floor, const float m0, const float m1) {
3946
+ const int col = blockDim.x*blockIdx.x + threadIdx.x;
3947
+
3948
+ if (col >= ncols) {
3949
+ return;
3950
+ }
3951
+
3952
+ const int row = blockDim.y*blockIdx.y + threadIdx.y;
3953
+ const int i = row*ncols + col;
3954
+
3955
+ const int k = row/k_rows;
3956
+
3957
+ float m_k;
3958
+ if (k < n_heads_log2_floor) {
3959
+ m_k = powf(m0, k + 1);
3960
+ } else {
3961
+ m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
3962
+ }
3963
+
3964
+ dst[i] = col * m_k + x[i];
3965
+ }
3966
+
3943
3967
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
3944
3968
const int col = blockDim.x*blockIdx.x + threadIdx.x;
3945
3969
const int row = blockDim.y*blockIdx.y + threadIdx.y;
@@ -4766,6 +4790,15 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con
4766
4790
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
4767
4791
}
4768
4792
4793
+ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
4794
+ const int k_rows, const int n_heads_log2_floor, const float m0,
4795
+ const float m1, cudaStream_t stream) {
4796
+ const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
4797
+ const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
4798
+ const dim3 block_nums(num_blocks_x, nrows, 1);
4799
+ alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
4800
+ }
4801
+
4769
4802
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
4770
4803
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
4771
4804
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@@ -5501,6 +5534,41 @@ inline void ggml_cuda_op_rope(
5501
5534
(void) i1;
5502
5535
}
5503
5536
5537
+ inline void ggml_cuda_op_alibi(
5538
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
5539
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
5540
+ cudaStream_t & cudaStream_main){
5541
+
5542
+ GGML_ASSERT(src0_ddf_i != nullptr);
5543
+ GGML_ASSERT(dst_ddf_i != nullptr);
5544
+
5545
+ const int64_t ne00 = src0->ne[0];
5546
+ const int64_t ne01 = src0->ne[1];
5547
+ const int64_t ne02 = src0->ne[2];
5548
+ const int64_t i01_diff = i01_high - i01_low;
5549
+
5550
+ const int n_past = ((int32_t *) dst->op_params)[0];
5551
+ const int n_head = ((int32_t *) dst->op_params)[1];
5552
+ float max_bias;
5553
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
5554
+
5555
+ GGML_ASSERT(ne01 + n_past == ne00);
5556
+ GGML_ASSERT(n_head == ne02);
5557
+
5558
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
5559
+
5560
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
5561
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
5562
+
5563
+ // compute
5564
+ alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main);
5565
+
5566
+ (void) src1;
5567
+ (void) src0_ddq_i;
5568
+ (void) src1_ddf_i;
5569
+ (void) i1;
5570
+ }
5571
+
5504
5572
inline void ggml_cuda_op_diag_mask_inf(
5505
5573
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
5506
5574
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -6121,6 +6189,11 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
6121
6189
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
6122
6190
}
6123
6191
6192
+ void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6193
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
6194
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true);
6195
+ }
6196
+
6124
6197
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6125
6198
(void) src0;
6126
6199
(void) src1;
@@ -6456,6 +6529,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
6456
6529
}
6457
6530
func = ggml_cuda_rope;
6458
6531
break;
6532
+ case GGML_OP_ALIBI:
6533
+ if (!any_on_device) {
6534
+ return false;
6535
+ }
6536
+ func = ggml_cuda_alibi;
6537
+ break;
6459
6538
default:
6460
6539
return false;
6461
6540
}
0 commit comments