Skip to content

Commit e17c849

Browse files
committed
switched to NTK aware scaling
1 parent e19483c commit e17c849

File tree

4 files changed

+26
-25
lines changed

4 files changed

+26
-25
lines changed

ggml-cuda.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -2223,10 +2223,10 @@ inline void ggml_cuda_op_rope(
22232223
const int n_ctx = ((int32_t *) src1->data)[3];
22242224
GGML_ASSERT(mode == 0);
22252225

2226-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
2226+
const float theta_scale = get_theta_scale(n_dims,n_past,n_ctx);
22272227
const float p0 = ((mode & 1) == 0 ? n_past + i02 : i02);
22282228

2229-
const float p = n_ctx <= GGML_TRAINING_CTX ? p0 : p0 * GGML_TRAINING_CTX / n_ctx;
2229+
const float p = p0;
22302230

22312231
// compute
22322232
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);

ggml.c

+21-16
Original file line numberDiff line numberDiff line change
@@ -4242,6 +4242,22 @@ static inline int ggml_up(int n, int m) {
42424242
#define ggml_assert_aligned(ptr) \
42434243
GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
42444244

4245+
float get_theta_scale(int n_dims,int n_past,int n_ctx)
4246+
{
4247+
if(n_ctx<=2048) //normie mode
4248+
{
4249+
return powf(10000.0, -2.0f/n_dims);
4250+
}
4251+
else
4252+
{
4253+
//using scaled NTK aware ctx
4254+
float a = (n_ctx<=4096?4.0:8.0);
4255+
float m = powf(a, n_dims / (n_dims - 2.0));
4256+
float s = powf(10000.0 * m, -2.0f/n_dims);
4257+
return s;
4258+
}
4259+
}
4260+
42454261
////////////////////////////////////////////////////////////////////////////////
42464262

42474263
struct ggml_context * ggml_init(struct ggml_init_params params) {
@@ -12531,7 +12547,7 @@ static void ggml_compute_forward_rope_f32(
1253112547
// row index used to determine which thread to use
1253212548
int ir = 0;
1253312549

12534-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
12550+
const float theta_scale = get_theta_scale(n_dims,n_past,n_ctx);
1253512551

1253612552
const bool is_neox = mode & 2;
1253712553
const bool is_glm = mode & 4;
@@ -12571,9 +12587,7 @@ static void ggml_compute_forward_rope_f32(
1257112587
dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
1257212588
}
1257312589
} else if (!is_neox) {
12574-
if (n_ctx > GGML_TRAINING_CTX) {
12575-
theta = theta * GGML_TRAINING_CTX / n_ctx;
12576-
}
12590+
1257712591
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
1257812592
const float cos_theta = cosf(theta);
1257912593
const float sin_theta = sinf(theta);
@@ -12674,7 +12688,7 @@ static void ggml_compute_forward_rope_f16(
1267412688
// row index used to determine which thread to use
1267512689
int ir = 0;
1267612690

12677-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
12691+
const float theta_scale = get_theta_scale(n_dims,n_past,n_ctx);
1267812692

1267912693
const bool is_neox = mode & 2;
1268012694
const bool is_glm = mode & 4;
@@ -12714,9 +12728,6 @@ static void ggml_compute_forward_rope_f16(
1271412728
dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
1271512729
}
1271612730
} if (!is_neox) {
12717-
if (n_ctx > GGML_TRAINING_CTX) {
12718-
theta = theta * GGML_TRAINING_CTX / n_ctx;
12719-
}
1272012731
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
1272112732
const float cos_theta = cosf(theta);
1272212733
const float sin_theta = sinf(theta);
@@ -12842,7 +12853,7 @@ static void ggml_compute_forward_rope_back_f32(
1284212853
// row index used to determine which thread to use
1284312854
int ir = 0;
1284412855

12845-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
12856+
const float theta_scale = get_theta_scale(n_dims,n_past,n_ctx);
1284612857

1284712858
const bool is_neox = mode & 2;
1284812859

@@ -12856,9 +12867,6 @@ static void ggml_compute_forward_rope_back_f32(
1285612867
float theta = (float)p;
1285712868

1285812869
if (!is_neox) {
12859-
if (n_ctx > GGML_TRAINING_CTX) {
12860-
theta = theta * GGML_TRAINING_CTX / n_ctx;
12861-
}
1286212870
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
1286312871
const float cos_theta = cosf(theta);
1286412872
const float sin_theta = sinf(theta);
@@ -12959,7 +12967,7 @@ static void ggml_compute_forward_rope_back_f16(
1295912967
// row index used to determine which thread to use
1296012968
int ir = 0;
1296112969

12962-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
12970+
const float theta_scale = get_theta_scale(n_dims,n_past,n_ctx);
1296312971

1296412972
const bool is_neox = mode & 2;
1296512973

@@ -12973,9 +12981,6 @@ static void ggml_compute_forward_rope_back_f16(
1297312981
float theta = (float)p;
1297412982

1297512983
if (!is_neox) {
12976-
if (n_ctx > GGML_TRAINING_CTX) {
12977-
theta = theta * GGML_TRAINING_CTX / n_ctx;
12978-
}
1297912984
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
1298012985
const float cos_theta = cosf(theta);
1298112986
const float sin_theta = sinf(theta);

ggml.h

+2-6
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,6 @@
201201
#define GGML_MAX_NAME 48
202202
#define GGML_DEFAULT_N_THREADS 4
203203

204-
// Maximum training context of the model in use
205-
// For the LLaMA models this is normally 2048, but somehow "stepping out" by 128 gives better results (tested at 7B and 13B)
206-
#ifndef GGML_TRAINING_CTX
207-
#define GGML_TRAINING_CTX 2176
208-
#endif
209-
210204
#define GGML_ASSERT(x) \
211205
do { \
212206
if (!(x)) { \
@@ -510,6 +504,8 @@ extern "C" {
510504
// use this to compute the memory overhead of a tensor
511505
GGML_API size_t ggml_tensor_overhead(void);
512506

507+
GGML_API float get_theta_scale(int n_dims,int n_past,int n_ctx);
508+
513509
// main
514510

515511
GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);

llama.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -2633,7 +2633,7 @@ struct llama_context * llama_new_context_with_model(
26332633

26342634
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type));
26352635

2636-
const size_t bigctxmul = (hparams.n_ctx>2048?2:1);
2636+
const size_t bigctxmul = (hparams.n_ctx>4096?3:(hparams.n_ctx>2048?2:1));
26372637
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0().at(ctx->model.type)*bigctxmul);
26382638
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)*bigctxmul);
26392639
}

0 commit comments

Comments
 (0)