@@ -1875,52 +1875,36 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
1875
1875
cpy_1 (cx + x_offset, cdst + dst_offset);
1876
1876
}
1877
1877
1878
- static __device__ void ntkv2_ramp (const float low, const float high, const int i0, float *out ) {
1878
+ static __device__ float ntkv2_ramp (const float low, const float high, const int i0) {
1879
1879
const float y = (i0 / 2 - low) / min (0 .001f , high - low);
1880
- *out = 1 .0f - min (1 .0f , max (0 .0f , y));
1880
+ return 1 .0f - min (1 .0f , max (0 .0f , y));
1881
1881
}
1882
1882
1883
1883
// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
1884
1884
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1885
- static __device__ void compute_ntkv2 (
1885
+ static __device__ float compute_ntkv2 (
1886
1886
float theta_base,
1887
+ float theta_linear,
1887
1888
float theta_ntk,
1888
- float dims_over_base,
1889
- float freq_scale,
1889
+ const float corr_factors[4 ],
1890
1890
int64_t i0,
1891
1891
float ntk_factor,
1892
- float extrapolation_factor,
1893
- int n_dims,
1894
- float *theta) {
1895
- // Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
1896
- // Do not change unless there is a good reason for doing so!
1897
- // These are precomputed because CUDA doesn't allow dynamic init of device constants
1898
- static const float low_1p = 2 .6135630f ;
1899
- static const float high_1p = 2 .7817991f ;
1900
- static const float low_2p = 1 .5070765f ;
1901
- static const float high_2p = 2 .5467973f ;
1902
-
1903
- // start and end correction factors
1904
- const float low_1 = max (0 .0f , floorf (low_1p * dims_over_base));
1905
- const float high_1 = min (n_dims - 1 .0f , ceilf (high_1p * dims_over_base));
1906
- const float low_2 = max (0 .0f , floorf (low_2p * dims_over_base));
1907
- const float high_2 = min (n_dims - 1 .0f , ceilf (high_2p * dims_over_base));
1908
-
1892
+ float extrapolation_factor) {
1909
1893
float ramp_mix;
1894
+ float theta;
1910
1895
1911
- const float theta_linear = freq_scale * theta_base;
1912
- ntkv2_ramp (low_1, high_1, i0, &ramp_mix);
1913
- ramp_mix *= ntk_factor;
1914
- const float theta_mix = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
1915
- ntkv2_ramp (low_2, high_2, i0, &ramp_mix);
1916
- ramp_mix *= extrapolation_factor;
1917
- *theta = theta_mix * (1 - ramp_mix) + theta_base * ramp_mix;
1896
+ ramp_mix = ntkv2_ramp (corr_factors[0 ], corr_factors[1 ], i0) * ntk_factor;
1897
+ theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
1898
+
1899
+ ramp_mix = ntkv2_ramp (corr_factors[2 ], corr_factors[3 ], i0) * extrapolation_factor;
1900
+ theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
1901
+ return theta;
1918
1902
}
1919
1903
1920
1904
// rope == RoPE == rotary positional embedding
1921
- static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const int n_dims, const float freq_base,
1905
+ static __global__ void rope_f32 (const float * x, float * dst, const int ncols,
1922
1906
const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
1923
- const float theta_ntk_scale, const float dims_over_base , const float p ) {
1907
+ const float theta_ntk_scale, const float p , const float corr_factors[ 4 ] ) {
1924
1908
1925
1909
const int col = 2 *(blockDim .x *blockIdx .x + threadIdx .x );
1926
1910
@@ -1931,11 +1915,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
1931
1915
const int row = blockDim .y *blockIdx .y + threadIdx .y ;
1932
1916
const int i = row*ncols + col;
1933
1917
1934
- const float theta_base = p*powf (theta_scale, col/2 );
1935
- const float theta_ntk = p* powf (theta_ntk_scale, col/ 2 ) ;
1936
- float theta ;
1937
- compute_ntkv2 (theta_base, theta_ntk, dims_over_base ,
1938
- freq_scale, col, ntk_factor, extrapolation_factor, n_dims, &theta );
1918
+ const float theta_base = p*powf (theta_scale, col/2 );
1919
+ const float theta_linear = freq_scale * theta_base ;
1920
+ const float theta_ntk = p* powf (theta_ntk_scale, col/ 2 ) ;
1921
+ const float theta = compute_ntkv2 (theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor ,
1922
+ extrapolation_factor);
1939
1923
const float sin_theta = sinf (theta);
1940
1924
const float cos_theta = cosf (theta);
1941
1925
@@ -2415,16 +2399,16 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
2415
2399
}
2416
2400
2417
2401
static void rope_f32_cuda (
2418
- const float * x, float * dst, const int ncols, const int nrows, const int n_dims, const float freq_base,
2402
+ const float * x, float * dst, const int ncols, const int nrows,
2419
2403
const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
2420
- const float theta_ntk_scale, const float dims_over_base, const float p , cudaStream_t stream) {
2404
+ const float theta_ntk_scale, const float p, const float corr_factors[ 4 ] , cudaStream_t stream) {
2421
2405
2422
2406
GGML_ASSERT (nrows % 2 == 0 );
2423
2407
const dim3 block_dims (2 *CUDA_ROPE_BLOCK_SIZE, 1 , 1 );
2424
2408
const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
2425
2409
const dim3 block_nums (num_blocks_x, nrows, 1 );
2426
- rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, n_dims, freq_base, freq_scale, ntk_factor,
2427
- extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p );
2410
+ rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, freq_scale, ntk_factor,
2411
+ extrapolation_factor, theta_scale, theta_ntk_scale, p, corr_factors );
2428
2412
}
2429
2413
2430
2414
static void rope_glm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
@@ -2990,6 +2974,13 @@ inline void ggml_cuda_op_mul_mat_cublas(
2990
2974
(void ) i1;
2991
2975
}
2992
2976
2977
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
2978
+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
2979
+ static float ntkv2_correction_factor (const int n_dims, const float n_rot, const float base) {
2980
+ static const float max_pos_emb = 2048 ;
2981
+ return n_dims * logf (max_pos_emb / (n_rot * 2 * (float )M_PI)) / (2 * logf (base));
2982
+ }
2983
+
2993
2984
inline void ggml_cuda_op_rope (
2994
2985
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
2995
2986
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -3016,8 +3007,6 @@ inline void ggml_cuda_op_rope(
3016
3007
memcpy (&extrapolation_factor, (int32_t *) src1->data + 7 , sizeof (float ));
3017
3008
3018
3009
const float theta_scale = powf (freq_base, -2 .0f /n_dims);
3019
- const float theta_ntk_scale = powf (freq_base * powf (freq_scale, (n_dims / (n_dims - 2 .0f ))), -2 .0f /n_dims);
3020
- const float dims_over_base = n_dims / logf (freq_base);
3021
3010
const float p = ((mode & 1 ) == 0 ? n_past + i02 : i02);
3022
3011
3023
3012
bool is_glm = mode & 4 ;
@@ -3028,8 +3017,25 @@ inline void ggml_cuda_op_rope(
3028
3017
const float block_p = max (p - (n_ctx - 2 .f ), 0 .f );
3029
3018
rope_glm_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
3030
3019
} else {
3031
- rope_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, n_dims, freq_base, freq_scale, ntk_factor,
3032
- extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p, cudaStream_main);
3020
+ const float theta_ntk_scale = powf (freq_base * powf (freq_scale, (n_dims / (n_dims - 2 .0f ))), -2 .0f /n_dims);
3021
+
3022
+ // Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
3023
+ // Do not change unless there is a good reason for doing so!
3024
+ static const float BETA_0 = 1 .75f ;
3025
+ static const float BETA_1 = 1 .25f ;
3026
+ static const float GAMMA_0 = 16 .0f ;
3027
+ static const float GAMMA_1 = 2 .0f ;
3028
+
3029
+ // start and end correction factors
3030
+ const float corr_factors[4 ] = {
3031
+ max (0 .0f , floorf (ntkv2_correction_factor (n_dims, BETA_0, freq_base))),
3032
+ min (n_dims - 1 .0f , ceilf (ntkv2_correction_factor (n_dims, BETA_1, freq_base))),
3033
+ max (0 .0f , floorf (ntkv2_correction_factor (n_dims, GAMMA_0, freq_base))),
3034
+ min (n_dims - 1 .0f , ceilf (ntkv2_correction_factor (n_dims, GAMMA_1, freq_base))),
3035
+ };
3036
+
3037
+ rope_f32_cuda (src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor,
3038
+ extrapolation_factor, theta_scale, theta_ntk_scale, p, corr_factors, cudaStream_main);
3033
3039
}
3034
3040
3035
3041
(void ) dst;
0 commit comments