@@ -3558,9 +3558,49 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
3558
3558
cpy_1 (cx + x_offset, cdst + dst_offset);
3559
3559
}
3560
3560
3561
+ static __device__ float rope_ntkv2_ramp (const float low, const float high, const int i0) {
3562
+ const float y = (i0 / 2 - low) / min (0 .001f , high - low);
3563
+ return 1 .0f - min (1 .0f , max (0 .0f , y));
3564
+ }
3565
+
3566
+ struct rope_corr_factors {
3567
+ float v[4 ];
3568
+ };
3569
+
3570
+ // NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
3571
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
3572
+ static __device__ float rope_ntkv2 (
3573
+ const float theta_base,
3574
+ const float theta_linear,
3575
+ const float theta_ntk,
3576
+ const rope_corr_factors corr_factors,
3577
+ const int64_t i0,
3578
+ const float ntk_factor,
3579
+ const float ext_factor) {
3580
+ float ramp_mix;
3581
+ float theta;
3582
+
3583
+ ramp_mix = rope_ntkv2_ramp (corr_factors.v [0 ], corr_factors.v [1 ], i0) * ntk_factor;
3584
+ theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
3585
+
3586
+ ramp_mix = rope_ntkv2_ramp (corr_factors.v [2 ], corr_factors.v [3 ], i0) * ext_factor;
3587
+ theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
3588
+ return theta;
3589
+ }
3590
+
3561
3591
// rope == RoPE == rotary positional embedding
3562
- static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const float p0,
3563
- const float p_delta, const int p_delta_rows, const float theta_scale) {
3592
+ static __global__ void rope_f32 (
3593
+ const float * x,
3594
+ float * dst,
3595
+ const int ncols,
3596
+ const float freq_scale,
3597
+ const float ntk_factor,
3598
+ const float ext_factor,
3599
+ const float theta_scale,
3600
+ const float theta_ntk_scale,
3601
+ const float p0,
3602
+ const int p_delta_rows,
3603
+ const rope_corr_factors corr_factors) {
3564
3604
const int col = 2 *(blockDim .x *blockIdx .x + threadIdx .x );
3565
3605
3566
3606
if (col >= ncols) {
@@ -3570,7 +3610,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
3570
3610
const int row = blockDim .y *blockIdx .y + threadIdx .y ;
3571
3611
const int i = row*ncols + col;
3572
3612
3573
- const float theta = (p0 + p_delta * (row/p_delta_rows))*powf (theta_scale, col/2 );
3613
+ const float p = p0 + row / p_delta_rows;
3614
+ const float theta_base = p*powf (theta_scale, col/2 );
3615
+ const float theta_linear = freq_scale * theta_base;
3616
+ const float theta_ntk = p*powf (theta_ntk_scale, col/2 );
3617
+ const float theta = rope_ntkv2 (theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor, ext_factor);
3574
3618
const float sin_theta = sinf (theta);
3575
3619
const float cos_theta = cosf (theta);
3576
3620
@@ -4234,13 +4278,26 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
4234
4278
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0 , stream>>> (x, dst, scale, k);
4235
4279
}
4236
4280
4237
- static void rope_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p0,
4238
- const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
4281
+ static void rope_f32_cuda (
4282
+ const float * x,
4283
+ float * dst,
4284
+ const int ncols,
4285
+ const int nrows,
4286
+ const float freq_scale,
4287
+ const float ntk_factor,
4288
+ const float ext_factor,
4289
+ const float theta_scale,
4290
+ const float theta_ntk_scale,
4291
+ const float p0,
4292
+ const int p_delta_rows,
4293
+ const rope_corr_factors corr_factors,
4294
+ cudaStream_t stream) {
4239
4295
GGML_ASSERT (nrows % 2 == 0 );
4240
4296
const dim3 block_dims (2 *CUDA_ROPE_BLOCK_SIZE, 1 , 1 );
4241
4297
const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
4242
4298
const dim3 block_nums (num_blocks_x, nrows, 1 );
4243
- rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
4299
+ rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, freq_scale, ntk_factor, ext_factor, theta_scale,
4300
+ theta_ntk_scale, p0, p_delta_rows, corr_factors);
4244
4301
}
4245
4302
4246
4303
static void rope_glm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
@@ -4941,11 +4998,13 @@ inline void ggml_cuda_op_rope(
4941
4998
const int n_dims = ((int32_t *) dst->op_params )[1 ];
4942
4999
const int mode = ((int32_t *) dst->op_params )[2 ];
4943
5000
const int n_ctx = ((int32_t *) dst->op_params )[3 ];
4944
- // RoPE alteration for extended context
4945
5001
4946
- float freq_base, freq_scale;
5002
+ // RoPE alteration for extended context
5003
+ float freq_base, freq_scale, ntk_factor, ext_factor;
4947
5004
memcpy (&freq_base, (int32_t *) dst->op_params + 4 , sizeof (float ));
4948
5005
memcpy (&freq_scale, (int32_t *) dst->op_params + 5 , sizeof (float ));
5006
+ memcpy (&ntk_factor, (int32_t *) dst->op_params + 6 , sizeof (float ));
5007
+ memcpy (&ext_factor, (int32_t *) dst->op_params + 7 , sizeof (float ));
4949
5008
4950
5009
const float theta_scale = powf (freq_base, -2 .0f /n_dims);
4951
5010
@@ -4958,8 +5017,13 @@ inline void ggml_cuda_op_rope(
4958
5017
const float block_p = max (p - (n_ctx - 2 .f ), 0 .f );
4959
5018
rope_glm_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
4960
5019
} else {
4961
- const float p0 = (((mode & 1 ) == 0 ? n_past : 0 )) * freq_scale;
4962
- rope_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
5020
+ const float p0 = (mode & 1 ) == 0 ? n_past : 0 ;
5021
+ const float theta_ntk_scale = powf (freq_base * powf (freq_scale, (n_dims / (n_dims - 2 .0f ))), -2 .0f /n_dims);
5022
+ rope_corr_factors corr_factors;
5023
+ ggml_rope_ntkv2_corr_factors (n_dims, freq_base, corr_factors.v );
5024
+
5025
+ rope_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor, ext_factor, theta_scale,
5026
+ theta_ntk_scale, p0, ne01, corr_factors, cudaStream_main);
4963
5027
}
4964
5028
4965
5029
(void ) src1;
0 commit comments