Skip to content

Commit 397cdcb

Browse files
ggerganovhodlen
authored andcommitted
backend : add eval callback (ggml-org#4935)
* backend : add eval callback ggml-ci * backend : group nodes in a single compute when user don't need them * backend : clean-up the implementation ggml-ci * simple : do not perform tensor data copy if not needed * simple : fix * simple : no need for ggml_is_contiguous + fix bool parse * llama : fix callback placement in llama_context_params * backend : avoid double-ask callback calls * simple : restore examples, imatrix will serve as a demo
1 parent 14e421c commit 397cdcb

File tree

4 files changed

+64
-2
lines changed

4 files changed

+64
-2
lines changed

ggml-backend.c

+40-2
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,9 @@ struct ggml_backend_sched {
802802
__attribute__((aligned(GGML_MEM_ALIGN)))
803803
#endif
804804
char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
805+
806+
ggml_backend_sched_eval_callback callback_eval;
807+
void * callback_eval_user_data;
805808
};
806809

807810
#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
@@ -1324,9 +1327,38 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
13241327
ggml_graph_dump_dot(split->graph, NULL, split_filename);
13251328
#endif
13261329

1330+
13271331
uint64_t compute_start_us = ggml_time_us();
1328-
ggml_backend_graph_compute(split_backend, &split->graph);
1329-
//ggml_backend_synchronize(split_backend); // necessary to measure compute time
1332+
if (!sched->callback_eval) {
1333+
ggml_backend_graph_compute(split_backend, &split->graph);
1334+
//ggml_backend_synchronize(split_backend); // necessary to measure compute time
1335+
} else {
1336+
// similar to ggml_backend_compare_graph_backend
1337+
for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
1338+
struct ggml_tensor * t = split->graph.nodes[j0];
1339+
1340+
// check if the user needs data from this node
1341+
bool need = sched->callback_eval(t, true, sched->callback_eval_user_data);
1342+
1343+
int j1 = j0;
1344+
1345+
// determine the range [j0, j1] of nodes that can be computed together
1346+
while (!need && j1 < split->graph.n_nodes - 1) {
1347+
t = split->graph.nodes[++j1];
1348+
need = sched->callback_eval(t, true, sched->callback_eval_user_data);
1349+
}
1350+
1351+
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
1352+
1353+
ggml_backend_graph_compute(split_backend, &gv);
1354+
1355+
if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
1356+
break;
1357+
}
1358+
1359+
j0 = j1;
1360+
}
1361+
}
13301362
uint64_t compute_end_us = ggml_time_us();
13311363
compute_us[split_backend_id] += compute_end_us - compute_start_us;
13321364
}
@@ -1431,6 +1463,12 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
14311463
sched_reset(sched);
14321464
}
14331465

1466+
1467+
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
1468+
sched->callback_eval = callback;
1469+
sched->callback_eval_user_data = user_data;
1470+
}
1471+
14341472
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
14351473
return sched->n_splits;
14361474
}

ggml-backend.h

+11
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ extern "C" {
148148
struct ggml_backend_sched;
149149
typedef struct ggml_backend_sched * ggml_backend_sched_t;
150150

151+
// when ask == true, the scheduler wants to know if the user wants to observe this node
152+
// this allows the scheduler to batch nodes together in order to evaluate them in a single call
153+
//
154+
// when ask == false, the scheduler is passing the node tensor to the user for observation
155+
// if the user returns false, the scheduler will cancel the graph compute
156+
//
157+
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
158+
151159
// Initialize a backend scheduler
152160
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
153161
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
@@ -168,6 +176,9 @@ extern "C" {
168176
// Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs
169177
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
170178

179+
// Set a callback to be called for each resulting node during graph compute
180+
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
181+
171182
//
172183
// Utils
173184
//

llama.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -1393,6 +1393,9 @@ struct llama_cparams {
13931393

13941394
bool mul_mat_q;
13951395
bool offload_kqv;
1396+
1397+
ggml_backend_sched_eval_callback cb_eval;
1398+
void * cb_eval_user_data;
13961399
};
13971400

13981401
struct llama_layer {
@@ -6254,6 +6257,7 @@ static int llama_decode_internal(
62546257
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
62556258

62566259
ggml_backend_sched_reset(lctx.sched);
6260+
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
62576261

62586262
ggml_cgraph * gf = llama_build_graph(lctx, batch);
62596263

@@ -9276,6 +9280,8 @@ struct llama_context_params llama_context_default_params() {
92769280
/*.yarn_beta_fast =*/ 32.0f,
92779281
/*.yarn_beta_slow =*/ 1.0f,
92789282
/*.yarn_orig_ctx =*/ 0,
9283+
/*.cb_eval =*/ nullptr,
9284+
/*.cb_eval_user_data =*/ nullptr,
92799285
/*.type_k =*/ GGML_TYPE_F16,
92809286
/*.type_v =*/ GGML_TYPE_F16,
92819287
/*.mul_mat_q =*/ true,
@@ -9416,6 +9422,9 @@ struct llama_context * llama_new_context_with_model(
94169422
hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
94179423
hparams.n_ctx_train;
94189424

9425+
cparams.cb_eval = params.cb_eval;
9426+
cparams.cb_eval_user_data = params.cb_eval_user_data;
9427+
94199428
auto rope_scaling_type = params.rope_scaling_type;
94209429
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
94219430
rope_scaling_type = hparams.rope_scaling_type_train;

llama.h

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define LLAMA_H
33

44
#include "ggml.h"
5+
#include "ggml-backend.h"
56
#ifdef GGML_USE_CUBLAS
67
#include "ggml-cuda.h"
78
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
@@ -231,6 +232,9 @@ extern "C" {
231232
float yarn_beta_slow; // YaRN high correction dim
232233
uint32_t yarn_orig_ctx; // YaRN original context size
233234

235+
ggml_backend_sched_eval_callback cb_eval;
236+
void * cb_eval_user_data;
237+
234238
enum ggml_type type_k; // data type for K cache
235239
enum ggml_type type_v; // data type for V cache
236240

0 commit comments

Comments
 (0)