@@ -102,6 +102,9 @@ struct llama_context {
102
102
// decode output (2-dimensional array: [n_tokens][n_vocab])
103
103
std::vector<float > logits;
104
104
bool logits_all = false ;
105
+
106
+ // work buffer for transformer evaluation
107
+ std::vector<uint8_t > buf_eval;
105
108
};
106
109
107
110
struct llama_context_params llama_context_default_params () {
@@ -627,27 +630,19 @@ static bool llama_eval_internal(
627
630
const int n_rot = hparams.n_embd /hparams.n_head ;
628
631
629
632
auto & mem_per_token = lctx.mem_per_token ;
633
+ auto & buf_eval = lctx.buf_eval ;
630
634
631
- // TODO: fix this hardcoded size
632
- static size_t buf_size = 512u *1024 *1024 ;
633
- static void * buf = malloc (buf_size);
635
+ if (mem_per_token*(n_past + N + 16 ) > buf_eval.size ()) {
636
+ const size_t buf_size_new = 1.618 *buf_eval.size ();
634
637
635
- if (mem_per_token > 0 && mem_per_token*N > buf_size) {
636
- const size_t buf_size_new = 1.3 *(mem_per_token*N); // add 30% to account for ggml object overhead
637
- // fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
638
+ // fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_eval.size(), buf_size_new);
638
639
639
- // reallocate
640
- buf_size = buf_size_new;
641
- buf = realloc (buf, buf_size);
642
- if (buf == nullptr ) {
643
- fprintf (stderr, " %s: failed to allocate %zu bytes\n " , __func__, buf_size);
644
- return false ;
645
- }
640
+ buf_eval.resize (buf_size_new);
646
641
}
647
642
648
643
struct ggml_init_params params = {
649
- /* .mem_size =*/ buf_size ,
650
- /* .mem_buffer =*/ buf ,
644
+ /* .mem_size =*/ buf_eval. size () ,
645
+ /* .mem_buffer =*/ buf_eval. data () ,
651
646
};
652
647
653
648
struct ggml_context * ctx0 = ggml_init (params);
@@ -832,10 +827,11 @@ static bool llama_eval_internal(
832
827
memcpy (logits_out.data (), (float *) ggml_get_data (inpL) + (n_vocab*(N-1 )), sizeof (float )*n_vocab);
833
828
}
834
829
835
- if (mem_per_token == 0 ) {
836
- mem_per_token = ggml_used_mem (ctx0)/N ;
830
+ if (N == 1 ) {
831
+ mem_per_token = ggml_used_mem (ctx0)/(n_past + N) ;
837
832
}
838
- // fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
833
+
834
+ // fprintf(stderr, "\nused_mem = %zu, %zu MB\n", ggml_used_mem(ctx0), ggml_used_mem(ctx0)/1024/1024);
839
835
840
836
ggml_free (ctx0);
841
837
@@ -1416,6 +1412,8 @@ struct llama_context * llama_init_from_file(
1416
1412
return nullptr ;
1417
1413
}
1418
1414
1415
+ ctx->buf_eval .resize (512u *1024u *1024u );
1416
+
1419
1417
return ctx;
1420
1418
}
1421
1419
0 commit comments