From eec6b66ac94fd2ff86951710da607e1d1c77c15a Mon Sep 17 00:00:00 2001 From: slaren Date: Mon, 18 Sep 2023 23:48:34 +0200 Subject: [PATCH 1/4] ggml-cuda : update rope implementation for parallel decoding --- ggml-cuda.cu | 50 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 08428ea3fab3b..9ead57648a665 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5,6 +5,7 @@ #include #include #include +#include #if defined(GGML_USE_HIPBLAS) #include @@ -4355,7 +4356,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, } // rope == RoPE == rotary positional embedding -static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0, +static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -4365,8 +4366,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c const int row = blockDim.x*blockIdx.x + threadIdx.x; const int i = row*ncols + col; + const int i2 = row/p_delta_rows; - const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); + const float theta = (p0[i2] + p_delta*i2)*powf(theta_scale, col/2); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4377,7 +4379,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c dst[i + 1] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0, +static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -4387,8 +4389,9 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco const int row = blockDim.x*blockIdx.x + threadIdx.x; const int i = row*ncols + col/2; + const int i2 = row/p_delta_rows; - const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); + const float theta = (p0[i2] + p_delta*i2)*powf(theta_scale, col/2); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4399,7 +4402,7 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0, +static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) { const int col = blockDim.x*blockIdx.x + threadIdx.x; const int half_n_dims = ncols/4; @@ -4410,9 +4413,10 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol const int row = blockDim.y*blockIdx.y + threadIdx.y; const int i = row*ncols + col; + const int i2 = row/p_delta_rows; const float col_theta_scale = powf(theta_scale, col); - const float p = p0 + p_delta*(row/p_delta_rows); + const float p = p0[i2] + p_delta*i2; const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale; const float sin_theta = sinf(theta); @@ -5361,7 +5365,7 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons scale_f32<<>>(x, dst, scale, k); } -static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, +static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); @@ -5370,7 +5374,7 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i rope_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); } -static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, +static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); @@ -5379,7 +5383,7 @@ static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, co rope_neox_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); } -static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, +static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { GGML_ASSERT(ncols % 4 == 0); const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); @@ -6069,9 +6073,10 @@ inline void ggml_cuda_op_rope( const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; + const int64_t ne2 = dst->ne[2]; const int64_t nrows = ggml_nrows(src0); - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; @@ -6082,21 +6087,38 @@ inline void ggml_cuda_op_rope( memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; + //const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; + + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(src1->ne[0] == ne2); + + std::vector p0s(ne2); + for (int64_t i = 0; i < ne2; ++i) { + int n_past = ((int32_t *) src1->data)[i]; + p0s[i] = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; + } + + size_t p0d_as = 0; + float * p0d; + + p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as); + CUDA_CHECK(cudaMemcpyAsync(p0d, p0s.data(), ne2 * sizeof(float), cudaMemcpyHostToDevice, main_stream)); const bool is_neox = mode & 2; const bool is_glm = mode & 4; // compute if (is_glm) { - rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream); + rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, n_ctx, main_stream); } else if (is_neox) { GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); - rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream); + rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, main_stream); } else { - rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream); + rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, main_stream); } + ggml_cuda_pool_free(p0d, p0d_as); + (void) src1; (void) dst; (void) src1_dd; From fb92acdd6b763ba6f206010966db9156b65ba476 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 19 Sep 2023 00:17:55 +0200 Subject: [PATCH 2/4] better solution for p0 computation --- ggml-cuda.cu | 28 +++++++++++++++++++--------- llama.cpp | 6 ++++++ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 9ead57648a665..ef6e9fd59d539 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5,7 +5,6 @@ #include #include #include -#include #if defined(GGML_USE_HIPBLAS) #include @@ -440,6 +439,7 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt struct ggml_tensor_extra_gpu { void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs + bool copied; }; // this is faster on Windows @@ -4356,6 +4356,14 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, } // rope == RoPE == rotary positional embedding +static __global__ void compute_rope_p0(const int32_t * pos, float * p0, int n, int mode, float freq_scale) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + int p = pos[i]; + p0[i] = (((mode & 1) == 0 ? p : 0)) * freq_scale; + } +} + static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -6091,18 +6099,20 @@ inline void ggml_cuda_op_rope( GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->ne[0] == ne2); + GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); - std::vector p0s(ne2); - for (int64_t i = 0; i < ne2; ++i) { - int n_past = ((int32_t *) src1->data)[i]; - p0s[i] = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + if (!src1_extra->copied) { + CUDA_CHECK(cudaMemcpyAsync(src1_extra->data_device[id], src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream)); + src1_extra->copied = true; } size_t p0d_as = 0; - float * p0d; - - p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as); - CUDA_CHECK(cudaMemcpyAsync(p0d, p0s.data(), ne2 * sizeof(float), cudaMemcpyHostToDevice, main_stream)); + float * p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as); + compute_rope_p0<<<(ne2 + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE, CUDA_ROPE_BLOCK_SIZE, 0, main_stream>>>((int32_t*)src1_extra->data_device[id], p0d, ne2, mode, freq_scale); const bool is_neox = mode & 2; const bool is_glm = mode & 4; diff --git a/llama.cpp b/llama.cpp index 3e54fed7c2253..a30bb5dede688 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2705,6 +2705,7 @@ static struct ggml_cgraph * llm_build_llama( // KQ_pos - contains the positions struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); ggml_allocr_alloc(lctx.alloc, KQ_pos); if (!ggml_allocr_is_measure(lctx.alloc)) { int * data = (int *) KQ_pos->data; @@ -2715,6 +2716,7 @@ static struct ggml_cgraph * llm_build_llama( // K_shift struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); ggml_allocr_alloc(lctx.alloc, K_shift); if (!ggml_allocr_is_measure(lctx.alloc)) { int * data = (int *) K_shift->data; @@ -3087,6 +3089,7 @@ static struct ggml_cgraph * llm_build_baichaun( // KQ_pos - contains the positions struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); ggml_allocr_alloc(lctx.alloc, KQ_pos); if (!ggml_allocr_is_measure(lctx.alloc)) { int * data = (int *) KQ_pos->data; @@ -3097,6 +3100,7 @@ static struct ggml_cgraph * llm_build_baichaun( // K_shift struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); ggml_allocr_alloc(lctx.alloc, K_shift); if (!ggml_allocr_is_measure(lctx.alloc)) { int * data = (int *) K_shift->data; @@ -3486,6 +3490,7 @@ static struct ggml_cgraph * llm_build_falcon( // KQ_pos - contains the positions struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); ggml_allocr_alloc(lctx.alloc, KQ_pos); if (!ggml_allocr_is_measure(lctx.alloc)) { int * data = (int *) KQ_pos->data; @@ -3496,6 +3501,7 @@ static struct ggml_cgraph * llm_build_falcon( // K_shift struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); ggml_allocr_alloc(lctx.alloc, K_shift); if (!ggml_allocr_is_measure(lctx.alloc)) { int * data = (int *) K_shift->data; From cbe2bac281b2889fdc118050f9648fbcecbb62f8 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 19 Sep 2023 02:16:07 +0200 Subject: [PATCH 3/4] fix rope --- ggml-cuda.cu | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ef6e9fd59d539..268b3666a51af 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4365,7 +4365,7 @@ static __global__ void compute_rope_p0(const int32_t * pos, float * p0, int n, i } static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0, - const float p_delta, const int p_delta_rows, const float theta_scale) { + const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (col >= ncols) { @@ -4376,7 +4376,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c const int i = row*ncols + col; const int i2 = row/p_delta_rows; - const float theta = (p0[i2] + p_delta*i2)*powf(theta_scale, col/2); + const float theta = p0[i2]*powf(theta_scale, col/2); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4388,7 +4388,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c } static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float * p0, - const float p_delta, const int p_delta_rows, const float theta_scale) { + const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (col >= ncols) { @@ -4399,7 +4399,7 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco const int i = row*ncols + col/2; const int i2 = row/p_delta_rows; - const float theta = (p0[i2] + p_delta*i2)*powf(theta_scale, col/2); + const float theta = p0[i2]*powf(theta_scale, col/2); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4424,7 +4424,8 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol const int i2 = row/p_delta_rows; const float col_theta_scale = powf(theta_scale, col); - const float p = p0[i2] + p_delta*i2; + // FIXME: this is likely wrong + const float p = p0[i2]; const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale; const float sin_theta = sinf(theta); @@ -5374,21 +5375,21 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons } static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, - const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { + const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); - rope_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); + rope_f32<<>>(x, dst, ncols, p0, p_delta_rows, theta_scale); } static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, - const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { + const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); - rope_neox_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); + rope_neox_f32<<>>(x, dst, ncols, p0, p_delta_rows, theta_scale); } static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, @@ -6095,7 +6096,7 @@ inline void ggml_cuda_op_rope( memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); const float theta_scale = powf(freq_base, -2.0f/n_dims); - //const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; + // const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->ne[0] == ne2); @@ -6110,7 +6111,7 @@ inline void ggml_cuda_op_rope( src1_extra->copied = true; } - size_t p0d_as = 0; + size_t p0d_as; float * p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as); compute_rope_p0<<<(ne2 + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE, CUDA_ROPE_BLOCK_SIZE, 0, main_stream>>>((int32_t*)src1_extra->data_device[id], p0d, ne2, mode, freq_scale); @@ -6119,12 +6120,13 @@ inline void ggml_cuda_op_rope( // compute if (is_glm) { + GGML_ASSERT(false); rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, n_ctx, main_stream); } else if (is_neox) { GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); - rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, main_stream); + rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, ne01, theta_scale, main_stream); } else { - rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, main_stream); + rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, ne01, theta_scale, main_stream); } ggml_cuda_pool_free(p0d, p0d_as); From aa18b939802a1be7f65f6d77ccc032434e5b5e01 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 19 Sep 2023 08:51:05 +0200 Subject: [PATCH 4/4] simpler rope implementation --- ggml-cuda.cu | 69 ++++++++++++++++++++++++---------------------------- 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 268b3666a51af..14b1ecf7d2cf3 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4356,15 +4356,8 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, } // rope == RoPE == rotary positional embedding -static __global__ void compute_rope_p0(const int32_t * pos, float * p0, int n, int mode, float freq_scale) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { - int p = pos[i]; - p0[i] = (((mode & 1) == 0 ? p : 0)) * freq_scale; - } -} -static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0, +static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -4376,7 +4369,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c const int i = row*ncols + col; const int i2 = row/p_delta_rows; - const float theta = p0[i2]*powf(theta_scale, col/2); + const int p = pos != nullptr ? pos[i2] : 0; + const float p0 = p * freq_scale; + const float theta = p0*powf(theta_scale, col/2); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4387,8 +4382,8 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c dst[i + 1] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float * p0, - const int p_delta_rows, const float theta_scale) { +static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (col >= ncols) { @@ -4399,7 +4394,9 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco const int i = row*ncols + col/2; const int i2 = row/p_delta_rows; - const float theta = p0[i2]*powf(theta_scale, col/2); + const int p = pos != nullptr ? pos[i2] : 0; + const float p0 = p * freq_scale; + const float theta = p0*powf(theta_scale, col/2); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4410,8 +4407,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float * p0, - const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) { +static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale, const int n_ctx) { const int col = blockDim.x*blockIdx.x + threadIdx.x; const int half_n_dims = ncols/4; @@ -4425,9 +4422,9 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol const float col_theta_scale = powf(theta_scale, col); // FIXME: this is likely wrong - const float p = p0[i2]; + const int p = pos != nullptr ? pos[i2] : 0; - const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale; + const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale; const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4437,7 +4434,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; - const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale; + const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale; const float sin_block_theta = sinf(block_theta); const float cos_block_theta = cosf(block_theta); @@ -5374,31 +5371,31 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons scale_f32<<>>(x, dst, scale, k); } -static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, +static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); - rope_f32<<>>(x, dst, ncols, p0, p_delta_rows, theta_scale); + rope_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); } -static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, +static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); - rope_neox_f32<<>>(x, dst, ncols, p0, p_delta_rows, theta_scale); + rope_neox_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); } -static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, - const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { +static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { GGML_ASSERT(ncols % 4 == 0); const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE; const dim3 block_nums(num_blocks_x, nrows, 1); - rope_glm_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx); + rope_glm_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx); } static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, @@ -6105,32 +6102,30 @@ inline void ggml_cuda_op_rope( int id; CUDA_CHECK(cudaGetDevice(&id)); - struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - if (!src1_extra->copied) { - CUDA_CHECK(cudaMemcpyAsync(src1_extra->data_device[id], src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream)); - src1_extra->copied = true; + int * pos = nullptr; + if ((mode & 1) == 0) { + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + pos = (int *) src1_extra->data_device[id]; + if (!src1_extra->copied) { + CUDA_CHECK(cudaMemcpyAsync(pos, src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream)); + src1_extra->copied = true; + } } - size_t p0d_as; - float * p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as); - compute_rope_p0<<<(ne2 + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE, CUDA_ROPE_BLOCK_SIZE, 0, main_stream>>>((int32_t*)src1_extra->data_device[id], p0d, ne2, mode, freq_scale); - const bool is_neox = mode & 2; const bool is_glm = mode & 4; // compute if (is_glm) { GGML_ASSERT(false); - rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, n_ctx, main_stream); + rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream); } else if (is_neox) { GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); - rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, ne01, theta_scale, main_stream); + rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); } else { - rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, ne01, theta_scale, main_stream); + rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); } - ggml_cuda_pool_free(p0d, p0d_as); - (void) src1; (void) dst; (void) src1_dd;