Skip to content

Commit b6e7f9b

Browse files
authored
llama : add api for getting/setting the complete state: rng, logits, embedding and kv_cache (#1105)
* reserve correct size for logits * add functions to get and set the whole llama state: including rng, logits, embedding and kv_cache * remove unused variables * remove trailing whitespace * fix comment
1 parent 50cb666 commit b6e7f9b

File tree

2 files changed

+133
-1
lines changed

2 files changed

+133
-1
lines changed

llama.cpp

+121-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <thread>
2828
#include <atomic>
2929
#include <mutex>
30+
#include <sstream>
3031

3132
#define LLAMA_USE_SCRATCH
3233
#define LLAMA_MAX_SCRATCH_BUFFERS 16
@@ -1787,7 +1788,7 @@ struct llama_context * llama_init_from_file(
17871788
if (params.logits_all) {
17881789
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
17891790
} else {
1790-
ctx->logits.reserve(hparams.n_ctx);
1791+
ctx->logits.reserve(hparams.n_vocab);
17911792
}
17921793

17931794
if (params.embedding){
@@ -2252,3 +2253,122 @@ const char * llama_print_system_info(void) {
22522253
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx) {
22532254
return ctx->model.tensors_by_name;
22542255
}
2256+
2257+
// Returns the size of the state
2258+
size_t llama_get_state_size(struct llama_context * ctx) {
2259+
const size_t s_bool = sizeof(int32_t);
2260+
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
2261+
// for reference, std::mt19937(1337) serializes to 6701 bytes.
2262+
const size_t s_rng_size = sizeof(size_t);
2263+
const size_t s_rng = 64*1024;
2264+
const size_t s_logits_capacity = sizeof(size_t);
2265+
const size_t s_logits_size = sizeof(size_t);
2266+
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
2267+
const size_t s_embedding_size = sizeof(size_t);
2268+
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
2269+
const size_t s_kv_size = sizeof(size_t);
2270+
const size_t s_kv_ntok = sizeof(int);
2271+
const size_t s_kv = llama_get_kv_cache_size(ctx);
2272+
const size_t s_total = (
2273+
+ s_rng_size
2274+
+ s_rng
2275+
+ s_logits_capacity
2276+
+ s_logits_size
2277+
+ s_logits
2278+
+ s_embedding_size
2279+
+ s_embedding
2280+
+ s_kv_size
2281+
+ s_kv_ntok
2282+
+ s_kv
2283+
);
2284+
return s_total;
2285+
}
2286+
2287+
// Copies the state to the specified destination address
2288+
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
2289+
std::stringstream rng_ss;
2290+
rng_ss << ctx->rng;
2291+
const size_t rng_size = rng_ss.str().size();
2292+
char rng_buf[64*1024];
2293+
memset(&rng_buf[0], 0, 64*1024);
2294+
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
2295+
const size_t logits_capacity = ctx->logits.capacity();
2296+
const size_t logits_size = ctx->logits.size();
2297+
const size_t embedding_size = ctx->embedding.size();
2298+
const size_t kv_size = llama_get_kv_cache_size(ctx);
2299+
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
2300+
2301+
uint8_t * out = dest;
2302+
memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t);
2303+
memcpy(out, &rng_buf[0], 64*1024); out += 64*1024;
2304+
memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t);
2305+
memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t);
2306+
if (logits_size) {
2307+
memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
2308+
}
2309+
out += logits_capacity * sizeof(float);
2310+
memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t);
2311+
if (embedding_size) {
2312+
memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float);
2313+
}
2314+
memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t);
2315+
memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int);
2316+
if (kv_size) {
2317+
memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size;
2318+
}
2319+
const size_t written = out - dest;
2320+
const size_t expected = llama_get_state_size(ctx);
2321+
LLAMA_ASSERT(written == expected);
2322+
return written;
2323+
}
2324+
2325+
// Sets the state reading from the specified source address
2326+
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
2327+
size_t rng_size;
2328+
char rng_buf[64*1024];
2329+
std::stringstream rng_ss;
2330+
2331+
const uint8_t * in = src;
2332+
memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t);
2333+
memcpy(&rng_buf[0], in, 64*1024); in += 64*1024;
2334+
rng_ss.str(std::string(&rng_buf[0], rng_size));
2335+
rng_ss >> ctx->rng;
2336+
LLAMA_ASSERT(rng_ss.fail() == false);
2337+
2338+
size_t logits_capacity;
2339+
size_t logits_size;
2340+
size_t embedding_size;
2341+
size_t kv_size;
2342+
int kv_ntok;
2343+
2344+
memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t);
2345+
memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t);
2346+
LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity);
2347+
if (logits_size) {
2348+
ctx->logits.resize(logits_size);
2349+
memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
2350+
}
2351+
in += logits_capacity * sizeof(float);
2352+
memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t);
2353+
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
2354+
if (embedding_size) {
2355+
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
2356+
in += embedding_size * sizeof(float);
2357+
}
2358+
memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t);
2359+
memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int);
2360+
if (kv_size) {
2361+
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
2362+
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
2363+
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
2364+
memcpy(ctx->model.kv_self.buf.addr, in, kv_size);
2365+
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
2366+
ctx->model.kv_self.v->data = v_data;
2367+
in += kv_size;
2368+
}
2369+
ctx->model.kv_self.n = kv_ntok;
2370+
const size_t nread = in - src;
2371+
const size_t expected = llama_get_state_size(ctx);
2372+
LLAMA_ASSERT(nread == expected);
2373+
return nread;
2374+
}

llama.h

+12
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,18 @@ extern "C" {
129129
size_t n_size,
130130
int n_token_count);
131131

132+
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
133+
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
134+
135+
// Copies the state to the specified destination address.
136+
// Destination needs to have allocated enough memory.
137+
// Returns the number of bytes copied
138+
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest);
139+
140+
// Set the state reading from the specified address
141+
// Returns the number of bytes read
142+
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);
143+
132144
// Run the llama inference to obtain the logits and probabilities for the next token.
133145
// tokens + n_tokens is the provided batch of new tokens to process
134146
// n_past is the number of tokens to use from previous eval calls

0 commit comments

Comments
 (0)