@@ -193,21 +193,15 @@ struct server_slot {
193
193
194
194
llama_token sampled;
195
195
196
- int32_t ga_i = 0 ; // group-attention state
197
- int32_t ga_n = 1 ; // group-attention factor
198
- int32_t ga_w = 512 ; // group-attention width
199
-
200
- int32_t n_past_se = 0 ; // self-extend
201
-
202
196
// stats
203
- size_t n_sent_text = 0 ; // number of sent text character
197
+ size_t n_sent_text = 0 ; // number of sent text character
204
198
size_t n_sent_token_probs = 0 ;
205
199
206
200
int64_t t_start_process_prompt;
207
201
int64_t t_start_generation;
208
202
209
203
double t_prompt_processing; // ms
210
- double t_token_generation; // ms
204
+ double t_token_generation; // ms
211
205
212
206
std::function<void (int )> callback_on_release;
213
207
@@ -225,8 +219,6 @@ struct server_slot {
225
219
n_sent_text = 0 ;
226
220
n_sent_token_probs = 0 ;
227
221
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
228
- ga_i = 0 ;
229
- n_past_se = 0 ;
230
222
231
223
generated_token_probs.clear ();
232
224
}
@@ -705,22 +697,6 @@ struct server_context {
705
697
706
698
SLT_INF (slot, " new slot n_ctx_slot = %d\n " , slot.n_ctx );
707
699
708
- const int ga_n = params.grp_attn_n ;
709
- const int ga_w = params.grp_attn_w ;
710
-
711
- if (ga_n != 1 ) {
712
- GGML_ASSERT (ga_n > 0 && " ga_n must be positive" ); // NOLINT
713
- GGML_ASSERT (ga_w % ga_n == 0 && " ga_w must be a multiple of ga_n" ); // NOLINT
714
- // GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
715
- // GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
716
-
717
- SLT_INF (slot, " slot self-extend: ga_n = %d, ga_w = %d\n " , ga_n, ga_w);
718
- }
719
-
720
- slot.ga_i = 0 ;
721
- slot.ga_n = ga_n;
722
- slot.ga_w = ga_w;
723
-
724
700
slot.sparams = params.sparams ;
725
701
726
702
slot.callback_on_release = [this ](int ) {
@@ -906,19 +882,14 @@ struct server_context {
906
882
}
907
883
if (data.contains (" json_schema" ) && !data.contains (" grammar" )) {
908
884
try {
909
- auto schema = json_value (data, " json_schema" , json::object ());
910
- slot.sparams .grammar = json_schema_to_grammar (schema);
885
+ auto schema = json_value (data, " json_schema" , json::object ());
886
+ slot.sparams .grammar = json_schema_to_grammar (schema);
911
887
} catch (const std::exception & e) {
912
888
send_error (task, std::string (" \" json_schema\" : " ) + e.what (), ERROR_TYPE_INVALID_REQUEST);
913
889
return false ;
914
890
}
915
891
} else {
916
- slot.sparams .grammar = json_value (data, " grammar" , default_sparams.grammar );
917
- }
918
-
919
- if (slot.params .cache_prompt && slot.ga_n != 1 ) {
920
- slot.params .cache_prompt = false ;
921
- SLT_WRN (slot, " %s" , " group-attention is not supported with prompt caching. disabling cache\n " );
892
+ slot.sparams .grammar = json_value (data, " grammar" , default_sparams.grammar );
922
893
}
923
894
924
895
if (slot.n_predict > 0 && slot.params .n_predict > slot.n_predict ) {
@@ -1131,12 +1102,13 @@ struct server_context {
1131
1102
}
1132
1103
1133
1104
// if context shift is disabled, we stop when it reaches the context limit
1134
- if (slot.n_decoded >= slot.n_ctx ) {
1105
+ if (slot.n_past >= slot.n_ctx ) {
1135
1106
slot.truncated = true ;
1136
1107
slot.stopped_limit = true ;
1137
1108
slot.has_next_token = false ;
1138
1109
1139
- SLT_DBG (slot, " stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n " , slot.n_decoded , slot.n_ctx );
1110
+ SLT_DBG (slot, " stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n " ,
1111
+ slot.n_decoded , slot.n_prompt_tokens , slot.n_past , slot.n_ctx );
1140
1112
}
1141
1113
1142
1114
if (llama_token_is_eog (model, result.tok )) {
@@ -1148,13 +1120,13 @@ struct server_context {
1148
1120
1149
1121
const auto n_ctx_train = llama_n_ctx_train (model);
1150
1122
1151
- if (slot.params .n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && slot. n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1123
+ if (slot.params .n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1152
1124
slot.truncated = true ;
1153
1125
slot.stopped_limit = true ;
1154
1126
slot.has_next_token = false ; // stop prediction
1155
1127
1156
1128
SLT_WRN (slot,
1157
- " n_predict (%d) is not set and self-context extend is disabled . "
1129
+ " n_predict (%d) is set for infinite generation . "
1158
1130
" Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n " ,
1159
1131
slot.params .n_predict , n_ctx_train);
1160
1132
}
@@ -1826,38 +1798,36 @@ struct server_context {
1826
1798
// apply context-shift if needed
1827
1799
// TODO: simplify and improve
1828
1800
for (server_slot & slot : slots) {
1829
- if (slot.ga_n == 1 ) {
1830
- if (slot.is_processing () && slot.n_past >= slot.n_ctx - 1 ) {
1831
- if (!params.ctx_shift ) {
1832
- // this check is redundant (for good)
1833
- // we should never get here, because generation should already stopped in process_token()
1834
- slot.release ();
1835
- send_error (slot, " context shift is disabled" , ERROR_TYPE_SERVER);
1836
- continue ;
1837
- }
1838
-
1839
- // Shift context
1840
- const int n_keep = slot.params .n_keep + add_bos_token;
1841
- const int n_left = slot.n_past - n_keep;
1842
- const int n_discard = slot.params .n_discard ? slot.params .n_discard : (n_left / 2 );
1801
+ if (slot.is_processing () && slot.n_past + 1 >= slot.n_ctx ) {
1802
+ if (!params.ctx_shift ) {
1803
+ // this check is redundant (for good)
1804
+ // we should never get here, because generation should already stopped in process_token()
1805
+ slot.release ();
1806
+ send_error (slot, " context shift is disabled" , ERROR_TYPE_SERVER);
1807
+ continue ;
1808
+ }
1843
1809
1844
- SLT_WRN (slot, " slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n " , n_keep, n_left, n_discard);
1810
+ // Shift context
1811
+ const int n_keep = slot.params .n_keep + add_bos_token;
1812
+ const int n_left = slot.n_past - n_keep;
1813
+ const int n_discard = slot.params .n_discard ? slot.params .n_discard : (n_left / 2 );
1845
1814
1846
- llama_kv_cache_seq_rm (ctx, slot.id + 1 , n_keep , n_keep + n_discard);
1847
- llama_kv_cache_seq_add (ctx, slot.id + 1 , n_keep + n_discard, slot.n_past , -n_discard);
1815
+ SLT_WRN (slot, " slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n " , n_keep, n_left, n_discard);
1848
1816
1849
- if (slot.params .cache_prompt ) {
1850
- for (size_t i = n_keep + n_discard; i < slot.cache_tokens .size (); i++) {
1851
- slot.cache_tokens [i - n_discard] = slot.cache_tokens [i];
1852
- }
1817
+ llama_kv_cache_seq_rm (ctx, slot.id + 1 , n_keep , n_keep + n_discard);
1818
+ llama_kv_cache_seq_add (ctx, slot.id + 1 , n_keep + n_discard, slot.n_past , -n_discard);
1853
1819
1854
- slot.cache_tokens .resize (slot.cache_tokens .size () - n_discard);
1820
+ if (slot.params .cache_prompt ) {
1821
+ for (size_t i = n_keep + n_discard; i < slot.cache_tokens .size (); i++) {
1822
+ slot.cache_tokens [i - n_discard] = slot.cache_tokens [i];
1855
1823
}
1856
1824
1857
- slot.n_past -= n_discard;
1858
-
1859
- slot.truncated = true ;
1825
+ slot.cache_tokens .resize (slot.cache_tokens .size () - n_discard);
1860
1826
}
1827
+
1828
+ slot.n_past -= n_discard;
1829
+
1830
+ slot.truncated = true ;
1861
1831
}
1862
1832
}
1863
1833
@@ -1872,9 +1842,7 @@ struct server_context {
1872
1842
1873
1843
slot.i_batch = batch.n_tokens ;
1874
1844
1875
- const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past ;
1876
-
1877
- common_batch_add (batch, slot.sampled , slot_npast, { slot.id + 1 }, true );
1845
+ common_batch_add (batch, slot.sampled , slot.n_past , { slot.id + 1 }, true );
1878
1846
1879
1847
slot.n_past += 1 ;
1880
1848
@@ -1993,6 +1961,8 @@ struct server_context {
1993
1961
} else {
1994
1962
if (!params.ctx_shift ) {
1995
1963
// if context shift is disabled, we make sure prompt size is smaller than KV size
1964
+ // TODO: there should be a separate parameter that control prompt truncation
1965
+ // context shift should be applied only during the generation phase
1996
1966
if (slot.n_prompt_tokens >= slot.n_ctx ) {
1997
1967
slot.release ();
1998
1968
send_error (slot, " the request exceeds the available context size. try increasing the context size or enable context shift" , ERROR_TYPE_INVALID_REQUEST);
@@ -2005,7 +1975,7 @@ struct server_context {
2005
1975
slot.params .n_keep = std::min (slot.n_ctx - 4 , slot.params .n_keep );
2006
1976
2007
1977
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
2008
- if (slot.ga_n == 1 && slot. n_prompt_tokens >= slot.n_ctx ) {
1978
+ if (slot.n_prompt_tokens >= slot.n_ctx ) {
2009
1979
const int n_left = slot.n_ctx - slot.params .n_keep ;
2010
1980
2011
1981
const int n_block_size = n_left / 2 ;
@@ -2032,12 +2002,7 @@ struct server_context {
2032
2002
2033
2003
common_sampler_reset (slot.smpl );
2034
2004
2035
- if (!slot.params .cache_prompt ) {
2036
- slot.n_past_se = 0 ;
2037
- slot.ga_i = 0 ;
2038
- } else {
2039
- GGML_ASSERT (slot.ga_n == 1 );
2040
-
2005
+ if (slot.params .cache_prompt ) {
2041
2006
// reuse any previously computed tokens that are common with the new prompt
2042
2007
slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
2043
2008
@@ -2053,9 +2018,6 @@ struct server_context {
2053
2018
SLT_WRN (slot, " need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n " , slot.n_past , slot.n_prompt_tokens );
2054
2019
2055
2020
slot.n_past --;
2056
- if (slot.ga_i > 0 ) {
2057
- slot.n_past_se --;
2058
- }
2059
2021
}
2060
2022
2061
2023
slot.n_prompt_tokens_processed = 0 ;
@@ -2081,52 +2043,31 @@ struct server_context {
2081
2043
}
2082
2044
2083
2045
// keep only the common part
2084
- int p0 = slot.n_past ;
2085
-
2086
- if (!llama_kv_cache_seq_rm (ctx, slot.id + 1 , p0, -1 )) {
2046
+ if (!llama_kv_cache_seq_rm (ctx, slot.id + 1 , slot.n_past , -1 )) {
2087
2047
// could not partially delete (likely using a non-Transformer model)
2088
2048
llama_kv_cache_seq_rm (ctx, slot.id + 1 , -1 , -1 );
2089
2049
2090
- p0 = 0 ;
2091
-
2092
2050
// there is no common part left
2093
2051
slot.n_past = 0 ;
2094
- slot.n_past_se = 0 ;
2095
- slot.ga_i = 0 ;
2096
2052
2097
2053
common_sampler_reset (slot.smpl );
2098
2054
}
2099
2055
2056
+ SLT_INF (slot, " kv cache rm [%d, end)\n " , slot.n_past );
2057
+
2100
2058
// remove the non-common part from the cache
2101
2059
slot.cache_tokens .resize (slot.n_past );
2102
2060
2103
- SLT_INF (slot, " kv cache rm [%d, end)\n " , p0);
2104
-
2105
- int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past ;
2106
-
2107
- int32_t ga_i = slot.ga_i ;
2108
- int32_t ga_n = slot.ga_n ;
2109
- int32_t ga_w = slot.ga_w ;
2110
-
2111
2061
// add prompt tokens for processing in the current batch
2112
- // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
2113
- for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past ) {
2114
- if (slot.ga_n != 1 ) {
2115
- while (slot_npast >= ga_i + ga_w) {
2116
- const int bd = (ga_w/ga_n)*(ga_n - 1 );
2117
- slot_npast -= bd;
2118
- ga_i += ga_w/ga_n;
2119
- }
2120
- }
2121
-
2122
- common_batch_add (batch, prompt_tokens[slot.n_past ], slot_npast, { slot.id + 1 }, false );
2062
+ while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
2063
+ common_batch_add (batch, prompt_tokens[slot.n_past ], slot.n_past , { slot.id + 1 }, false );
2123
2064
2124
2065
if (slot.params .cache_prompt ) {
2125
2066
slot.cache_tokens .push_back (prompt_tokens[slot.n_past ]);
2126
2067
}
2127
2068
2128
2069
slot.n_prompt_tokens_processed ++;
2129
- slot_npast ++;
2070
+ slot. n_past ++;
2130
2071
}
2131
2072
2132
2073
SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , batch.n_tokens , (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
@@ -2167,34 +2108,6 @@ struct server_context {
2167
2108
for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
2168
2109
const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
2169
2110
2170
- for (auto & slot : slots) {
2171
- if (slot.ga_n != 1 ) {
2172
- // context extension via Self-Extend
2173
- // TODO: simplify and/or abstract this
2174
- while (slot.n_past_se >= slot.ga_i + slot.ga_w ) {
2175
- const int ib = (slot.ga_n * slot.ga_i ) / slot.ga_w ;
2176
- const int bd = (slot.ga_w / slot.ga_n ) * (slot.ga_n - 1 );
2177
- const int dd = (slot.ga_w / slot.ga_n ) - ib * bd - slot.ga_w ;
2178
-
2179
- SLT_DBG (slot, " shift: [%6d, %6d] + %6d -> [%6d, %6d]\n " , slot.ga_i , slot.n_past_se , ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
2180
- SLT_DBG (slot, " div: [%6d, %6d] / %6d -> [%6d, %6d]\n " , slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w , slot.ga_n , (slot.ga_i + ib * bd) / slot.ga_n , (slot.ga_i + ib * bd + slot.ga_w ) / slot.ga_n );
2181
- SLT_DBG (slot, " shift: [%6d, %6d] + %6d -> [%6d, %6d]\n " , slot.ga_i + ib * bd + slot.ga_w , slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
2182
-
2183
- llama_kv_cache_seq_add (ctx, slot.id + 1 , slot.ga_i , slot.n_past_se , ib * bd);
2184
- llama_kv_cache_seq_div (ctx, slot.id + 1 , slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w , slot.ga_n );
2185
- llama_kv_cache_seq_add (ctx, slot.id + 1 , slot.ga_i + ib * bd + slot.ga_w , slot.n_past_se + ib * bd, dd);
2186
-
2187
- slot.n_past_se -= bd;
2188
-
2189
- slot.ga_i += slot.ga_w / slot.ga_n ;
2190
-
2191
- SLT_DBG (slot, " \n n_past_old = %d, n_past = %d, ga_i = %d\n\n " , slot.n_past_se + bd, slot.n_past_se , slot.ga_i );
2192
- }
2193
-
2194
- slot.n_past_se += n_tokens;
2195
- }
2196
- }
2197
-
2198
2111
llama_batch batch_view = {
2199
2112
n_tokens,
2200
2113
batch.token + i,
0 commit comments