Skip to content

Commit ce59171

Browse files
committed
initial CUDA implementation
1 parent b43bfe8 commit ce59171

File tree

1 file changed

+73
-7
lines changed

1 file changed

+73
-7
lines changed

Diff for: ggml-cuda.cu

+73-7
Original file line numberDiff line numberDiff line change
@@ -1875,8 +1875,53 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
18751875
cpy_1(cx + x_offset, cdst + dst_offset);
18761876
}
18771877

1878+
static __device__ void ntkv2_ramp(const float low, const float high, const int i0, float *out) {
1879+
const float y = (i0 / 2 - low) / min(0.001f, high - low);
1880+
*out = 1.0f - min(1.0f, max(0.0f, y));
1881+
}
1882+
1883+
// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
1884+
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1885+
static __device__ void compute_ntkv2(
1886+
float theta_base,
1887+
float theta_ntk,
1888+
float dims_over_base,
1889+
float freq_scale,
1890+
int64_t i0,
1891+
float ntk_factor,
1892+
float extrapolation_factor,
1893+
int n_dims,
1894+
float *theta) {
1895+
// Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
1896+
// Do not change unless there is a good reason for doing so!
1897+
// These are precomputed because CUDA doesn't allow dynamic init of device constants
1898+
static const float low_1p = 2.6135630f;
1899+
static const float high_1p = 2.7817991f;
1900+
static const float low_2p = 1.5070765f;
1901+
static const float high_2p = 2.5467973f;
1902+
1903+
// start and end correction factors
1904+
const float low_1 = max(0.0f, floorf(low_1p * dims_over_base));
1905+
const float high_1 = min(n_dims - 1.0f, ceilf(high_1p * dims_over_base));
1906+
const float low_2 = max(0.0f, floorf(low_2p * dims_over_base));
1907+
const float high_2 = min(n_dims - 1.0f, ceilf(high_2p * dims_over_base));
1908+
1909+
float ramp_mix;
1910+
1911+
const float theta_linear = freq_scale * theta_base;
1912+
ntkv2_ramp(low_1, high_1, i0, &ramp_mix);
1913+
ramp_mix *= ntk_factor;
1914+
const float theta_mix = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
1915+
ntkv2_ramp(low_2, high_2, i0, &ramp_mix);
1916+
ramp_mix *= extrapolation_factor;
1917+
*theta = theta_mix * (1 - ramp_mix) + theta_base * ramp_mix;
1918+
}
1919+
18781920
// rope == RoPE == rotary positional embedding
1879-
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
1921+
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int n_dims, const float freq_base,
1922+
const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
1923+
const float theta_ntk_scale, const float dims_over_base, const float p) {
1924+
18801925
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
18811926

18821927
if (col >= ncols) {
@@ -1886,7 +1931,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
18861931
const int row = blockDim.y*blockIdx.y + threadIdx.y;
18871932
const int i = row*ncols + col;
18881933

1889-
const float theta = p*powf(theta_scale, col/2);
1934+
const float theta_base = p*powf(theta_scale, col/2);
1935+
const float theta_ntk = p*powf(theta_ntk_scale, col/2);
1936+
float theta;
1937+
compute_ntkv2(theta_base, theta_ntk, dims_over_base,
1938+
freq_scale, col, ntk_factor, extrapolation_factor, n_dims, &theta);
18901939
const float sin_theta = sinf(theta);
18911940
const float cos_theta = cosf(theta);
18921941

@@ -2365,12 +2414,17 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
23652414
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
23662415
}
23672416

2368-
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
2417+
static void rope_f32_cuda(
2418+
const float * x, float * dst, const int ncols, const int nrows, const int n_dims, const float freq_base,
2419+
const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
2420+
const float theta_ntk_scale, const float dims_over_base, const float p, cudaStream_t stream) {
2421+
23692422
GGML_ASSERT(nrows % 2 == 0);
23702423
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
23712424
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
23722425
const dim3 block_nums(num_blocks_x, nrows, 1);
2373-
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
2426+
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, n_dims, freq_base, freq_scale, ntk_factor,
2427+
extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p);
23742428
}
23752429

23762430
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
@@ -2947,12 +3001,23 @@ inline void ggml_cuda_op_rope(
29473001
const int64_t ne00 = src0->ne[0];
29483002
const int64_t i01_diff = i01_high - i01_low;
29493003

3004+
float freq_base;
3005+
float freq_scale;
3006+
float ntk_factor;
3007+
float extrapolation_factor;
3008+
29503009
const int n_past = ((int32_t *) src1->data)[0];
29513010
const int n_dims = ((int32_t *) src1->data)[1];
29523011
const int mode = ((int32_t *) src1->data)[2];
29533012
const int n_ctx = ((int32_t *) src1->data)[3];
2954-
2955-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
3013+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
3014+
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
3015+
memcpy(&ntk_factor, (int32_t *) src1->data + 6, sizeof(float));
3016+
memcpy(&extrapolation_factor, (int32_t *) src1->data + 7, sizeof(float));
3017+
3018+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
3019+
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
3020+
const float dims_over_base = n_dims / logf(freq_base);
29563021
const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
29573022

29583023
bool is_glm = mode & 4;
@@ -2963,7 +3028,8 @@ inline void ggml_cuda_op_rope(
29633028
const float block_p = max(p - (n_ctx - 2.f), 0.f);
29643029
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
29653030
} else {
2966-
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
3031+
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, n_dims, freq_base, freq_scale, ntk_factor,
3032+
extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p, cudaStream_main);
29673033
}
29683034

29693035
(void) dst;

0 commit comments

Comments
 (0)