@@ -150,7 +150,7 @@ ModelLoadResult gptj_model_load(const std::string & fname, gptj_model & model, g
150
150
params.mem_size = ctx_size;
151
151
params.mem_buffer = NULL ;
152
152
params.no_alloc = false ;
153
-
153
+
154
154
155
155
model.ctx = ggml_init (params);
156
156
if (!model.ctx ) {
@@ -281,7 +281,7 @@ ModelLoadResult gptj_model_load(const std::string & fname, gptj_model & model, g
281
281
fprintf (stderr, " %s: tensor '%s' has wrong size in model file\n " , __func__, name.data ());
282
282
return ModelLoadResult::FAIL;
283
283
}
284
-
284
+
285
285
286
286
if (tensor->ne [0 ] != ne[0 ] || tensor->ne [1 ] != ne[1 ]) {
287
287
@@ -298,7 +298,7 @@ ModelLoadResult gptj_model_load(const std::string & fname, gptj_model & model, g
298
298
__func__, name.data (), tensor->ne [0 ], tensor->ne [1 ], ne[0 ], ne[1 ]);
299
299
return ModelLoadResult::FAIL;
300
300
}
301
-
301
+
302
302
}
303
303
304
304
// for debugging
@@ -367,8 +367,16 @@ bool gptj_eval(
367
367
static size_t buf_size = 256u *1024 *1024 ;
368
368
static void * buf = malloc (buf_size);
369
369
370
- if (mem_per_token > 0 && (mem_per_token*N*2 + 64u *1024 *1024 ) > buf_size) {
371
- const size_t buf_size_new = 320u *1024 *1024 + 1.6 *(mem_per_token*N); // add 10% to account for ggml object overhead
370
+ // use 2 scratch buffers
371
+ // TODO: very hacky solution - reimplement in a more elegant way
372
+ static size_t scr0_size = (n_ctx>1024 ?512u :256u )*1024 *1024 ;
373
+ static void * scr0 = malloc (scr0_size);
374
+
375
+ static size_t scr1_size = (n_ctx>1024 ?512u :256u )*1024 *1024 ;
376
+ static void * scr1 = malloc (scr1_size);
377
+
378
+ if (mem_per_token > 0 && mem_per_token*N*1.05 > buf_size) {
379
+ const size_t buf_size_new = 64u *1024 *1024 + 1.15 *(mem_per_token*N); // add 10% to account for ggml object overhead
372
380
// printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
373
381
374
382
// reallocate
@@ -388,7 +396,7 @@ bool gptj_eval(
388
396
params.mem_size = buf_size;
389
397
params.mem_buffer = buf;
390
398
params.no_alloc = false ;
391
-
399
+
392
400
393
401
struct ggml_context * ctx0 = ggml_init (params);
394
402
struct ggml_cgraph gf = {};
@@ -403,6 +411,8 @@ bool gptj_eval(
403
411
for (int il = 0 ; il < n_layer; ++il) {
404
412
struct ggml_tensor * cur;
405
413
414
+ ggml_set_scratch (ctx0, { 0 , scr0_size, scr0, });
415
+
406
416
// norm
407
417
{
408
418
cur = ggml_norm (ctx0, inpL);
@@ -490,6 +500,8 @@ bool gptj_eval(
490
500
cur);
491
501
}
492
502
503
+ ggml_set_scratch (ctx0, { 0 , scr1_size, scr1, });
504
+
493
505
struct ggml_tensor * inpFF = cur;
494
506
495
507
// feed-forward network
@@ -525,6 +537,8 @@ bool gptj_eval(
525
537
inpL = ggml_add (ctx0, cur, inpL);
526
538
}
527
539
540
+ ggml_set_scratch (ctx0, { 0 , scr0_size, scr0, });
541
+
528
542
// norm
529
543
{
530
544
inpL = ggml_norm (ctx0, inpL);
@@ -537,6 +551,8 @@ bool gptj_eval(
537
551
ggml_repeat (ctx0, model.ln_f_b , inpL));
538
552
}
539
553
554
+ ggml_set_scratch (ctx0, { 0 , 0 , nullptr , });
555
+
540
556
// lm_head
541
557
{
542
558
inpL = ggml_mul_mat (ctx0, model.lmh_g , inpL);
0 commit comments