@@ -439,6 +439,7 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt
439
439
struct ggml_tensor_extra_gpu {
440
440
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
441
441
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
442
+ bool copied;
442
443
};
443
444
444
445
// this is faster on Windows
@@ -4355,8 +4356,9 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
4355
4356
}
4356
4357
4357
4358
// rope == RoPE == rotary positional embedding
4358
- static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const float p0,
4359
- const float p_delta, const int p_delta_rows, const float theta_scale) {
4359
+
4360
+ static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4361
+ const int p_delta_rows, const float theta_scale) {
4360
4362
const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
4361
4363
4362
4364
if (col >= ncols) {
@@ -4365,8 +4367,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
4365
4367
4366
4368
const int row = blockDim .x *blockIdx .x + threadIdx .x ;
4367
4369
const int i = row*ncols + col;
4370
+ const int i2 = row/p_delta_rows;
4368
4371
4369
- const float theta = (p0 + p_delta * (row/p_delta_rows))*powf (theta_scale, col/2 );
4372
+ const int p = pos != nullptr ? pos[i2] : 0 ;
4373
+ const float p0 = p * freq_scale;
4374
+ const float theta = p0*powf (theta_scale, col/2 );
4370
4375
const float sin_theta = sinf (theta);
4371
4376
const float cos_theta = cosf (theta);
4372
4377
@@ -4377,8 +4382,8 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
4377
4382
dst[i + 1 ] = x0*sin_theta + x1*cos_theta;
4378
4383
}
4379
4384
4380
- static __global__ void rope_neox_f32 (const float * x, float * dst, const int ncols, const float p0 ,
4381
- const float p_delta, const int p_delta_rows, const float theta_scale) {
4385
+ static __global__ void rope_neox_f32 (const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale ,
4386
+ const int p_delta_rows, const float theta_scale) {
4382
4387
const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
4383
4388
4384
4389
if (col >= ncols) {
@@ -4387,8 +4392,11 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
4387
4392
4388
4393
const int row = blockDim .x *blockIdx .x + threadIdx .x ;
4389
4394
const int i = row*ncols + col/2 ;
4395
+ const int i2 = row/p_delta_rows;
4390
4396
4391
- const float theta = (p0 + p_delta * (row/p_delta_rows))*powf (theta_scale, col/2 );
4397
+ const int p = pos != nullptr ? pos[i2] : 0 ;
4398
+ const float p0 = p * freq_scale;
4399
+ const float theta = p0*powf (theta_scale, col/2 );
4392
4400
const float sin_theta = sinf (theta);
4393
4401
const float cos_theta = cosf (theta);
4394
4402
@@ -4399,8 +4407,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
4399
4407
dst[i + ncols/2 ] = x0*sin_theta + x1*cos_theta;
4400
4408
}
4401
4409
4402
- static __global__ void rope_glm_f32 (const float * x, float * dst, const int ncols, const float p0 ,
4403
- const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
4410
+ static __global__ void rope_glm_f32 (const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale ,
4411
+ const int p_delta_rows, const float theta_scale, const int n_ctx) {
4404
4412
const int col = blockDim .x *blockIdx .x + threadIdx .x ;
4405
4413
const int half_n_dims = ncols/4 ;
4406
4414
@@ -4410,11 +4418,13 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
4410
4418
4411
4419
const int row = blockDim .y *blockIdx .y + threadIdx .y ;
4412
4420
const int i = row*ncols + col;
4421
+ const int i2 = row/p_delta_rows;
4413
4422
4414
4423
const float col_theta_scale = powf (theta_scale, col);
4415
- const float p = p0 + p_delta*(row/p_delta_rows);
4424
+ // FIXME: this is likely wrong
4425
+ const int p = pos != nullptr ? pos[i2] : 0 ;
4416
4426
4417
- const float theta = min (p, p_delta*( n_ctx - 2 )) *col_theta_scale;
4427
+ const float theta = min (p, n_ctx - 2 )*freq_scale *col_theta_scale;
4418
4428
const float sin_theta = sinf (theta);
4419
4429
const float cos_theta = cosf (theta);
4420
4430
@@ -4424,7 +4434,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
4424
4434
dst[i + 0 ] = x0*cos_theta - x1*sin_theta;
4425
4435
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
4426
4436
4427
- const float block_theta = max (p - p_delta*( n_ctx - 2 ) , 0 . f )*col_theta_scale;
4437
+ const float block_theta = (( float ) max (p - n_ctx - 2 , 0 ) )*col_theta_scale;
4428
4438
const float sin_block_theta = sinf (block_theta);
4429
4439
const float cos_block_theta = cosf (block_theta);
4430
4440
@@ -5361,31 +5371,31 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
5361
5371
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0 , stream>>> (x, dst, scale, k);
5362
5372
}
5363
5373
5364
- static void rope_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p0 ,
5365
- const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5374
+ static void rope_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale ,
5375
+ const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5366
5376
GGML_ASSERT (ncols % 2 == 0 );
5367
5377
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
5368
5378
const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
5369
5379
const dim3 block_nums (nrows, num_blocks_x, 1 );
5370
- rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, p0, p_delta , p_delta_rows, theta_scale);
5380
+ rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale , p_delta_rows, theta_scale);
5371
5381
}
5372
5382
5373
- static void rope_neox_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p0 ,
5374
- const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5383
+ static void rope_neox_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale ,
5384
+ const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
5375
5385
GGML_ASSERT (ncols % 2 == 0 );
5376
5386
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
5377
5387
const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
5378
5388
const dim3 block_nums (nrows, num_blocks_x, 1 );
5379
- rope_neox_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, p0, p_delta , p_delta_rows, theta_scale);
5389
+ rope_neox_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale , p_delta_rows, theta_scale);
5380
5390
}
5381
5391
5382
- static void rope_glm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float p0 ,
5383
- const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
5392
+ static void rope_glm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale ,
5393
+ const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
5384
5394
GGML_ASSERT (ncols % 4 == 0 );
5385
5395
const dim3 block_dims (CUDA_ROPE_BLOCK_SIZE/4 , 1 , 1 );
5386
5396
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1 ) / CUDA_ROPE_BLOCK_SIZE;
5387
5397
const dim3 block_nums (num_blocks_x, nrows, 1 );
5388
- rope_glm_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, p0, p_delta , p_delta_rows, theta_scale, n_ctx);
5398
+ rope_glm_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale , p_delta_rows, theta_scale, n_ctx);
5389
5399
}
5390
5400
5391
5401
static void alibi_f32_cuda (const float * x, float * dst, const int ncols, const int nrows,
@@ -6069,9 +6079,10 @@ inline void ggml_cuda_op_rope(
6069
6079
6070
6080
const int64_t ne00 = src0->ne [0 ];
6071
6081
const int64_t ne01 = src0->ne [1 ];
6082
+ const int64_t ne2 = dst->ne [2 ];
6072
6083
const int64_t nrows = ggml_nrows (src0);
6073
6084
6074
- const int n_past = ((int32_t *) dst->op_params )[0 ];
6085
+ // const int n_past = ((int32_t *) dst->op_params)[0];
6075
6086
const int n_dims = ((int32_t *) dst->op_params )[1 ];
6076
6087
const int mode = ((int32_t *) dst->op_params )[2 ];
6077
6088
const int n_ctx = ((int32_t *) dst->op_params )[3 ];
@@ -6082,19 +6093,37 @@ inline void ggml_cuda_op_rope(
6082
6093
memcpy (&freq_scale, (int32_t *) dst->op_params + 5 , sizeof (float ));
6083
6094
6084
6095
const float theta_scale = powf (freq_base, -2 .0f /n_dims);
6085
- const float p0 = (((mode & 1 ) == 0 ? n_past : 0 )) * freq_scale;
6096
+ // const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
6097
+
6098
+ GGML_ASSERT (src1->type == GGML_TYPE_I32);
6099
+ GGML_ASSERT (src1->ne [0 ] == ne2);
6100
+ GGML_ASSERT (src1->backend == GGML_BACKEND_GPU);
6101
+
6102
+ int id;
6103
+ CUDA_CHECK (cudaGetDevice (&id));
6104
+
6105
+ int * pos = nullptr ;
6106
+ if ((mode & 1 ) == 0 ) {
6107
+ struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra ;
6108
+ pos = (int *) src1_extra->data_device [id];
6109
+ if (!src1_extra->copied ) {
6110
+ CUDA_CHECK (cudaMemcpyAsync (pos, src1->data , ggml_nbytes (src1), cudaMemcpyHostToDevice, main_stream));
6111
+ src1_extra->copied = true ;
6112
+ }
6113
+ }
6086
6114
6087
6115
const bool is_neox = mode & 2 ;
6088
6116
const bool is_glm = mode & 4 ;
6089
6117
6090
6118
// compute
6091
6119
if (is_glm) {
6092
- rope_glm_f32_cuda (src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream);
6120
+ GGML_ASSERT (false );
6121
+ rope_glm_f32_cuda (src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
6093
6122
} else if (is_neox) {
6094
6123
GGML_ASSERT (ne00 == n_dims && " ne00 != n_dims is not implemented for CUDA yet" );
6095
- rope_neox_f32_cuda (src0_dd, dst_dd, ne00, nrows, p0 , freq_scale, ne01, theta_scale, main_stream);
6124
+ rope_neox_f32_cuda (src0_dd, dst_dd, ne00, nrows, pos , freq_scale, ne01, theta_scale, main_stream);
6096
6125
} else {
6097
- rope_f32_cuda (src0_dd, dst_dd, ne00, nrows, p0 , freq_scale, ne01, theta_scale, main_stream);
6126
+ rope_f32_cuda (src0_dd, dst_dd, ne00, nrows, pos , freq_scale, ne01, theta_scale, main_stream);
6098
6127
}
6099
6128
6100
6129
(void ) src1;
0 commit comments