|
27 | 27 | #include <thread>
|
28 | 28 | #include <atomic>
|
29 | 29 | #include <mutex>
|
| 30 | +#include <sstream> |
30 | 31 |
|
31 | 32 | #define LLAMA_USE_SCRATCH
|
32 | 33 | #define LLAMA_MAX_SCRATCH_BUFFERS 16
|
@@ -1787,7 +1788,7 @@ struct llama_context * llama_init_from_file(
|
1787 | 1788 | if (params.logits_all) {
|
1788 | 1789 | ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
|
1789 | 1790 | } else {
|
1790 |
| - ctx->logits.reserve(hparams.n_ctx); |
| 1791 | + ctx->logits.reserve(hparams.n_vocab); |
1791 | 1792 | }
|
1792 | 1793 |
|
1793 | 1794 | if (params.embedding){
|
@@ -2252,3 +2253,122 @@ const char * llama_print_system_info(void) {
|
2252 | 2253 | std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx) {
|
2253 | 2254 | return ctx->model.tensors_by_name;
|
2254 | 2255 | }
|
| 2256 | + |
| 2257 | +// Returns the size of the state |
| 2258 | +size_t llama_get_state_size(struct llama_context * ctx) { |
| 2259 | + const size_t s_bool = sizeof(int32_t); |
| 2260 | + // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. |
| 2261 | + // for reference, std::mt19937(1337) serializes to 6701 bytes. |
| 2262 | + const size_t s_rng_size = sizeof(size_t); |
| 2263 | + const size_t s_rng = 64*1024; |
| 2264 | + const size_t s_logits_capacity = sizeof(size_t); |
| 2265 | + const size_t s_logits_size = sizeof(size_t); |
| 2266 | + const size_t s_logits = ctx->logits.capacity() * sizeof(float); |
| 2267 | + const size_t s_embedding_size = sizeof(size_t); |
| 2268 | + const size_t s_embedding = ctx->embedding.size() * sizeof(float); |
| 2269 | + const size_t s_kv_size = sizeof(size_t); |
| 2270 | + const size_t s_kv_ntok = sizeof(int); |
| 2271 | + const size_t s_kv = llama_get_kv_cache_size(ctx); |
| 2272 | + const size_t s_total = ( |
| 2273 | + + s_rng_size |
| 2274 | + + s_rng |
| 2275 | + + s_logits_capacity |
| 2276 | + + s_logits_size |
| 2277 | + + s_logits |
| 2278 | + + s_embedding_size |
| 2279 | + + s_embedding |
| 2280 | + + s_kv_size |
| 2281 | + + s_kv_ntok |
| 2282 | + + s_kv |
| 2283 | + ); |
| 2284 | + return s_total; |
| 2285 | +} |
| 2286 | + |
| 2287 | +// Copies the state to the specified destination address |
| 2288 | +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { |
| 2289 | + std::stringstream rng_ss; |
| 2290 | + rng_ss << ctx->rng; |
| 2291 | + const size_t rng_size = rng_ss.str().size(); |
| 2292 | + char rng_buf[64*1024]; |
| 2293 | + memset(&rng_buf[0], 0, 64*1024); |
| 2294 | + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); |
| 2295 | + const size_t logits_capacity = ctx->logits.capacity(); |
| 2296 | + const size_t logits_size = ctx->logits.size(); |
| 2297 | + const size_t embedding_size = ctx->embedding.size(); |
| 2298 | + const size_t kv_size = llama_get_kv_cache_size(ctx); |
| 2299 | + const int kv_ntok = llama_get_kv_cache_token_count(ctx); |
| 2300 | + |
| 2301 | + uint8_t * out = dest; |
| 2302 | + memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t); |
| 2303 | + memcpy(out, &rng_buf[0], 64*1024); out += 64*1024; |
| 2304 | + memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t); |
| 2305 | + memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t); |
| 2306 | + if (logits_size) { |
| 2307 | + memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); |
| 2308 | + } |
| 2309 | + out += logits_capacity * sizeof(float); |
| 2310 | + memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t); |
| 2311 | + if (embedding_size) { |
| 2312 | + memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float); |
| 2313 | + } |
| 2314 | + memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t); |
| 2315 | + memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int); |
| 2316 | + if (kv_size) { |
| 2317 | + memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size; |
| 2318 | + } |
| 2319 | + const size_t written = out - dest; |
| 2320 | + const size_t expected = llama_get_state_size(ctx); |
| 2321 | + LLAMA_ASSERT(written == expected); |
| 2322 | + return written; |
| 2323 | +} |
| 2324 | + |
| 2325 | +// Sets the state reading from the specified source address |
| 2326 | +size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { |
| 2327 | + size_t rng_size; |
| 2328 | + char rng_buf[64*1024]; |
| 2329 | + std::stringstream rng_ss; |
| 2330 | + |
| 2331 | + const uint8_t * in = src; |
| 2332 | + memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t); |
| 2333 | + memcpy(&rng_buf[0], in, 64*1024); in += 64*1024; |
| 2334 | + rng_ss.str(std::string(&rng_buf[0], rng_size)); |
| 2335 | + rng_ss >> ctx->rng; |
| 2336 | + LLAMA_ASSERT(rng_ss.fail() == false); |
| 2337 | + |
| 2338 | + size_t logits_capacity; |
| 2339 | + size_t logits_size; |
| 2340 | + size_t embedding_size; |
| 2341 | + size_t kv_size; |
| 2342 | + int kv_ntok; |
| 2343 | + |
| 2344 | + memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t); |
| 2345 | + memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t); |
| 2346 | + LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity); |
| 2347 | + if (logits_size) { |
| 2348 | + ctx->logits.resize(logits_size); |
| 2349 | + memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); |
| 2350 | + } |
| 2351 | + in += logits_capacity * sizeof(float); |
| 2352 | + memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t); |
| 2353 | + LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); |
| 2354 | + if (embedding_size) { |
| 2355 | + memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); |
| 2356 | + in += embedding_size * sizeof(float); |
| 2357 | + } |
| 2358 | + memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t); |
| 2359 | + memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int); |
| 2360 | + if (kv_size) { |
| 2361 | + LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size); |
| 2362 | + void * k_data = ctx->model.kv_self.k->data; // remember data pointers |
| 2363 | + void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy |
| 2364 | + memcpy(ctx->model.kv_self.buf.addr, in, kv_size); |
| 2365 | + ctx->model.kv_self.k->data = k_data; // restore correct data pointers |
| 2366 | + ctx->model.kv_self.v->data = v_data; |
| 2367 | + in += kv_size; |
| 2368 | + } |
| 2369 | + ctx->model.kv_self.n = kv_ntok; |
| 2370 | + const size_t nread = in - src; |
| 2371 | + const size_t expected = llama_get_state_size(ctx); |
| 2372 | + LLAMA_ASSERT(nread == expected); |
| 2373 | + return nread; |
| 2374 | +} |
0 commit comments