@@ -1039,7 +1039,12 @@ namespace chatllm
1039
1039
while (!aborted && !completed && (n_past + (int )curr_input_ids.size () < gen_config.max_length ))
1040
1040
{
1041
1041
std::vector<float > lm_logits;
1042
- generate_next_token (curr_input_ids, gen_config, lm_logits);
1042
+ if (!generate_next_token (curr_input_ids, gen_config, lm_logits))
1043
+ {
1044
+ ggml::log (GGML_LOG_LEVEL_ERROR, " Out of memory" );
1045
+ aborted = true ;
1046
+ break ;
1047
+ }
1043
1048
1044
1049
if (first_call)
1045
1050
{
@@ -1113,29 +1118,35 @@ namespace chatllm
1113
1118
void text_embedding (const GenerationConfig &gen_config, const std::vector<int > &input_ids,
1114
1119
std::vector<float > &embedding) override
1115
1120
{
1116
- run_model (input_ids, gen_config, 0 , embedding);
1121
+ auto r = run_model (input_ids, gen_config, 0 , embedding);
1122
+ if (!r) ggml::log (GGML_LOG_LEVEL_ERROR, " Out of memory" );
1117
1123
}
1118
1124
1119
1125
float qa_rank (const GenerationConfig &gen_config, const std::vector<int > &input_ids) override
1120
1126
{
1121
1127
std::vector<float > output;
1122
- run_model (input_ids, gen_config, 0 , output);
1128
+ auto r = run_model (input_ids, gen_config, 0 , output);
1129
+ if (!r) ggml::log (GGML_LOG_LEVEL_ERROR, " Out of memory" );
1123
1130
CHATLLM_CHECK (output.size () == 1 ) << " ouput must be scaler" ;
1124
1131
1125
1132
return output[0 ];
1126
1133
}
1127
1134
1128
- void generate_next_token (const std::vector<int > &input_ids, const GenerationConfig &gen_config, std::vector<float > &lm_logits) override
1135
+ bool generate_next_token (const std::vector<int > &input_ids, const GenerationConfig &gen_config, std::vector<float > &lm_logits) override
1129
1136
{
1130
1137
if (batch_input)
1131
1138
{
1132
- run_model (input_ids, gen_config, n_past + n_past_offset, lm_logits);
1139
+ return run_model (input_ids, gen_config, n_past + n_past_offset, lm_logits);
1133
1140
}
1134
1141
else
1135
1142
{
1136
1143
int past = n_past + n_past_offset;
1137
1144
for (size_t i = 0 ; (i < input_ids.size ()) & !aborted; i++, past++)
1138
- run_model ({input_ids[i]}, gen_config, past, lm_logits);
1145
+ {
1146
+ if (!run_model ({input_ids[i]}, gen_config, past, lm_logits))
1147
+ return false ;
1148
+ }
1149
+ return true ;
1139
1150
}
1140
1151
}
1141
1152
@@ -1218,7 +1229,7 @@ namespace chatllm
1218
1229
return s;
1219
1230
}
1220
1231
1221
- virtual void run_model (const std::vector<int > &input_ids,
1232
+ virtual bool run_model (const std::vector<int > &input_ids,
1222
1233
const GenerationConfig &gen_config,
1223
1234
int past,
1224
1235
std::vector<float > &output)
@@ -1228,7 +1239,8 @@ namespace chatllm
1228
1239
initial_run = true ;
1229
1240
int past = gen_config.max_length - (int )input_ids.size ();
1230
1241
if (past < 0 ) past = 0 ;
1231
- CHATLLM_CHECK (before_initial_run (input_ids, gen_config, past)) << " failed to reserve memory." ;
1242
+ if (!before_initial_run (input_ids, gen_config, past))
1243
+ return false ;
1232
1244
}
1233
1245
1234
1246
ForwardContext ctx (&backend_context);
@@ -1255,7 +1267,7 @@ namespace chatllm
1255
1267
1256
1268
output.resize (ggml::nbytes (r) / sizeof (output[0 ]));
1257
1269
1258
- CHATLLM_CHECK ( ctx.allocate ()) << " failed to allocate memory for graph " ;
1270
+ if (! ctx.allocate ()) return false ;
1259
1271
1260
1272
Backend::write_tensor_data (input_ids_tensor, input_ids.data ());
1261
1273
@@ -1270,6 +1282,8 @@ namespace chatllm
1270
1282
Backend::read_tensor_data (r, output.data ());
1271
1283
1272
1284
ctx.reset ();
1285
+
1286
+ return true ;
1273
1287
}
1274
1288
1275
1289
virtual bool is_output_terminated (const std::vector<int > &output_ids, int &keep_idx, int &pop_output)
0 commit comments