Skip to content

Commit 513f861

Browse files
committed
ggml : fix rope args order + assert (#2054)
1 parent 3973b25 commit 513f861

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

examples/train-text-from-scratch/train-text-from-scratch.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1434,7 +1434,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
14341434
gf->perf_time_us = 0;
14351435

14361436
const auto & hparams = model->hparams;
1437-
//const int n_ctx = hparams.n_ctx;
1437+
const int n_ctx = hparams.n_ctx;
14381438
const int n_vocab = hparams.n_vocab;
14391439
const int n_embd = hparams.n_embd;
14401440
const int n_layer = hparams.n_layer;
@@ -1863,10 +1863,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
18631863
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
18641864
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
18651865
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
1866-
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
1866+
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
18671867
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
18681868
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
1869-
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
1869+
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
18701870
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
18711871
t04->grad = expand(gb, ggml_add_inplace(ctx0,
18721872
ggml_add_inplace(ctx0,

ggml.c

+14-10
Original file line numberDiff line numberDiff line change
@@ -6956,9 +6956,9 @@ struct ggml_tensor * ggml_rope_impl(
69566956
int n_past,
69576957
int n_dims,
69586958
int mode,
6959+
int n_ctx,
69596960
float freq_base,
69606961
float freq_scale,
6961-
int n_ctx,
69626962
bool inplace) {
69636963
GGML_ASSERT(n_past >= 0);
69646964
bool is_node = false;
@@ -6997,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
69976997
int n_dims,
69986998
int mode,
69996999
int n_ctx) {
7000-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
7000+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
70017001
}
70027002

70037003
struct ggml_tensor * ggml_rope_inplace(
@@ -7007,7 +7007,7 @@ struct ggml_tensor * ggml_rope_inplace(
70077007
int n_dims,
70087008
int mode,
70097009
int n_ctx) {
7010-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
7010+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
70117011
}
70127012

70137013
struct ggml_tensor * ggml_rope_custom_inplace(
@@ -7016,10 +7016,10 @@ struct ggml_tensor * ggml_rope_custom_inplace(
70167016
int n_past,
70177017
int n_dims,
70187018
int mode,
7019+
int n_ctx,
70197020
float freq_base,
7020-
float freq_scale,
7021-
int n_ctx) {
7022-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
7021+
float freq_scale) {
7022+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
70237023
}
70247024

70257025
// ggml_rope_back
@@ -7029,7 +7029,8 @@ struct ggml_tensor * ggml_rope_back(
70297029
struct ggml_tensor * a,
70307030
int n_past,
70317031
int n_dims,
7032-
int mode) {
7032+
int mode,
7033+
int n_ctx) {
70337034
GGML_ASSERT(n_past >= 0);
70347035
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
70357036

@@ -7043,12 +7044,13 @@ struct ggml_tensor * ggml_rope_back(
70437044

70447045
ggml_scratch_save(ctx);
70457046

7046-
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
7047+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
70477048
ggml_set_name(b, "n_past, n_dims, mode");
70487049

70497050
((int32_t *) b->data)[0] = n_past;
70507051
((int32_t *) b->data)[1] = n_dims;
70517052
((int32_t *) b->data)[2] = mode;
7053+
((int32_t *) b->data)[3] = n_ctx;
70527054

70537055
ggml_scratch_load(ctx);
70547056

@@ -15740,13 +15742,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1574015742
const int n_past = ((int32_t *) src1->data)[0];
1574115743
const int n_dims = ((int32_t *) src1->data)[1];
1574215744
const int mode = ((int32_t *) src1->data)[2];
15745+
const int n_ctx = ((int32_t *) src1->data)[3];
1574315746
src0->grad = ggml_add_impl(ctx,
1574415747
src0->grad,
1574515748
ggml_rope_back(ctx,
1574615749
tensor->grad,
1574715750
n_past,
1574815751
n_dims,
15749-
mode),
15752+
mode,
15753+
n_ctx),
1575015754
inplace);
1575115755
}
1575215756
if (src1->grad) {
@@ -15757,7 +15761,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1575715761
{
1575815762
if (src0->grad) {
1575915763
assert(src1->type == GGML_TYPE_I32);
15760-
assert(ggml_nelements(src1) == 3);
15764+
assert(ggml_nelements(src1) == 4);
1576115765
const int n_past = ((int32_t *) src1->data)[0];
1576215766
const int n_dims = ((int32_t *) src1->data)[1];
1576315767
const int mode = ((int32_t *) src1->data)[2];

ggml.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -1128,9 +1128,9 @@ extern "C" {
11281128
int n_past,
11291129
int n_dims,
11301130
int mode,
1131+
int n_ctx,
11311132
float freq_base,
1132-
float freq_scale,
1133-
int n_ctx);
1133+
float freq_scale);
11341134

11351135
// rotary position embedding backward, i.e compute dx from dy
11361136
// a - dy
@@ -1139,7 +1139,8 @@ extern "C" {
11391139
struct ggml_tensor * a,
11401140
int n_past,
11411141
int n_dims,
1142-
int mode);
1142+
int mode,
1143+
int n_ctx);
11431144

11441145
// alibi position embedding
11451146
// in-place, returns view(a)

llama.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1452,11 +1452,11 @@ static bool llama_eval_internal(
14521452
offload_func_kq(tmpq);
14531453
ggml_set_name(tmpq, "tmpq");
14541454

1455-
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
1455+
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale);
14561456
offload_func_kq(Kcur);
14571457
ggml_set_name(Kcur, "Kcur");
14581458

1459-
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
1459+
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale);
14601460
offload_func_kq(Qcur);
14611461
ggml_set_name(Qcur, "Qcur");
14621462

0 commit comments

Comments
 (0)