Skip to content

Commit a6b048e

Browse files
committed
server : reuse context chunks
ggml-ci
1 parent edc2656 commit a6b048e

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

examples/server/server.cpp

+52-2
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ struct server_context {
800800
int slot_prompt_len = slot_prompt.size();
801801

802802
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
803-
int lcp_len = common_part(slot_prompt, prompt);
803+
int lcp_len = longest_common_prefix(slot_prompt, prompt);
804804

805805
// fraction of the common substring length compared to the current slot's prompt length
806806
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
@@ -2042,12 +2042,61 @@ struct server_context {
20422042

20432043
if (slot.params.cache_prompt) {
20442044
// reuse any previously computed tokens that are common with the new prompt
2045-
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
2045+
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
20462046

20472047
// push the prompt into the sampling context (do not apply grammar)
20482048
for (int i = 0; i < slot.n_past; ++i) {
20492049
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
20502050
}
2051+
2052+
// EXPERIMENTAL: reuse chunks from the cached prompt by shifting them in the new position
2053+
if (1) {
2054+
size_t head_c = slot.n_past; // cache
2055+
size_t head_p = slot.n_past; // current prompt
2056+
2057+
while (head_c < slot.cache_tokens.size() &&
2058+
head_p < prompt_tokens.size() &&
2059+
!llama_token_is_control(model, slot.cache_tokens[head_c]) &&
2060+
!llama_token_is_control(model, prompt_tokens[head_p])) {
2061+
2062+
size_t n_match = 0;
2063+
while (head_c + n_match < slot.cache_tokens.size() &&
2064+
head_p + n_match < prompt_tokens.size() &&
2065+
!llama_token_is_control(model, slot.cache_tokens[head_c + n_match]) &&
2066+
!llama_token_is_control(model, prompt_tokens[head_p + n_match]) &&
2067+
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
2068+
n_match++;
2069+
}
2070+
2071+
if (n_match > 32) {
2072+
// shift the KV chunk [head_c, head_c + n_match) -> [head_p, head_p + n_match)
2073+
SLT_DBG(slot, "shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", head_c, head_c + n_match, head_p, head_p + n_match);
2074+
//for (size_t i = head_p; i < head_p + n_match; i++) {
2075+
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2076+
//}
2077+
2078+
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
2079+
2080+
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
2081+
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
2082+
2083+
for (size_t i = 0; i < n_match; i++) {
2084+
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
2085+
2086+
common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
2087+
2088+
slot.n_past++;
2089+
}
2090+
2091+
head_c += n_match;
2092+
head_p += n_match;
2093+
} else {
2094+
head_c += 1;
2095+
}
2096+
}
2097+
2098+
SLT_DBG(slot, "new slot.n_past = %d, cache_tokens.size() = %zu\n", slot.n_past, slot.cache_tokens.size());
2099+
}
20512100
}
20522101
}
20532102

@@ -3257,6 +3306,7 @@ int main(int argc, char ** argv) {
32573306

32583307
ctx_server.queue_tasks.on_new_task(std::bind(
32593308
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
3309+
32603310
ctx_server.queue_tasks.on_update_slots(std::bind(
32613311
&server_context::update_slots, &ctx_server));
32623312

examples/server/utils.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,14 @@ static std::string gen_chatcmplid() {
195195
// other common utils
196196
//
197197

198-
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
198+
static size_t longest_common_prefix(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
199199
size_t i;
200200
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
201201

202202
return i;
203203
}
204204

205-
static size_t common_part(const std::string & a, const std::string & b) {
205+
static size_t longest_common_prefix(const std::string & a, const std::string & b) {
206206
size_t i;
207207
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
208208

0 commit comments

Comments
 (0)