Skip to content

Commit 7e2b997

Browse files
slarenggerganov
andauthored
ggml-cuda : update rope implementation for parallel decoding (#3254)
* ggml-cuda : update rope implementation for parallel decoding * better solution for p0 computation * fix rope * simpler rope implementation --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent daf4c6d commit 7e2b997

File tree

2 files changed

+60
-25
lines changed

2 files changed

+60
-25
lines changed

ggml-cuda.cu

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt
439439
struct ggml_tensor_extra_gpu {
440440
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
441441
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
442+
bool copied;
442443
};
443444

444445
// this is faster on Windows
@@ -4355,8 +4356,9 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
43554356
}
43564357

43574358
// rope == RoPE == rotary positional embedding
4358-
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
4359-
const float p_delta, const int p_delta_rows, const float theta_scale) {
4359+
4360+
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4361+
const int p_delta_rows, const float theta_scale) {
43604362
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
43614363

43624364
if (col >= ncols) {
@@ -4365,8 +4367,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
43654367

43664368
const int row = blockDim.x*blockIdx.x + threadIdx.x;
43674369
const int i = row*ncols + col;
4370+
const int i2 = row/p_delta_rows;
43684371

4369-
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
4372+
const int p = pos != nullptr ? pos[i2] : 0;
4373+
const float p0 = p * freq_scale;
4374+
const float theta = p0*powf(theta_scale, col/2);
43704375
const float sin_theta = sinf(theta);
43714376
const float cos_theta = cosf(theta);
43724377

@@ -4377,8 +4382,8 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
43774382
dst[i + 1] = x0*sin_theta + x1*cos_theta;
43784383
}
43794384

4380-
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
4381-
const float p_delta, const int p_delta_rows, const float theta_scale) {
4385+
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4386+
const int p_delta_rows, const float theta_scale) {
43824387
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
43834388

43844389
if (col >= ncols) {
@@ -4387,8 +4392,11 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
43874392

43884393
const int row = blockDim.x*blockIdx.x + threadIdx.x;
43894394
const int i = row*ncols + col/2;
4395+
const int i2 = row/p_delta_rows;
43904396

4391-
const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
4397+
const int p = pos != nullptr ? pos[i2] : 0;
4398+
const float p0 = p * freq_scale;
4399+
const float theta = p0*powf(theta_scale, col/2);
43924400
const float sin_theta = sinf(theta);
43934401
const float cos_theta = cosf(theta);
43944402

@@ -4399,8 +4407,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
43994407
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
44004408
}
44014409

4402-
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0,
4403-
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
4410+
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4411+
const int p_delta_rows, const float theta_scale, const int n_ctx) {
44044412
const int col = blockDim.x*blockIdx.x + threadIdx.x;
44054413
const int half_n_dims = ncols/4;
44064414

@@ -4410,11 +4418,13 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
44104418

44114419
const int row = blockDim.y*blockIdx.y + threadIdx.y;
44124420
const int i = row*ncols + col;
4421+
const int i2 = row/p_delta_rows;
44134422

44144423
const float col_theta_scale = powf(theta_scale, col);
4415-
const float p = p0 + p_delta*(row/p_delta_rows);
4424+
// FIXME: this is likely wrong
4425+
const int p = pos != nullptr ? pos[i2] : 0;
44164426

4417-
const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
4427+
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
44184428
const float sin_theta = sinf(theta);
44194429
const float cos_theta = cosf(theta);
44204430

@@ -4424,7 +4434,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
44244434
dst[i + 0] = x0*cos_theta - x1*sin_theta;
44254435
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
44264436

4427-
const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale;
4437+
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
44284438
const float sin_block_theta = sinf(block_theta);
44294439
const float cos_block_theta = cosf(block_theta);
44304440

@@ -5361,31 +5371,31 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
53615371
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
53625372
}
53635373

5364-
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
5365-
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5374+
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5375+
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
53665376
GGML_ASSERT(ncols % 2 == 0);
53675377
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
53685378
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
53695379
const dim3 block_nums(nrows, num_blocks_x, 1);
5370-
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
5380+
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
53715381
}
53725382

5373-
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
5374-
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5383+
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5384+
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
53755385
GGML_ASSERT(ncols % 2 == 0);
53765386
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
53775387
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
53785388
const dim3 block_nums(nrows, num_blocks_x, 1);
5379-
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
5389+
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
53805390
}
53815391

5382-
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
5383-
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
5392+
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5393+
const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
53845394
GGML_ASSERT(ncols % 4 == 0);
53855395
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
53865396
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
53875397
const dim3 block_nums(num_blocks_x, nrows, 1);
5388-
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx);
5398+
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx);
53895399
}
53905400

53915401
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
@@ -6069,9 +6079,10 @@ inline void ggml_cuda_op_rope(
60696079

60706080
const int64_t ne00 = src0->ne[0];
60716081
const int64_t ne01 = src0->ne[1];
6082+
const int64_t ne2 = dst->ne[2];
60726083
const int64_t nrows = ggml_nrows(src0);
60736084

6074-
const int n_past = ((int32_t *) dst->op_params)[0];
6085+
//const int n_past = ((int32_t *) dst->op_params)[0];
60756086
const int n_dims = ((int32_t *) dst->op_params)[1];
60766087
const int mode = ((int32_t *) dst->op_params)[2];
60776088
const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -6082,19 +6093,37 @@ inline void ggml_cuda_op_rope(
60826093
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
60836094

60846095
const float theta_scale = powf(freq_base, -2.0f/n_dims);
6085-
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
6096+
// const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
6097+
6098+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
6099+
GGML_ASSERT(src1->ne[0] == ne2);
6100+
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
6101+
6102+
int id;
6103+
CUDA_CHECK(cudaGetDevice(&id));
6104+
6105+
int * pos = nullptr;
6106+
if ((mode & 1) == 0) {
6107+
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
6108+
pos = (int *) src1_extra->data_device[id];
6109+
if (!src1_extra->copied) {
6110+
CUDA_CHECK(cudaMemcpyAsync(pos, src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream));
6111+
src1_extra->copied = true;
6112+
}
6113+
}
60866114

60876115
const bool is_neox = mode & 2;
60886116
const bool is_glm = mode & 4;
60896117

60906118
// compute
60916119
if (is_glm) {
6092-
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream);
6120+
GGML_ASSERT(false);
6121+
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
60936122
} else if (is_neox) {
60946123
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
6095-
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
6124+
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
60966125
} else {
6097-
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream);
6126+
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
60986127
}
60996128

61006129
(void) src1;

llama.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2708,6 +2708,7 @@ static struct ggml_cgraph * llm_build_llama(
27082708

27092709
// KQ_pos - contains the positions
27102710
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
2711+
offload_func_kq(KQ_pos);
27112712
ggml_allocr_alloc(lctx.alloc, KQ_pos);
27122713
if (!ggml_allocr_is_measure(lctx.alloc)) {
27132714
int * data = (int *) KQ_pos->data;
@@ -2719,6 +2720,7 @@ static struct ggml_cgraph * llm_build_llama(
27192720
// shift the entire K-cache if needed
27202721
if (do_rope_shift) {
27212722
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
2723+
offload_func_kq(K_shift);
27222724
ggml_allocr_alloc(lctx.alloc, K_shift);
27232725
if (!ggml_allocr_is_measure(lctx.alloc)) {
27242726
int * data = (int *) K_shift->data;
@@ -3092,6 +3094,7 @@ static struct ggml_cgraph * llm_build_baichaun(
30923094

30933095
// KQ_pos - contains the positions
30943096
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
3097+
offload_func_kq(KQ_pos);
30953098
ggml_allocr_alloc(lctx.alloc, KQ_pos);
30963099
if (!ggml_allocr_is_measure(lctx.alloc)) {
30973100
int * data = (int *) KQ_pos->data;
@@ -3103,6 +3106,7 @@ static struct ggml_cgraph * llm_build_baichaun(
31033106
// shift the entire K-cache if needed
31043107
if (do_rope_shift) {
31053108
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
3109+
offload_func_kq(K_shift);
31063110
ggml_allocr_alloc(lctx.alloc, K_shift);
31073111
if (!ggml_allocr_is_measure(lctx.alloc)) {
31083112
int * data = (int *) K_shift->data;
@@ -3496,6 +3500,7 @@ static struct ggml_cgraph * llm_build_falcon(
34963500

34973501
// KQ_pos - contains the positions
34983502
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
3503+
offload_func_kq(KQ_pos);
34993504
ggml_allocr_alloc(lctx.alloc, KQ_pos);
35003505
if (!ggml_allocr_is_measure(lctx.alloc)) {
35013506
int * data = (int *) KQ_pos->data;
@@ -3507,6 +3512,7 @@ static struct ggml_cgraph * llm_build_falcon(
35073512
// shift the entire K-cache if needed
35083513
if (do_rope_shift) {
35093514
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
3515+
offload_func_kq(K_shift);
35103516
ggml_allocr_alloc(lctx.alloc, K_shift);
35113517
if (!ggml_allocr_is_measure(lctx.alloc)) {
35123518
int * data = (int *) K_shift->data;

0 commit comments

Comments
 (0)