Skip to content

Commit f3b9eae

Browse files
committed
llama: improve NTKv2 CUDA implementation
Precompute what we can on the host to make the device kernel smaller, and to avoid magic constants.
1 parent 2a9ba48 commit f3b9eae

File tree

2 files changed

+63
-52
lines changed

2 files changed

+63
-52
lines changed

Diff for: ggml-cuda.cu

+50-44
Original file line numberDiff line numberDiff line change
@@ -1875,52 +1875,36 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
18751875
cpy_1(cx + x_offset, cdst + dst_offset);
18761876
}
18771877

1878-
static __device__ void ntkv2_ramp(const float low, const float high, const int i0, float *out) {
1878+
static __device__ float ntkv2_ramp(const float low, const float high, const int i0) {
18791879
const float y = (i0 / 2 - low) / min(0.001f, high - low);
1880-
*out = 1.0f - min(1.0f, max(0.0f, y));
1880+
return 1.0f - min(1.0f, max(0.0f, y));
18811881
}
18821882

18831883
// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
18841884
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1885-
static __device__ void compute_ntkv2(
1885+
static __device__ float compute_ntkv2(
18861886
float theta_base,
1887+
float theta_linear,
18871888
float theta_ntk,
1888-
float dims_over_base,
1889-
float freq_scale,
1889+
const float corr_factors[4],
18901890
int64_t i0,
18911891
float ntk_factor,
1892-
float extrapolation_factor,
1893-
int n_dims,
1894-
float *theta) {
1895-
// Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
1896-
// Do not change unless there is a good reason for doing so!
1897-
// These are precomputed because CUDA doesn't allow dynamic init of device constants
1898-
static const float low_1p = 2.6135630f;
1899-
static const float high_1p = 2.7817991f;
1900-
static const float low_2p = 1.5070765f;
1901-
static const float high_2p = 2.5467973f;
1902-
1903-
// start and end correction factors
1904-
const float low_1 = max(0.0f, floorf(low_1p * dims_over_base));
1905-
const float high_1 = min(n_dims - 1.0f, ceilf(high_1p * dims_over_base));
1906-
const float low_2 = max(0.0f, floorf(low_2p * dims_over_base));
1907-
const float high_2 = min(n_dims - 1.0f, ceilf(high_2p * dims_over_base));
1908-
1892+
float extrapolation_factor) {
19091893
float ramp_mix;
1894+
float theta;
19101895

1911-
const float theta_linear = freq_scale * theta_base;
1912-
ntkv2_ramp(low_1, high_1, i0, &ramp_mix);
1913-
ramp_mix *= ntk_factor;
1914-
const float theta_mix = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
1915-
ntkv2_ramp(low_2, high_2, i0, &ramp_mix);
1916-
ramp_mix *= extrapolation_factor;
1917-
*theta = theta_mix * (1 - ramp_mix) + theta_base * ramp_mix;
1896+
ramp_mix = ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor;
1897+
theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
1898+
1899+
ramp_mix = ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * extrapolation_factor;
1900+
theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
1901+
return theta;
19181902
}
19191903

19201904
// rope == RoPE == rotary positional embedding
1921-
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int n_dims, const float freq_base,
1905+
static __global__ void rope_f32(const float * x, float * dst, const int ncols,
19221906
const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
1923-
const float theta_ntk_scale, const float dims_over_base, const float p) {
1907+
const float theta_ntk_scale, const float p, const float corr_factors[4]) {
19241908

19251909
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
19261910

@@ -1931,11 +1915,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
19311915
const int row = blockDim.y*blockIdx.y + threadIdx.y;
19321916
const int i = row*ncols + col;
19331917

1934-
const float theta_base = p*powf(theta_scale, col/2);
1935-
const float theta_ntk = p*powf(theta_ntk_scale, col/2);
1936-
float theta;
1937-
compute_ntkv2(theta_base, theta_ntk, dims_over_base,
1938-
freq_scale, col, ntk_factor, extrapolation_factor, n_dims, &theta);
1918+
const float theta_base = p*powf(theta_scale, col/2);
1919+
const float theta_linear = freq_scale * theta_base;
1920+
const float theta_ntk = p*powf(theta_ntk_scale, col/2);
1921+
const float theta = compute_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor,
1922+
extrapolation_factor);
19391923
const float sin_theta = sinf(theta);
19401924
const float cos_theta = cosf(theta);
19411925

@@ -2415,16 +2399,16 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
24152399
}
24162400

24172401
static void rope_f32_cuda(
2418-
const float * x, float * dst, const int ncols, const int nrows, const int n_dims, const float freq_base,
2402+
const float * x, float * dst, const int ncols, const int nrows,
24192403
const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
2420-
const float theta_ntk_scale, const float dims_over_base, const float p, cudaStream_t stream) {
2404+
const float theta_ntk_scale, const float p, const float corr_factors[4], cudaStream_t stream) {
24212405

24222406
GGML_ASSERT(nrows % 2 == 0);
24232407
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
24242408
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
24252409
const dim3 block_nums(num_blocks_x, nrows, 1);
2426-
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, n_dims, freq_base, freq_scale, ntk_factor,
2427-
extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p);
2410+
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, freq_scale, ntk_factor,
2411+
extrapolation_factor, theta_scale, theta_ntk_scale, p, corr_factors);
24282412
}
24292413

24302414
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
@@ -2990,6 +2974,13 @@ inline void ggml_cuda_op_mul_mat_cublas(
29902974
(void) i1;
29912975
}
29922976

2977+
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
2978+
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
2979+
static float ntkv2_correction_factor(const int n_dims, const float n_rot, const float base) {
2980+
static const float max_pos_emb = 2048;
2981+
return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
2982+
}
2983+
29932984
inline void ggml_cuda_op_rope(
29942985
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
29952986
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -3016,8 +3007,6 @@ inline void ggml_cuda_op_rope(
30163007
memcpy(&extrapolation_factor, (int32_t *) src1->data + 7, sizeof(float));
30173008

30183009
const float theta_scale = powf(freq_base, -2.0f/n_dims);
3019-
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
3020-
const float dims_over_base = n_dims / logf(freq_base);
30213010
const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
30223011

30233012
bool is_glm = mode & 4;
@@ -3028,8 +3017,25 @@ inline void ggml_cuda_op_rope(
30283017
const float block_p = max(p - (n_ctx - 2.f), 0.f);
30293018
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
30303019
} else {
3031-
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, n_dims, freq_base, freq_scale, ntk_factor,
3032-
extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p, cudaStream_main);
3020+
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
3021+
3022+
// Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
3023+
// Do not change unless there is a good reason for doing so!
3024+
static const float BETA_0 = 1.75f;
3025+
static const float BETA_1 = 1.25f;
3026+
static const float GAMMA_0 = 16.0f;
3027+
static const float GAMMA_1 = 2.0f;
3028+
3029+
// start and end correction factors
3030+
const float corr_factors[4] = {
3031+
max(0.0f, floorf(ntkv2_correction_factor(n_dims, BETA_0, freq_base))),
3032+
min(n_dims - 1.0f, ceilf(ntkv2_correction_factor(n_dims, BETA_1, freq_base))),
3033+
max(0.0f, floorf(ntkv2_correction_factor(n_dims, GAMMA_0, freq_base))),
3034+
min(n_dims - 1.0f, ceilf(ntkv2_correction_factor(n_dims, GAMMA_1, freq_base))),
3035+
};
3036+
3037+
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor,
3038+
extrapolation_factor, theta_scale, theta_ntk_scale, p, corr_factors, cudaStream_main);
30333039
}
30343040

30353041
(void) dst;

Diff for: ggml.c

+13-8
Original file line numberDiff line numberDiff line change
@@ -12129,16 +12129,21 @@ static float compute_ntkv2(
1212912129
static const float high_2p = NTKV2_CORRECTION_FACTOR(GAMMA_1);
1213012130

1213112131
// start and end correction factors
12132-
const float low_1 = maxf(0, floorf(low_1p * dims_over_base));
12133-
const float high_1 = minf(n_dims - 1, ceilf(high_1p * dims_over_base));
12134-
const float low_2 = maxf(0, floorf(low_2p * dims_over_base));
12135-
const float high_2 = minf(n_dims - 1, ceilf(high_2p * dims_over_base));
12132+
const float low_1 = maxf(0.0f, floorf(low_1p * dims_over_base));
12133+
const float high_1 = minf(n_dims - 1.0f, ceilf(high_1p * dims_over_base));
12134+
const float low_2 = maxf(0.0f, floorf(low_2p * dims_over_base));
12135+
const float high_2 = minf(n_dims - 1.0f, ceilf(high_2p * dims_over_base));
1213612136

1213712137
const float theta_linear = freq_scale * theta_base;
12138-
const float ramp_mix = ntkv2_ramp(low_1, high_1, i0) * ntk_factor;
12139-
const float theta_mix = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
12140-
const float ramp_final = ntkv2_ramp(low_2, high_2, i0) * extrapolation_factor;
12141-
return theta_mix * (1 - ramp_final) + theta_base * ramp_final;
12138+
float ramp_mix;
12139+
float theta;
12140+
12141+
ramp_mix = ntkv2_ramp(low_1, high_1, i0) * ntk_factor;
12142+
theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
12143+
12144+
ramp_mix = ntkv2_ramp(low_2, high_2, i0) * extrapolation_factor;
12145+
theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
12146+
return theta;
1214212147
}
1214312148

1214412149
static void ggml_compute_forward_rope_f32(

0 commit comments

Comments
 (0)