@@ -800,7 +800,7 @@ struct server_context {
800
800
int slot_prompt_len = slot_prompt.size ();
801
801
802
802
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
803
- int lcp_len = common_part (slot_prompt, prompt);
803
+ int lcp_len = longest_common_prefix (slot_prompt, prompt);
804
804
805
805
// fraction of the common substring length compared to the current slot's prompt length
806
806
similarity = static_cast <float >(lcp_len) / slot_prompt_len;
@@ -2042,12 +2042,61 @@ struct server_context {
2042
2042
2043
2043
if (slot.params .cache_prompt ) {
2044
2044
// reuse any previously computed tokens that are common with the new prompt
2045
- slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
2045
+ slot.n_past = longest_common_prefix (slot.cache_tokens , prompt_tokens);
2046
2046
2047
2047
// push the prompt into the sampling context (do not apply grammar)
2048
2048
for (int i = 0 ; i < slot.n_past ; ++i) {
2049
2049
common_sampler_accept (slot.smpl , slot.cache_tokens [i], false );
2050
2050
}
2051
+
2052
+ // EXPERIMENTAL: reuse chunks from the cached prompt by shifting them in the new position
2053
+ if (1 ) {
2054
+ size_t head_c = slot.n_past ; // cache
2055
+ size_t head_p = slot.n_past ; // current prompt
2056
+
2057
+ while (head_c < slot.cache_tokens .size () &&
2058
+ head_p < prompt_tokens.size () &&
2059
+ !llama_token_is_control (model, slot.cache_tokens [head_c]) &&
2060
+ !llama_token_is_control (model, prompt_tokens[head_p])) {
2061
+
2062
+ size_t n_match = 0 ;
2063
+ while (head_c + n_match < slot.cache_tokens .size () &&
2064
+ head_p + n_match < prompt_tokens.size () &&
2065
+ !llama_token_is_control (model, slot.cache_tokens [head_c + n_match]) &&
2066
+ !llama_token_is_control (model, prompt_tokens[head_p + n_match]) &&
2067
+ slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2068
+ n_match++;
2069
+ }
2070
+
2071
+ if (n_match > 32 ) {
2072
+ // shift the KV chunk [head_c, head_c + n_match) -> [head_p, head_p + n_match)
2073
+ SLT_DBG (slot, " shifting KV cache [%zu, %zu) -> [%zu, %zu)\n " , head_c, head_c + n_match, head_p, head_p + n_match);
2074
+ // for (size_t i = head_p; i < head_p + n_match; i++) {
2075
+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2076
+ // }
2077
+
2078
+ const int64_t kv_shift = (int64_t ) head_p - (int64_t ) head_c;
2079
+
2080
+ llama_kv_cache_seq_rm (ctx, slot.id + 1 , head_p, head_c);
2081
+ llama_kv_cache_seq_add (ctx, slot.id + 1 , head_c, -1 , kv_shift);
2082
+
2083
+ for (size_t i = 0 ; i < n_match; i++) {
2084
+ slot.cache_tokens [head_p + i] = slot.cache_tokens [head_c + i];
2085
+
2086
+ common_sampler_accept (slot.smpl , slot.cache_tokens [head_p + i], false );
2087
+
2088
+ slot.n_past ++;
2089
+ }
2090
+
2091
+ head_c += n_match;
2092
+ head_p += n_match;
2093
+ } else {
2094
+ head_c += 1 ;
2095
+ }
2096
+ }
2097
+
2098
+ SLT_DBG (slot, " new slot.n_past = %d, cache_tokens.size() = %zu\n " , slot.n_past , slot.cache_tokens .size ());
2099
+ }
2051
2100
}
2052
2101
}
2053
2102
@@ -3257,6 +3306,7 @@ int main(int argc, char ** argv) {
3257
3306
3258
3307
ctx_server.queue_tasks .on_new_task (std::bind (
3259
3308
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
3309
+
3260
3310
ctx_server.queue_tasks .on_update_slots (std::bind (
3261
3311
&server_context::update_slots, &ctx_server));
3262
3312
0 commit comments