Skip to content

Commit beea6e1

Browse files
kaetemimartindevansggerganov
authored
llama : save and restore kv cache for single seq id (#6341)
* llama : save and restore kv cache for single seq id * remove trailing whitespace * respond error in case there's no space in the kv cache * add kv seq save restore to test case * add --slot-save-path arg to enable save restore and restrict save location * Returning 0 for some cases, instead of asserting. * cleanup error cases * rename sequence state functions * rename state get set functions * add previous function names back in with DEPRECATED notice * update doc * adjust endpoints to preferred style * fix restoring zero cell count * handle seq rm return value * unused param * keep in the size check * fix return types * add server test case for slot save restore * cleanup * add cake * cleanup style * add special * removing a whole sequence never fails * move sequence state file functionality from server to llama to match session api and add version tags * catch exceptions on save as well * error log messages * check types for stricter restore * update server doc * readme : update API changes date * strict filename validation * move include, reject bom as well * also reject empty filename * reject whitespace and trailing dot --------- Co-authored-by: Martin Evans <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 87fb5b4 commit beea6e1

File tree

11 files changed

+1086
-31
lines changed

11 files changed

+1086
-31
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
1010

1111
### Recent API changes
1212

13+
- [2024 Apr 4] State and session file functions reorganized under `llama_state_*` https://github.com/ggerganov/llama.cpp/pull/6341
1314
- [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122
1415
- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017
1516
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328

common/common.cpp

+72-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <unordered_set>
1717
#include <vector>
1818
#include <cinttypes>
19+
#include <codecvt>
1920

2021
#if defined(__APPLE__) && defined(__MACH__)
2122
#include <sys/types.h>
@@ -27,7 +28,6 @@
2728
#ifndef NOMINMAX
2829
# define NOMINMAX
2930
#endif
30-
#include <codecvt>
3131
#include <locale>
3232
#include <windows.h>
3333
#include <fcntl.h>
@@ -1500,6 +1500,77 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
15001500
GGML_UNREACHABLE();
15011501
}
15021502

1503+
// Validate if a filename is safe to use
1504+
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
1505+
bool validate_file_name(const std::string & filename) {
1506+
if (!filename.length()) {
1507+
// Empty filename invalid
1508+
return false;
1509+
}
1510+
if (filename.length() > 255) {
1511+
// Limit at common largest possible filename on Linux filesystems
1512+
// to avoid unnecessary further validation
1513+
// (On systems with smaller limits it will be caught by the OS)
1514+
return false;
1515+
}
1516+
1517+
std::u32string filename_utf32;
1518+
try {
1519+
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
1520+
filename_utf32 = converter.from_bytes(filename);
1521+
1522+
// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
1523+
// or invalid encodings were encountered. Reject such attempts
1524+
std::string filename_reencoded = converter.to_bytes(filename_utf32);
1525+
if (filename_reencoded != filename) {
1526+
return false;
1527+
}
1528+
} catch (const std::exception &) {
1529+
return false;
1530+
}
1531+
1532+
// Check for forbidden codepoints:
1533+
// - Control characters
1534+
// - Unicode equivalents of illegal characters
1535+
// - UTF-16 surrogate pairs
1536+
// - UTF-8 replacement character
1537+
// - Byte order mark (BOM)
1538+
// - Illegal characters: / \ : * ? " < > |
1539+
for (char32_t c : filename_utf32) {
1540+
if (c <= 0x1F // Control characters (C0)
1541+
|| c == 0x7F // Control characters (DEL)
1542+
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
1543+
|| c == 0xFF0E // Fullwidth Full Stop (period equivalent)
1544+
|| c == 0x2215 // Division Slash (forward slash equivalent)
1545+
|| c == 0x2216 // Set Minus (backslash equivalent)
1546+
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
1547+
|| c == 0xFFFD // Replacement Character (UTF-8)
1548+
|| c == 0xFEFF // Byte Order Mark (BOM)
1549+
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
1550+
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
1551+
return false;
1552+
}
1553+
}
1554+
1555+
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
1556+
// Unicode and other whitespace is not affected, only 0x20 space
1557+
if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') {
1558+
return false;
1559+
}
1560+
1561+
// Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead)
1562+
if (filename.find("..") != std::string::npos) {
1563+
return false;
1564+
}
1565+
1566+
// Reject "."
1567+
if (filename == ".") {
1568+
return false;
1569+
}
1570+
1571+
return true;
1572+
}
1573+
15031574
//
15041575
// String utils
15051576
//

common/common.h

+2
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ std::string gpt_random_prompt(std::mt19937 & rng);
179179

180180
void process_escapes(std::string& input);
181181

182+
bool validate_file_name(const std::string & filename);
183+
182184
//
183185
// String utils
184186
//

examples/main/main.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
235235
// The file exists and is not empty
236236
session_tokens.resize(n_ctx);
237237
size_t n_token_count_out = 0;
238-
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
238+
if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
239239
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
240240
return 1;
241241
}
@@ -693,7 +693,7 @@ int main(int argc, char ** argv) {
693693
// optionally save the session on first sample (for faster prompt loading next time)
694694
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
695695
need_to_save_session = false;
696-
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
696+
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
697697

698698
LOG("saved session to %s\n", path_session.c_str());
699699
}
@@ -935,7 +935,7 @@ int main(int argc, char ** argv) {
935935

936936
if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
937937
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
938-
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
938+
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
939939
}
940940

941941
llama_print_timings(ctx);

examples/save-load-state/save-load-state.cpp

+95-6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ int main(int argc, char ** argv) {
2424

2525
std::string result0;
2626
std::string result1;
27+
std::string result2;
2728

2829
// init
2930
llama_model * model;
@@ -44,8 +45,8 @@ int main(int argc, char ** argv) {
4445

4546
// save state (rng, logits, embedding and kv_cache) to file
4647
{
47-
std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
48-
const size_t written = llama_copy_state_data(ctx, state_mem.data());
48+
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
49+
const size_t written = llama_state_get_data(ctx, state_mem.data());
4950

5051
FILE *fp_write = fopen("dump_state.bin", "wb");
5152
fwrite(state_mem.data(), 1, written, fp_write);
@@ -97,13 +98,13 @@ int main(int argc, char ** argv) {
9798

9899
// load state (rng, logits, embedding and kv_cache) from file
99100
{
100-
std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
101+
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
101102

102103
FILE * fp_read = fopen("dump_state.bin", "rb");
103104
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
104105
fclose(fp_read);
105106

106-
if (read != llama_set_state_data(ctx2, state_mem.data())) {
107+
if (read != llama_state_set_data(ctx2, state_mem.data())) {
107108
fprintf(stderr, "\n%s : failed to read state\n", __func__);
108109
llama_free(ctx2);
109110
llama_free_model(model);
@@ -141,16 +142,104 @@ int main(int argc, char ** argv) {
141142
n_past += 1;
142143
}
143144

144-
printf("\n");
145+
printf("\n\n");
145146

146147
llama_free(ctx2);
147-
llama_free_model(model);
148148

149149
if (result0 != result1) {
150150
fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
151151
return 1;
152152
}
153153

154+
// make new context
155+
auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
156+
157+
printf("\nsingle seq run: %s", params.prompt.c_str());
158+
159+
// load state (rng, logits, embedding and kv_cache) from file
160+
{
161+
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
162+
163+
FILE * fp_read = fopen("dump_state.bin", "rb");
164+
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
165+
fclose(fp_read);
166+
167+
if (read != llama_state_set_data(ctx3, state_mem.data())) {
168+
fprintf(stderr, "\n%s : failed to read state\n", __func__);
169+
llama_free(ctx3);
170+
llama_free_model(model);
171+
return 1;
172+
}
173+
174+
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
175+
}
176+
177+
// restore state (last tokens)
178+
n_past = n_past_saved;
179+
180+
// save seq 0 and load into seq 1
181+
{
182+
// save kv of seq 0
183+
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
184+
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
185+
if (ncopy != seq_store.size()) {
186+
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
187+
llama_free(ctx3);
188+
llama_free_model(model);
189+
return 1;
190+
}
191+
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
192+
193+
// erase whole kv
194+
llama_kv_cache_clear(ctx3);
195+
fprintf(stderr, "%s : kv cache cleared\n", __func__);
196+
197+
// restore kv into seq 1
198+
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
199+
if (nset != seq_store.size()) {
200+
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
201+
llama_free(ctx3);
202+
llama_free_model(model);
203+
return 1;
204+
}
205+
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
206+
}
207+
208+
// third run with seq 1 instead of 0
209+
for (auto i = 0; i < params.n_predict; i++) {
210+
auto * logits = llama_get_logits(ctx3);
211+
auto n_vocab = llama_n_vocab(model);
212+
std::vector<llama_token_data> candidates;
213+
candidates.reserve(n_vocab);
214+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
215+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
216+
}
217+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
218+
auto next_token = llama_sample_token(ctx3, &candidates_p);
219+
auto next_token_str = llama_token_to_piece(ctx3, next_token);
220+
221+
printf("%s", next_token_str.c_str());
222+
result2 += next_token_str;
223+
224+
if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) {
225+
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
226+
llama_free(ctx3);
227+
llama_free_model(model);
228+
return 1;
229+
}
230+
n_past += 1;
231+
}
232+
233+
printf("\n");
234+
235+
llama_free(ctx3);
236+
llama_free_model(model);
237+
238+
if (result0 != result2) {
239+
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
240+
return 1;
241+
}
242+
154243
fprintf(stderr, "\n%s : success\n", __func__);
155244

156245
return 0;

examples/server/README.md

+52
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/
5757
- `-n N, --n-predict N`: Set the maximum tokens to predict. Default: `-1`
5858
- `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included.
5959
- `--metrics`: enable prometheus `/metrics` compatible endpoint. Default: disabled
60+
- `--slot-save-path PATH`: Specifies the path where the state of slots (the prompt cache) can be stored. If not provided, the slot management endpoints will be disabled.
6061
- `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
6162
- `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled
6263
- `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json`
@@ -517,6 +518,57 @@ Available metrics:
517518
- `llamacpp:requests_processing`: Number of requests processing.
518519
- `llamacpp:requests_deferred`: Number of requests deferred.
519520

521+
- **POST** `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.
522+
523+
*Options:*
524+
525+
`filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter.
526+
527+
### Result JSON
528+
529+
```json
530+
{
531+
"id_slot": 0,
532+
"filename": "slot_save_file.bin",
533+
"n_saved": 1745,
534+
"n_written": 14309796,
535+
"timings": {
536+
"save_ms": 49.865
537+
}
538+
}
539+
```
540+
541+
- **POST** `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file.
542+
543+
*Options:*
544+
545+
`filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter.
546+
547+
### Result JSON
548+
549+
```json
550+
{
551+
"id_slot": 0,
552+
"filename": "slot_save_file.bin",
553+
"n_restored": 1745,
554+
"n_read": 14309796,
555+
"timings": {
556+
"restore_ms": 42.937
557+
}
558+
}
559+
```
560+
561+
- **POST** `/slots/{id_slot}?action=erase`: Erase the prompt cache of the specified slot.
562+
563+
### Result JSON
564+
565+
```json
566+
{
567+
"id_slot": 0,
568+
"n_erased": 1745
569+
}
570+
```
571+
520572
## More examples
521573

522574
### Change system prompt on runtime

0 commit comments

Comments
 (0)