@@ -6956,9 +6956,9 @@ struct ggml_tensor * ggml_rope_impl(
6956
6956
int n_past,
6957
6957
int n_dims,
6958
6958
int mode,
6959
+ int n_ctx,
6959
6960
float freq_base,
6960
6961
float freq_scale,
6961
- int n_ctx,
6962
6962
bool inplace) {
6963
6963
GGML_ASSERT(n_past >= 0);
6964
6964
bool is_node = false;
@@ -6997,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
6997
6997
int n_dims,
6998
6998
int mode,
6999
6999
int n_ctx) {
7000
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx , false);
7000
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
7001
7001
}
7002
7002
7003
7003
struct ggml_tensor * ggml_rope_inplace(
@@ -7007,7 +7007,7 @@ struct ggml_tensor * ggml_rope_inplace(
7007
7007
int n_dims,
7008
7008
int mode,
7009
7009
int n_ctx) {
7010
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx , true);
7010
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
7011
7011
}
7012
7012
7013
7013
struct ggml_tensor * ggml_rope_custom_inplace(
@@ -7016,10 +7016,10 @@ struct ggml_tensor * ggml_rope_custom_inplace(
7016
7016
int n_past,
7017
7017
int n_dims,
7018
7018
int mode,
7019
+ int n_ctx,
7019
7020
float freq_base,
7020
- float freq_scale,
7021
- int n_ctx) {
7022
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
7021
+ float freq_scale) {
7022
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
7023
7023
}
7024
7024
7025
7025
// ggml_rope_back
@@ -7029,7 +7029,8 @@ struct ggml_tensor * ggml_rope_back(
7029
7029
struct ggml_tensor * a,
7030
7030
int n_past,
7031
7031
int n_dims,
7032
- int mode) {
7032
+ int mode,
7033
+ int n_ctx) {
7033
7034
GGML_ASSERT(n_past >= 0);
7034
7035
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
7035
7036
@@ -7043,12 +7044,13 @@ struct ggml_tensor * ggml_rope_back(
7043
7044
7044
7045
ggml_scratch_save(ctx);
7045
7046
7046
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3 );
7047
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4 );
7047
7048
ggml_set_name(b, "n_past, n_dims, mode");
7048
7049
7049
7050
((int32_t *) b->data)[0] = n_past;
7050
7051
((int32_t *) b->data)[1] = n_dims;
7051
7052
((int32_t *) b->data)[2] = mode;
7053
+ ((int32_t *) b->data)[3] = n_ctx;
7052
7054
7053
7055
ggml_scratch_load(ctx);
7054
7056
@@ -15740,13 +15742,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15740
15742
const int n_past = ((int32_t *) src1->data)[0];
15741
15743
const int n_dims = ((int32_t *) src1->data)[1];
15742
15744
const int mode = ((int32_t *) src1->data)[2];
15745
+ const int n_ctx = ((int32_t *) src1->data)[3];
15743
15746
src0->grad = ggml_add_impl(ctx,
15744
15747
src0->grad,
15745
15748
ggml_rope_back(ctx,
15746
15749
tensor->grad,
15747
15750
n_past,
15748
15751
n_dims,
15749
- mode),
15752
+ mode,
15753
+ n_ctx),
15750
15754
inplace);
15751
15755
}
15752
15756
if (src1->grad) {
@@ -15757,7 +15761,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15757
15761
{
15758
15762
if (src0->grad) {
15759
15763
assert(src1->type == GGML_TYPE_I32);
15760
- assert(ggml_nelements(src1) == 3 );
15764
+ assert(ggml_nelements(src1) == 4 );
15761
15765
const int n_past = ((int32_t *) src1->data)[0];
15762
15766
const int n_dims = ((int32_t *) src1->data)[1];
15763
15767
const int mode = ((int32_t *) src1->data)[2];
0 commit comments