@@ -37,13 +37,16 @@ extern "C" {
37
37
// ====== Dataset ======
38
38
39
39
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init (
40
- int64_t ne_datapoint , // number of elements per datapoint
41
- int64_t ne_label , // number of elements per label
42
- int64_t ndata , // total number of datapoints/labels
43
- int64_t ndata_shard ); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
40
+ enum ggml_type type_data , // the type for the internal data tensor
41
+ enum ggml_type type_label , // the type for the internal labels tensor
42
+ int64_t ne_datapoint , // number of elements per datapoint
43
+ int64_t ne_label , // number of elements per label
44
+ int64_t ndata , // total number of datapoints/labels
45
+ int64_t ndata_shard ); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
44
46
GGML_API void ggml_opt_dataset_free (ggml_opt_dataset_t dataset );
45
47
46
48
// get underlying tensors that store the data
49
+ GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset );
47
50
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset ); // shape = [ne_datapoint, ndata]
48
51
GGML_API struct ggml_tensor * ggml_opt_dataset_labels (ggml_opt_dataset_t dataset ); // shape = [nd_label, ndata]
49
52
@@ -56,13 +59,19 @@ extern "C" {
56
59
struct ggml_tensor * data_batch , // shape = [ne_datapoint, ndata_batch]
57
60
struct ggml_tensor * labels_batch , // shape = [ne_label, ndata_batch]
58
61
int64_t ibatch );
62
+ GGML_API void ggml_opt_dataset_get_batch_host (
63
+ ggml_opt_dataset_t dataset ,
64
+ void * data_batch ,
65
+ size_t nb_data_batch ,
66
+ void * labels_batch ,
67
+ int64_t ibatch );
59
68
60
69
// ====== Model / Context ======
61
70
62
71
enum ggml_opt_build_type {
63
- GGML_OPT_BUILD_TYPE_FORWARD ,
64
- GGML_OPT_BUILD_TYPE_GRAD ,
65
- GGML_OPT_BUILD_TYPE_OPT ,
72
+ GGML_OPT_BUILD_TYPE_FORWARD = 10 ,
73
+ GGML_OPT_BUILD_TYPE_GRAD = 20 ,
74
+ GGML_OPT_BUILD_TYPE_OPT = 30 ,
66
75
};
67
76
68
77
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@@ -81,20 +90,22 @@ extern "C" {
81
90
// userdata can be used to pass arbitrary data
82
91
typedef struct ggml_opt_optimizer_params (* ggml_opt_get_optimizer_params )(void * userdata );
83
92
84
- // returns the default optimizer params (constant)
93
+ // returns the default optimizer params (constant, hard-coded values )
85
94
// userdata is not used
86
95
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params (void * userdata );
87
96
97
+ // casts userdata to ggml_opt_optimizer_params and returns it
98
+ GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params (void * userdata );
99
+
88
100
// parameters for initializing a new optimization context
89
101
struct ggml_opt_params {
90
102
ggml_backend_sched_t backend_sched ; // defines which backends are used to construct the compute graphs
91
103
92
- struct ggml_context * ctx_compute ; // created in user code, holds non-static tensors
93
-
94
- // the forward graph is defined by inputs and outputs
95
- // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
96
- struct ggml_tensor * inputs ;
97
- struct ggml_tensor * outputs ;
104
+ // by default the forward graph needs to be reconstructed for each eval
105
+ // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
106
+ struct ggml_context * ctx_compute ;
107
+ struct ggml_tensor * inputs ;
108
+ struct ggml_tensor * outputs ;
98
109
99
110
enum ggml_opt_loss_type loss_type ;
100
111
enum ggml_opt_build_type build_type ;
@@ -107,12 +118,9 @@ extern "C" {
107
118
108
119
// get parameters for an optimization context with defaults set where possible
109
120
// parameters for which no sensible defaults exist are supplied as arguments to this function
110
- GGML_API ggml_opt_params ggml_opt_default_params (
111
- ggml_backend_sched_t backend_sched ,
112
- struct ggml_context * ctx_compute ,
113
- struct ggml_tensor * inputs ,
114
- struct ggml_tensor * outputs ,
115
- enum ggml_opt_loss_type loss_type );
121
+ GGML_API struct ggml_opt_params ggml_opt_default_params (
122
+ ggml_backend_sched_t backend_sched ,
123
+ enum ggml_opt_loss_type loss_type );
116
124
117
125
GGML_API ggml_opt_context_t ggml_opt_init (struct ggml_opt_params params );
118
126
GGML_API void ggml_opt_free (ggml_opt_context_t opt_ctx );
@@ -121,18 +129,20 @@ extern "C" {
121
129
GGML_API void ggml_opt_reset (ggml_opt_context_t opt_ctx , bool optimizer );
122
130
123
131
// get underlying tensors that store data
132
+ // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
124
133
GGML_API struct ggml_tensor * ggml_opt_inputs ( ggml_opt_context_t opt_ctx ); // forward graph input tensor
125
134
GGML_API struct ggml_tensor * ggml_opt_outputs ( ggml_opt_context_t opt_ctx ); // forward graph output tensor
126
135
GGML_API struct ggml_tensor * ggml_opt_labels ( ggml_opt_context_t opt_ctx ); // labels to compare outputs against
127
136
GGML_API struct ggml_tensor * ggml_opt_loss ( ggml_opt_context_t opt_ctx ); // scalar tensor that contains the loss
128
137
GGML_API struct ggml_tensor * ggml_opt_pred ( ggml_opt_context_t opt_ctx ); // predictions made by outputs
129
138
GGML_API struct ggml_tensor * ggml_opt_ncorrect (ggml_opt_context_t opt_ctx ); // number of matching predictions between outputs and labels
130
139
140
+ // get the gradient accumulator for a node from the forward graph
131
141
GGML_API struct ggml_tensor * ggml_opt_grad_acc (ggml_opt_context_t opt_ctx , struct ggml_tensor * node );
132
142
133
143
// ====== Optimization Result ======
134
144
135
- GGML_API ggml_opt_result_t ggml_opt_result_init ();
145
+ GGML_API ggml_opt_result_t ggml_opt_result_init (void );
136
146
GGML_API void ggml_opt_result_free (ggml_opt_result_t result );
137
147
GGML_API void ggml_opt_result_reset (ggml_opt_result_t result );
138
148
@@ -144,11 +154,20 @@ extern "C" {
144
154
145
155
// ====== Computation ======
146
156
147
- // do forward pass, increment result if not NULL
148
- GGML_API void ggml_opt_forward (ggml_opt_context_t opt_ctx , ggml_opt_result_t result );
157
+ // if not using static graphs, this function must be called prior to ggml_opt_alloc
158
+ GGML_API void ggml_opt_prepare_alloc (
159
+ ggml_opt_context_t opt_ctx ,
160
+ struct ggml_context * ctx_compute ,
161
+ struct ggml_cgraph * gf ,
162
+ struct ggml_tensor * inputs ,
163
+ struct ggml_tensor * outputs );
164
+
165
+ // allocate the next graph for evaluation, either forward or forward + backward
166
+ // must be called exactly once prior to calling ggml_opt_eval
167
+ GGML_API void ggml_opt_alloc (ggml_opt_context_t opt_ctx , bool backward );
149
168
150
- // do forward pass, increment result if not NULL, do backward pass
151
- GGML_API void ggml_opt_forward_backward (ggml_opt_context_t opt_ctx , ggml_opt_result_t result );
169
+ // do forward pass, increment result if not NULL, do backward pass if allocated
170
+ GGML_API void ggml_opt_eval (ggml_opt_context_t opt_ctx , ggml_opt_result_t result );
152
171
153
172
// ############################################################################
154
173
// ## The high-level functions start here. They do not depend on any private ##
@@ -200,9 +219,9 @@ extern "C" {
200
219
// fit model defined by inputs and outputs to dataset
201
220
GGML_API void ggml_opt_fit (
202
221
ggml_backend_sched_t backend_sched , // backend scheduler for constructing the compute graphs
203
- ggml_context * ctx_compute , // context with temporarily allocated tensors to calculate the outputs
204
- ggml_tensor * inputs , // input tensor with shape [ne_datapoint, ndata_batch]
205
- ggml_tensor * outputs , // output tensor, must have shape [ne_label, ndata_batch] if labels are used
222
+ struct ggml_context * ctx_compute , // context with temporarily allocated tensors to calculate the outputs
223
+ struct ggml_tensor * inputs , // input tensor with shape [ne_datapoint, ndata_batch]
224
+ struct ggml_tensor * outputs , // output tensor, must have shape [ne_label, ndata_batch] if labels are used
206
225
ggml_opt_dataset_t dataset , // dataset with data and optionally also labels
207
226
enum ggml_opt_loss_type loss_type , // loss to minimize
208
227
ggml_opt_get_optimizer_params get_opt_pars , // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
0 commit comments