Skip to content

Commit 075f5f5

Browse files
committed
add functions to get and set the whole llama state (rng, logits, embedding and kv_cache)
1 parent 958a74e commit 075f5f5

File tree

3 files changed

+154
-1
lines changed

3 files changed

+154
-1
lines changed

llama.cpp

+125-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
#include <climits>
1717
#include <memory>
1818
#include <algorithm>
19-
#include <initializer_list>
19+
#include <sstream>
20+
2021

2122
#define LLAMA_USE_SCRATCH
2223
#define LLAMA_MAX_SCRATCH_BUFFERS 16
@@ -1931,3 +1932,126 @@ const char * llama_print_system_info(void) {
19311932
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx) {
19321933
return ctx->model.tensors_by_name;
19331934
}
1935+
1936+
// Returns the size of the state
1937+
size_t llama_get_state_size(struct llama_context * ctx) {
1938+
const size_t s_bool = sizeof(int32_t);
1939+
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
1940+
// for reference, std::mt19937(1337) serializes to 6701 bytes.
1941+
const size_t s_rng_size = sizeof(size_t);
1942+
const size_t s_rng = 64*1024;
1943+
const size_t s_logits_capacity = sizeof(size_t);
1944+
const size_t s_logits_size = sizeof(size_t);
1945+
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
1946+
const size_t s_embedding_size = sizeof(size_t);
1947+
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
1948+
const size_t s_kv_size = sizeof(size_t);
1949+
const size_t s_kv_ntok = sizeof(int);
1950+
const size_t s_kv = llama_get_kv_cache_size(ctx);
1951+
const size_t s_total = (
1952+
+ s_rng_size
1953+
+ s_rng
1954+
+ s_logits_capacity
1955+
+ s_logits_size
1956+
+ s_logits
1957+
+ s_embedding_size
1958+
+ s_embedding
1959+
+ s_kv_size
1960+
+ s_kv_ntok
1961+
+ s_kv
1962+
);
1963+
return s_total;
1964+
}
1965+
1966+
// Copies the state to the specified destination address
1967+
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
1968+
std::stringstream rng_ss;
1969+
rng_ss << ctx->rng;
1970+
const size_t rng_size = rng_ss.str().size();
1971+
char rng_buf[64*1024];
1972+
memset(&rng_buf[0], 0, 64*1024);
1973+
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
1974+
const int32_t has_evaluated_once = ctx->has_evaluated_once ? 1 : 0;
1975+
const int32_t logits_all = ctx->logits_all ? 1 : 0;
1976+
const size_t logits_capacity = ctx->logits.capacity();
1977+
const size_t logits_size = ctx->logits.size();
1978+
const size_t embedding_size = ctx->embedding.size();
1979+
const size_t kv_size = llama_get_kv_cache_size(ctx);
1980+
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
1981+
1982+
uint8_t * out = dest;
1983+
memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t);
1984+
memcpy(out, &rng_buf[0], 64*1024); out += 64*1024;
1985+
memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t);
1986+
memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t);
1987+
if (logits_size) {
1988+
memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
1989+
}
1990+
out += logits_capacity * sizeof(float);
1991+
memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t);
1992+
if (embedding_size) {
1993+
memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float);
1994+
}
1995+
memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t);
1996+
memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int);
1997+
if (kv_size) {
1998+
memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size;
1999+
}
2000+
const size_t written = out - dest;
2001+
const size_t expected = llama_get_state_size(ctx);
2002+
LLAMA_ASSERT(written == expected);
2003+
return written;
2004+
}
2005+
2006+
// Copies the state to the specified destination address
2007+
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
2008+
size_t rng_size;
2009+
char rng_buf[64*1024];
2010+
std::stringstream rng_ss;
2011+
2012+
const uint8_t * in = src;
2013+
memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t);
2014+
memcpy(&rng_buf[0], in, 64*1024); in += 64*1024;
2015+
rng_ss.str(std::string(&rng_buf[0], rng_size));
2016+
rng_ss >> ctx->rng;
2017+
LLAMA_ASSERT(rng_ss.fail() == false);
2018+
2019+
int32_t has_evaluated_once;
2020+
int32_t logits_all;
2021+
size_t logits_capacity;
2022+
size_t logits_size;
2023+
size_t embedding_size;
2024+
size_t kv_size;
2025+
int kv_ntok;
2026+
2027+
memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t);
2028+
memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t);
2029+
LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity);
2030+
if (logits_size) {
2031+
ctx->logits.resize(logits_size);
2032+
memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
2033+
}
2034+
in += logits_capacity * sizeof(float);
2035+
memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t);
2036+
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
2037+
if (embedding_size) {
2038+
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
2039+
in += embedding_size * sizeof(float);
2040+
}
2041+
memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t);
2042+
memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int);
2043+
if (kv_size) {
2044+
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
2045+
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
2046+
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
2047+
memcpy(ctx->model.kv_self.buf.addr, in, kv_size);
2048+
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
2049+
ctx->model.kv_self.v->data = v_data;
2050+
in += kv_size;
2051+
}
2052+
ctx->model.kv_self.n = kv_ntok;
2053+
const size_t nread = in - src;
2054+
const size_t expected = llama_get_state_size(ctx);
2055+
LLAMA_ASSERT(nread == expected);
2056+
return nread;
2057+
}

llama.h

+11
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,17 @@ extern "C" {
104104
size_t n_size,
105105
int n_token_count);
106106

107+
// Returns the size of the state
108+
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
109+
110+
// Copies the state to the specified destination address
111+
// Returns the number of bytes copied
112+
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest);
113+
114+
// Copies the state to the specified destination address
115+
// Returns the number of bytes read
116+
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);
117+
107118
// Run the llama inference to obtain the logits and probabilities for the next token.
108119
// tokens + n_tokens is the provided batch of new tokens to process
109120
// n_past is the number of tokens to use from previous eval calls

py/llama_cpp/llama.py

+18
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,15 @@ class llama_context_params(Structure):
105105
lib.llama_set_kv_cache.argtypes = [llama_context_p, c_ubyte_p, c_size_t, c_int]
106106
lib.llama_set_kv_cache.restype = None
107107

108+
lib.llama_get_state_size.argtypes = [llama_context_p]
109+
lib.llama_get_state_size.restype = c_size_t
110+
111+
lib.llama_copy_state_data.argtypes = [llama_context_p, c_ubyte_p]
112+
lib.llama_copy_state_data.restype = c_size_t
113+
114+
lib.llama_set_state_data.argtypes = [llama_context_p, c_ubyte_p]
115+
lib.llama_set_state_data.restype = c_size_t
116+
108117
# Python functions
109118
def llama_context_default_params() -> llama_context_params:
110119
params = lib.llama_context_default_params()
@@ -186,3 +195,12 @@ def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int:
186195

187196
def llama_set_kv_cache(ctx: llama_context_p, data: c_ubyte_p, n_size:c_size_t, n_token_count:c_int):
188197
return lib.llama_set_kv_cache(ctx, data, n_size, n_token_count)
198+
199+
def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
200+
return lib.llama_get_state_size(ctx)
201+
202+
def llama_copy_state_data(ctx: llama_context_p, dst: c_ubyte_p) -> c_size_t:
203+
return lib.llama_copy_state_data(ctx, dst)
204+
205+
def llama_set_state_data(ctx: llama_context_p, src: c_ubyte_p) -> c_size_t:
206+
return lib.llama_set_state_data(ctx, src)

0 commit comments

Comments
 (0)