|
16 | 16 | #include <climits>
|
17 | 17 | #include <memory>
|
18 | 18 | #include <algorithm>
|
19 |
| -#include <initializer_list> |
| 19 | +#include <sstream> |
| 20 | + |
20 | 21 |
|
21 | 22 | #define LLAMA_USE_SCRATCH
|
22 | 23 | #define LLAMA_MAX_SCRATCH_BUFFERS 16
|
@@ -1931,3 +1932,126 @@ const char * llama_print_system_info(void) {
|
1931 | 1932 | std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx) {
|
1932 | 1933 | return ctx->model.tensors_by_name;
|
1933 | 1934 | }
|
| 1935 | + |
| 1936 | +// Returns the size of the state |
| 1937 | +size_t llama_get_state_size(struct llama_context * ctx) { |
| 1938 | + const size_t s_bool = sizeof(int32_t); |
| 1939 | + // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. |
| 1940 | + // for reference, std::mt19937(1337) serializes to 6701 bytes. |
| 1941 | + const size_t s_rng_size = sizeof(size_t); |
| 1942 | + const size_t s_rng = 64*1024; |
| 1943 | + const size_t s_logits_capacity = sizeof(size_t); |
| 1944 | + const size_t s_logits_size = sizeof(size_t); |
| 1945 | + const size_t s_logits = ctx->logits.capacity() * sizeof(float); |
| 1946 | + const size_t s_embedding_size = sizeof(size_t); |
| 1947 | + const size_t s_embedding = ctx->embedding.size() * sizeof(float); |
| 1948 | + const size_t s_kv_size = sizeof(size_t); |
| 1949 | + const size_t s_kv_ntok = sizeof(int); |
| 1950 | + const size_t s_kv = llama_get_kv_cache_size(ctx); |
| 1951 | + const size_t s_total = ( |
| 1952 | + + s_rng_size |
| 1953 | + + s_rng |
| 1954 | + + s_logits_capacity |
| 1955 | + + s_logits_size |
| 1956 | + + s_logits |
| 1957 | + + s_embedding_size |
| 1958 | + + s_embedding |
| 1959 | + + s_kv_size |
| 1960 | + + s_kv_ntok |
| 1961 | + + s_kv |
| 1962 | + ); |
| 1963 | + return s_total; |
| 1964 | +} |
| 1965 | + |
| 1966 | +// Copies the state to the specified destination address |
| 1967 | +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { |
| 1968 | + std::stringstream rng_ss; |
| 1969 | + rng_ss << ctx->rng; |
| 1970 | + const size_t rng_size = rng_ss.str().size(); |
| 1971 | + char rng_buf[64*1024]; |
| 1972 | + memset(&rng_buf[0], 0, 64*1024); |
| 1973 | + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); |
| 1974 | + const int32_t has_evaluated_once = ctx->has_evaluated_once ? 1 : 0; |
| 1975 | + const int32_t logits_all = ctx->logits_all ? 1 : 0; |
| 1976 | + const size_t logits_capacity = ctx->logits.capacity(); |
| 1977 | + const size_t logits_size = ctx->logits.size(); |
| 1978 | + const size_t embedding_size = ctx->embedding.size(); |
| 1979 | + const size_t kv_size = llama_get_kv_cache_size(ctx); |
| 1980 | + const int kv_ntok = llama_get_kv_cache_token_count(ctx); |
| 1981 | + |
| 1982 | + uint8_t * out = dest; |
| 1983 | + memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t); |
| 1984 | + memcpy(out, &rng_buf[0], 64*1024); out += 64*1024; |
| 1985 | + memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t); |
| 1986 | + memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t); |
| 1987 | + if (logits_size) { |
| 1988 | + memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); |
| 1989 | + } |
| 1990 | + out += logits_capacity * sizeof(float); |
| 1991 | + memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t); |
| 1992 | + if (embedding_size) { |
| 1993 | + memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float); |
| 1994 | + } |
| 1995 | + memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t); |
| 1996 | + memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int); |
| 1997 | + if (kv_size) { |
| 1998 | + memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size; |
| 1999 | + } |
| 2000 | + const size_t written = out - dest; |
| 2001 | + const size_t expected = llama_get_state_size(ctx); |
| 2002 | + LLAMA_ASSERT(written == expected); |
| 2003 | + return written; |
| 2004 | +} |
| 2005 | + |
| 2006 | +// Copies the state to the specified destination address |
| 2007 | +size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { |
| 2008 | + size_t rng_size; |
| 2009 | + char rng_buf[64*1024]; |
| 2010 | + std::stringstream rng_ss; |
| 2011 | + |
| 2012 | + const uint8_t * in = src; |
| 2013 | + memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t); |
| 2014 | + memcpy(&rng_buf[0], in, 64*1024); in += 64*1024; |
| 2015 | + rng_ss.str(std::string(&rng_buf[0], rng_size)); |
| 2016 | + rng_ss >> ctx->rng; |
| 2017 | + LLAMA_ASSERT(rng_ss.fail() == false); |
| 2018 | + |
| 2019 | + int32_t has_evaluated_once; |
| 2020 | + int32_t logits_all; |
| 2021 | + size_t logits_capacity; |
| 2022 | + size_t logits_size; |
| 2023 | + size_t embedding_size; |
| 2024 | + size_t kv_size; |
| 2025 | + int kv_ntok; |
| 2026 | + |
| 2027 | + memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t); |
| 2028 | + memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t); |
| 2029 | + LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity); |
| 2030 | + if (logits_size) { |
| 2031 | + ctx->logits.resize(logits_size); |
| 2032 | + memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); |
| 2033 | + } |
| 2034 | + in += logits_capacity * sizeof(float); |
| 2035 | + memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t); |
| 2036 | + LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); |
| 2037 | + if (embedding_size) { |
| 2038 | + memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); |
| 2039 | + in += embedding_size * sizeof(float); |
| 2040 | + } |
| 2041 | + memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t); |
| 2042 | + memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int); |
| 2043 | + if (kv_size) { |
| 2044 | + LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size); |
| 2045 | + void * k_data = ctx->model.kv_self.k->data; // remember data pointers |
| 2046 | + void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy |
| 2047 | + memcpy(ctx->model.kv_self.buf.addr, in, kv_size); |
| 2048 | + ctx->model.kv_self.k->data = k_data; // restore correct data pointers |
| 2049 | + ctx->model.kv_self.v->data = v_data; |
| 2050 | + in += kv_size; |
| 2051 | + } |
| 2052 | + ctx->model.kv_self.n = kv_ntok; |
| 2053 | + const size_t nread = in - src; |
| 2054 | + const size_t expected = llama_get_state_size(ctx); |
| 2055 | + LLAMA_ASSERT(nread == expected); |
| 2056 | + return nread; |
| 2057 | +} |
0 commit comments