Skip to content

Commit 83f3d7a

Browse files
committed
backend : clean-up the implementation
ggml-ci
1 parent 01b6f68 commit 83f3d7a

File tree

3 files changed

+22
-20
lines changed

3 files changed

+22
-20
lines changed

Diff for: examples/simple/simple.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,20 @@
88

99
// a function that can be called for every computed node during graph evaluation
1010
// the user can choose to whether to observe the data of the node depending on the tensor parameters
11-
static bool observe_compute(int node_index, struct ggml_tensor * t, bool ask, void * user_data) {
11+
static bool observe_compute(struct ggml_tensor * t, bool ask, void * user_data) {
1212
GGML_UNUSED(user_data);
1313

1414
// the scheduler is asking us if we want to observe this node
1515
if (ask) {
16-
// check if name contains soft_max
16+
// check if name contains soft_max (customize to your needs)
1717
return strstr(t->name, "soft_max") != 0;
1818
}
1919

20-
// print the node data
21-
printf("%s: node_index = %5d, t->name = %32s, t->op = %12s, [%5d, %5d, %5d, %5d]\n",
22-
__func__, node_index, t->name, ggml_op_name(t->op), (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
20+
// print the node info
21+
printf("%s: t->name = %32s, t->op = %12s, [%5d, %5d, %5d, %5d]\n",
22+
__func__, t->name, ggml_op_name(t->op), (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
2323

24+
// this will copy the data to host memory (if needed)
2425
std::vector<float> t_data(ggml_nelements(t));
2526
ggml_backend_tensor_get(t, t_data.data(), 0, ggml_nbytes(t));
2627

Diff for: ggml-backend.c

+15-12
Original file line numberDiff line numberDiff line change
@@ -1334,28 +1334,31 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
13341334
//ggml_backend_synchronize(split_backend); // necessary to measure compute time
13351335
} else {
13361336
// similar to ggml_backend_compare_graph_backend
1337-
for (int j = 0; j < split->graph.n_nodes; j++) {
1338-
struct ggml_tensor * t = split->graph.nodes[j];
1337+
for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
1338+
struct ggml_tensor * t = split->graph.nodes[j0];
13391339

1340-
int k = j;
1340+
int j1 = j0;
13411341

1342-
// check if the user needs data from this node
1343-
while (!sched->callback_eval(k, t, true, sched->callback_eval_user_data) && k < split->graph.n_nodes - 1) {
1344-
t = split->graph.nodes[++k];
1342+
// determine the range [j0, j1] of nodes that can be computed together
1343+
while (j1 < split->graph.n_nodes - 1) {
1344+
// check if the user needs data from this node
1345+
if (sched->callback_eval(t, true, sched->callback_eval_user_data)) {
1346+
break;
1347+
}
1348+
1349+
t = split->graph.nodes[++j1];
13451350
}
13461351

1347-
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j, k + 1);
1352+
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
13481353

13491354
ggml_backend_graph_compute(split_backend, &gv);
13501355

1351-
// TODO: k is node index in the split, not in the original graph
1352-
// TODO: avoid the ask == true call here
1353-
if (sched->callback_eval(k, t, true, sched->callback_eval_user_data) &&
1354-
!sched->callback_eval(k, t, false, sched->callback_eval_user_data)) {
1356+
if (sched->callback_eval(t, true, sched->callback_eval_user_data) && // ask
1357+
!sched->callback_eval(t, false, sched->callback_eval_user_data)) { // eval
13551358
break;
13561359
}
13571360

1358-
j = k;
1361+
j0 = j1;
13591362
}
13601363
}
13611364
uint64_t compute_end_us = ggml_time_us();

Diff for: ggml-backend.h

+1-3
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,7 @@ extern "C" {
154154
// when ask == false, the scheduler is passing the node tensor to the user for observation
155155
// if the user returns false, the scheduler will cancel the graph compute
156156
//
157-
// TODO: propose to rename to ggml_backend_sched_callback_eval
158-
typedef bool (*ggml_backend_sched_eval_callback)(int node_index, struct ggml_tensor * t, bool ask, void * user_data);
157+
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
159158

160159
// Initialize a backend scheduler
161160
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
@@ -195,7 +194,6 @@ extern "C" {
195194
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
196195
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
197196

198-
// TODO: propose to rename this to ggml_backend_callback_compare
199197
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
200198

201199
// Compare the output of two backends

0 commit comments

Comments
 (0)