@@ -4104,22 +4104,20 @@ static void llm_build_k_shift(
4104
4104
struct ggml_cgraph * graph,
4105
4105
llm_rope_type type,
4106
4106
int64_t n_ctx,
4107
- int n_rot,
4108
4107
float freq_base,
4109
4108
float freq_scale,
4110
4109
const llm_build_cb & cb) {
4111
4110
const int64_t n_layer = hparams.n_layer ;
4112
4111
const int64_t n_head_kv = hparams.n_head_kv ;
4113
4112
const int64_t n_embd_head_k = hparams.n_embd_head_k ;
4114
4113
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa ();
4114
+ const int32_t n_rot = hparams.n_rot ;
4115
4115
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx ;
4116
4116
const float ext_factor = cparams.yarn_ext_factor ;
4117
4117
const float attn_factor = cparams.yarn_attn_factor ;
4118
4118
const float beta_fast = cparams.yarn_beta_fast ;
4119
4119
const float beta_slow = cparams.yarn_beta_slow ;
4120
4120
4121
- GGML_ASSERT(n_embd_head_k % n_rot == 0);
4122
-
4123
4121
struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx, GGML_TYPE_I32, n_ctx);
4124
4122
cb (K_shift, " K_shift" , -1 );
4125
4123
@@ -4523,7 +4521,7 @@ struct llm_build_context {
4523
4521
4524
4522
// shift the entire K-cache if needed
4525
4523
if (do_rope_shift) {
4526
- llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
4524
+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
4527
4525
}
4528
4526
4529
4527
for (int il = 0 ; il < n_layer; ++il) {
@@ -4561,14 +4559,14 @@ struct llm_build_context {
4561
4559
4562
4560
Qcur = ggml_rope_custom (
4563
4561
ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4564
- n_embd_head , 0, 0, n_orig_ctx, freq_base, freq_scale,
4562
+ hparams. n_rot , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
4565
4563
ext_factor, attn_factor, beta_fast, beta_slow
4566
4564
);
4567
4565
cb (Qcur, " Qcur" , il);
4568
4566
4569
4567
Kcur = ggml_rope_custom (
4570
4568
ctx0, ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4571
- n_embd_head , 0, 0, n_orig_ctx, freq_base, freq_scale,
4569
+ hparams. n_rot , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
4572
4570
ext_factor, attn_factor, beta_fast, beta_slow
4573
4571
);
4574
4572
cb (Kcur, " Kcur" , il);
@@ -4691,6 +4689,7 @@ struct llm_build_context {
4691
4689
4692
4690
const int64_t n_embd_head = hparams.n_embd_head_v ;
4693
4691
GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
4692
+ GGML_ASSERT (n_embd_head == hparams.n_rot );
4694
4693
4695
4694
struct ggml_tensor * cur;
4696
4695
struct ggml_tensor * inpL;
@@ -4708,7 +4707,7 @@ struct llm_build_context {
4708
4707
4709
4708
// shift the entire K-cache if needed
4710
4709
if (do_rope_shift) {
4711
- llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
4710
+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
4712
4711
}
4713
4712
4714
4713
for (int il = 0 ; il < n_layer; ++il) {
@@ -4734,12 +4733,12 @@ struct llm_build_context {
4734
4733
case MODEL_7B:
4735
4734
Qcur = ggml_rope_custom (
4736
4735
ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4737
- n_embd_head , 0, 0, n_orig_ctx, freq_base, freq_scale,
4736
+ hparams. n_rot , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
4738
4737
ext_factor, attn_factor, beta_fast, beta_slow
4739
4738
);
4740
4739
Kcur = ggml_rope_custom (
4741
4740
ctx0, ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4742
- n_embd_head , 0, 0, n_orig_ctx, freq_base, freq_scale,
4741
+ hparams. n_rot , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
4743
4742
ext_factor, attn_factor, beta_fast, beta_slow
4744
4743
);
4745
4744
break ;
@@ -4812,6 +4811,7 @@ struct llm_build_context {
4812
4811
const int64_t n_embd_head = hparams.n_embd_head_v ;
4813
4812
const int64_t n_embd_gqa = hparams.n_embd_v_gqa ();
4814
4813
GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
4814
+ GGML_ASSERT (n_embd_head == hparams.n_rot );
4815
4815
4816
4816
struct ggml_tensor * cur;
4817
4817
struct ggml_tensor * inpL;
@@ -4829,7 +4829,7 @@ struct llm_build_context {
4829
4829
4830
4830
// shift the entire K-cache if needed
4831
4831
if (do_rope_shift) {
4832
- llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
4832
+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
4833
4833
}
4834
4834
4835
4835
for (int il = 0 ; il < n_layer; ++il) {
@@ -4870,13 +4870,13 @@ struct llm_build_context {
4870
4870
4871
4871
// using mode = 2 for neox mode
4872
4872
Qcur = ggml_rope_custom (
4873
- ctx0, Qcur, inp_pos, n_embd_head , 2, 0, n_orig_ctx,
4873
+ ctx0, Qcur, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
4874
4874
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
4875
4875
);
4876
4876
cb (Qcur, " Qcur" , il);
4877
4877
4878
4878
Kcur = ggml_rope_custom (
4879
- ctx0, Kcur, inp_pos, n_embd_head , 2, 0, n_orig_ctx,
4879
+ ctx0, Kcur, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
4880
4880
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
4881
4881
);
4882
4882
cb (Kcur, " Kcur" , il);
@@ -5033,9 +5033,8 @@ struct llm_build_context {
5033
5033
struct ggml_cgraph * gf = ggml_new_graph_custom (ctx0, LLAMA_MAX_NODES, false );
5034
5034
5035
5035
const int64_t n_embd_head = hparams.n_embd_head_v ;
5036
- GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
5037
-
5038
- const int64_t n_rot = n_embd_head_k / 2;
5036
+ GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
5037
+ GGML_ASSERT (n_embd_head/2 == hparams.n_rot );
5039
5038
5040
5039
struct ggml_tensor * cur;
5041
5040
struct ggml_tensor * inpL;
@@ -5052,7 +5051,7 @@ struct llm_build_context {
5052
5051
cb (KQ_mask, " KQ_mask" , -1 );
5053
5052
5054
5053
if (do_rope_shift) {
5055
- llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
5054
+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
5056
5055
}
5057
5056
5058
5057
for (int il = 0 ; il < n_layer; ++il) {
@@ -5112,15 +5111,15 @@ struct llm_build_context {
5112
5111
5113
5112
// RoPE the first n_rot of q/k, pass the other half, and concat.
5114
5113
struct ggml_tensor * qrot = ggml_view_3d (
5115
- ctx0, tmpq, n_rot, n_head, n_tokens,
5114
+ ctx0, tmpq, hparams. n_rot , n_head, n_tokens,
5116
5115
ggml_element_size (tmpq) * n_embd_head,
5117
5116
ggml_element_size (tmpq) * n_embd_head * n_head,
5118
5117
0
5119
5118
);
5120
5119
cb (qrot, " qrot" , il);
5121
5120
5122
5121
struct ggml_tensor * krot = ggml_view_3d (
5123
- ctx0, tmpk, n_rot, n_head, n_tokens,
5122
+ ctx0, tmpk, hparams. n_rot , n_head, n_tokens,
5124
5123
ggml_element_size (tmpk) * n_embd_head,
5125
5124
ggml_element_size (tmpk) * n_embd_head * n_head,
5126
5125
0
@@ -5129,29 +5128,29 @@ struct llm_build_context {
5129
5128
5130
5129
// get the second half of tmpq, e.g tmpq[n_rot:, :, :]
5131
5130
struct ggml_tensor * qpass = ggml_view_3d (
5132
- ctx0, tmpq, n_rot, n_head, n_tokens,
5131
+ ctx0, tmpq, hparams. n_rot , n_head, n_tokens,
5133
5132
ggml_element_size (tmpq) * n_embd_head,
5134
5133
ggml_element_size (tmpq) * n_embd_head * n_head,
5135
- ggml_element_size(tmpq) * n_rot
5134
+ ggml_element_size (tmpq) * hparams. n_rot
5136
5135
);
5137
5136
cb (qpass, " qpass" , il);
5138
5137
5139
5138
struct ggml_tensor * kpass = ggml_view_3d (
5140
- ctx0, tmpk, n_rot, n_head, n_tokens,
5139
+ ctx0, tmpk, hparams. n_rot , n_head, n_tokens,
5141
5140
ggml_element_size (tmpk) * n_embd_head,
5142
5141
ggml_element_size (tmpk) * n_embd_head * n_head,
5143
- ggml_element_size(tmpk) * n_rot
5142
+ ggml_element_size (tmpk) * hparams. n_rot
5144
5143
);
5145
5144
cb (kpass, " kpass" , il);
5146
5145
5147
5146
struct ggml_tensor * qrotated = ggml_rope_custom (
5148
- ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx,
5147
+ ctx0, qrot, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
5149
5148
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
5150
5149
);
5151
5150
cb (qrotated, " qrotated" , il);
5152
5151
5153
5152
struct ggml_tensor * krotated = ggml_rope_custom (
5154
- ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx,
5153
+ ctx0, krot, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
5155
5154
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
5156
5155
);
5157
5156
cb (krotated, " krotated" , il);
@@ -5531,6 +5530,7 @@ struct llm_build_context {
5531
5530
5532
5531
const int64_t n_embd_head = hparams.n_embd_head_v ;
5533
5532
GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
5533
+ GGML_ASSERT (n_embd_head == hparams.n_rot );
5534
5534
5535
5535
struct ggml_tensor * cur;
5536
5536
struct ggml_tensor * inpL;
@@ -5548,7 +5548,7 @@ struct llm_build_context {
5548
5548
5549
5549
// shift the entire K-cache if needed
5550
5550
if (do_rope_shift) {
5551
- llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb);
5551
+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
5552
5552
}
5553
5553
5554
5554
for (int il = 0 ; il < n_layer; ++il) {
@@ -5661,7 +5661,7 @@ struct llm_build_context {
5661
5661
5662
5662
// shift the entire K-cache if needed
5663
5663
if (do_rope_shift) {
5664
- llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
5664
+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
5665
5665
}
5666
5666
5667
5667
for (int il = 0 ; il < n_layer; ++il) {
@@ -5693,13 +5693,13 @@ struct llm_build_context {
5693
5693
5694
5694
// using mode = 2 for neox mode
5695
5695
Qcur = ggml_rope_custom (
5696
- ctx0, Qcur, inp_pos, n_embd_head , 2, 0, n_orig_ctx,
5696
+ ctx0, Qcur, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
5697
5697
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
5698
5698
);
5699
5699
cb (Qcur, " Qcur" , il);
5700
5700
5701
5701
Kcur = ggml_rope_custom (
5702
- ctx0, Kcur, inp_pos, n_embd_head , 2, 0, n_orig_ctx,
5702
+ ctx0, Kcur, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
5703
5703
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
5704
5704
);
5705
5705
cb (Kcur, " Kcur" , il);
@@ -5778,7 +5778,7 @@ struct llm_build_context {
5778
5778
5779
5779
// shift the entire K-cache if needed
5780
5780
if (do_rope_shift) {
5781
- llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
5781
+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
5782
5782
}
5783
5783
5784
5784
for (int il = 0 ; il < n_layer; ++il) {
@@ -5874,6 +5874,7 @@ struct llm_build_context {
5874
5874
5875
5875
const int64_t n_embd_head = hparams.n_embd_head_v ;
5876
5876
GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
5877
+ GGML_ASSERT (n_embd_head == hparams.n_rot );
5877
5878
5878
5879
struct ggml_tensor * cur;
5879
5880
struct ggml_tensor * inpL;
@@ -5891,7 +5892,7 @@ struct llm_build_context {
5891
5892
5892
5893
// shift the entire K-cache if needed
5893
5894
if (do_rope_shift) {
5894
- llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
5895
+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
5895
5896
}
5896
5897
5897
5898
for (int il = 0 ; il < n_layer; ++il) {
@@ -5917,13 +5918,13 @@ struct llm_build_context {
5917
5918
cb (Vcur, " Vcur" , il);
5918
5919
5919
5920
Qcur = ggml_rope_custom (
5920
- ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head , n_head, n_tokens), inp_pos,
5921
+ ctx0, ggml_reshape_3d (ctx0, Qcur, hparams. n_rot , n_head, n_tokens), inp_pos,
5921
5922
n_embd_head, 2 , 0 , n_orig_ctx, freq_base, freq_scale,
5922
5923
ext_factor, attn_factor, beta_fast, beta_slow);
5923
5924
cb (Qcur, " Qcur" , il);
5924
5925
5925
5926
Kcur = ggml_rope_custom (
5926
- ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head , n_head_kv, n_tokens), inp_pos,
5927
+ ctx0, ggml_reshape_3d (ctx0, Kcur, hparams. n_rot , n_head_kv, n_tokens), inp_pos,
5927
5928
n_embd_head, 2 , 0 , n_orig_ctx, freq_base, freq_scale,
5928
5929
ext_factor, attn_factor, beta_fast, beta_slow);
5929
5930
cb (Kcur, " Kcur" , il);
0 commit comments