Skip to content

Commit 4870e45

Browse files
committed
Fix memory allocation issues and seg faults
1 parent 483bab2 commit 4870e45

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

llama.cpp

+16-18
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ struct llama_context {
102102
// decode output (2-dimensional array: [n_tokens][n_vocab])
103103
std::vector<float> logits;
104104
bool logits_all = false;
105+
106+
// work buffer for transformer evaluation
107+
std::vector<uint8_t> buf_eval;
105108
};
106109

107110
struct llama_context_params llama_context_default_params() {
@@ -627,27 +630,19 @@ static bool llama_eval_internal(
627630
const int n_rot = hparams.n_embd/hparams.n_head;
628631

629632
auto & mem_per_token = lctx.mem_per_token;
633+
auto & buf_eval = lctx.buf_eval;
630634

631-
// TODO: fix this hardcoded size
632-
static size_t buf_size = 512u*1024*1024;
633-
static void * buf = malloc(buf_size);
635+
if (mem_per_token*(n_past + N + 16) > buf_eval.size()) {
636+
const size_t buf_size_new = 1.618*buf_eval.size();
634637

635-
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
636-
const size_t buf_size_new = 1.3*(mem_per_token*N); // add 30% to account for ggml object overhead
637-
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
638+
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_eval.size(), buf_size_new);
638639

639-
// reallocate
640-
buf_size = buf_size_new;
641-
buf = realloc(buf, buf_size);
642-
if (buf == nullptr) {
643-
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
644-
return false;
645-
}
640+
buf_eval.resize(buf_size_new);
646641
}
647642

648643
struct ggml_init_params params = {
649-
/*.mem_size =*/ buf_size,
650-
/*.mem_buffer =*/ buf,
644+
/*.mem_size =*/ buf_eval.size(),
645+
/*.mem_buffer =*/ buf_eval.data(),
651646
};
652647

653648
struct ggml_context * ctx0 = ggml_init(params);
@@ -832,10 +827,11 @@ static bool llama_eval_internal(
832827
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
833828
}
834829

835-
if (mem_per_token == 0) {
836-
mem_per_token = ggml_used_mem(ctx0)/N;
830+
if (N == 1) {
831+
mem_per_token = ggml_used_mem(ctx0)/(n_past + N);
837832
}
838-
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
833+
834+
//fprintf(stderr, "\nused_mem = %zu, %zu MB\n", ggml_used_mem(ctx0), ggml_used_mem(ctx0)/1024/1024);
839835

840836
ggml_free(ctx0);
841837

@@ -1416,6 +1412,8 @@ struct llama_context * llama_init_from_file(
14161412
return nullptr;
14171413
}
14181414

1415+
ctx->buf_eval.resize(512u*1024u*1024u);
1416+
14191417
return ctx;
14201418
}
14211419

0 commit comments

Comments
 (0)