Skip to content

Commit f572bd7

Browse files
author
Judd
committed
report error when OOM. don't exit.
1 parent 193cf74 commit f572bd7

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

src/chat.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ namespace chatllm
651651
int gen_max_tokens,
652652
BaseStreamer *streamer = nullptr) = 0;
653653

654-
virtual void generate_next_token(const std::vector<int> &input_ids, const GenerationConfig &gen_config, std::vector<float> &lm_logits) {};
654+
virtual bool generate_next_token(const std::vector<int> &input_ids, const GenerationConfig &gen_config, std::vector<float> &lm_logits) { return true; };
655655

656656
virtual void abort_generation(void) = 0;
657657

@@ -719,9 +719,9 @@ namespace chatllm
719719
return model->generate(input_ids, gen_config, continuous, completed, performance, gen_max_tokens, streamer);
720720
}
721721

722-
void generate_next_token(const std::vector<int> &input_ids, const GenerationConfig &gen_config, std::vector<float> &lm_logits) override
722+
bool generate_next_token(const std::vector<int> &input_ids, const GenerationConfig &gen_config, std::vector<float> &lm_logits) override
723723
{
724-
model->generate_next_token(input_ids, gen_config, lm_logits);
724+
return model->generate_next_token(input_ids, gen_config, lm_logits);
725725
}
726726

727727
void abort_generation(void) override { model->abort_generation(); }

src/models.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,12 @@ namespace chatllm
10391039
while (!aborted && !completed && (n_past + (int)curr_input_ids.size() < gen_config.max_length))
10401040
{
10411041
std::vector<float> lm_logits;
1042-
generate_next_token(curr_input_ids, gen_config, lm_logits);
1042+
if (!generate_next_token(curr_input_ids, gen_config, lm_logits))
1043+
{
1044+
ggml::log(GGML_LOG_LEVEL_ERROR, "Out of memory");
1045+
aborted = true;
1046+
break;
1047+
}
10431048

10441049
if (first_call)
10451050
{
@@ -1113,29 +1118,35 @@ namespace chatllm
11131118
void text_embedding(const GenerationConfig &gen_config, const std::vector<int> &input_ids,
11141119
std::vector<float> &embedding) override
11151120
{
1116-
run_model(input_ids, gen_config, 0, embedding);
1121+
auto r = run_model(input_ids, gen_config, 0, embedding);
1122+
if (!r) ggml::log(GGML_LOG_LEVEL_ERROR, "Out of memory");
11171123
}
11181124

11191125
float qa_rank(const GenerationConfig &gen_config, const std::vector<int> &input_ids) override
11201126
{
11211127
std::vector<float> output;
1122-
run_model(input_ids, gen_config, 0, output);
1128+
auto r = run_model(input_ids, gen_config, 0, output);
1129+
if (!r) ggml::log(GGML_LOG_LEVEL_ERROR, "Out of memory");
11231130
CHATLLM_CHECK(output.size() == 1) << "ouput must be scaler";
11241131

11251132
return output[0];
11261133
}
11271134

1128-
void generate_next_token(const std::vector<int> &input_ids, const GenerationConfig &gen_config, std::vector<float> &lm_logits) override
1135+
bool generate_next_token(const std::vector<int> &input_ids, const GenerationConfig &gen_config, std::vector<float> &lm_logits) override
11291136
{
11301137
if (batch_input)
11311138
{
1132-
run_model(input_ids, gen_config, n_past + n_past_offset, lm_logits);
1139+
return run_model(input_ids, gen_config, n_past + n_past_offset, lm_logits);
11331140
}
11341141
else
11351142
{
11361143
int past = n_past + n_past_offset;
11371144
for (size_t i = 0 ; (i < input_ids.size()) & !aborted; i++, past++)
1138-
run_model({input_ids[i]}, gen_config, past, lm_logits);
1145+
{
1146+
if (!run_model({input_ids[i]}, gen_config, past, lm_logits))
1147+
return false;
1148+
}
1149+
return true;
11391150
}
11401151
}
11411152

@@ -1218,7 +1229,7 @@ namespace chatllm
12181229
return s;
12191230
}
12201231

1221-
virtual void run_model(const std::vector<int> &input_ids,
1232+
virtual bool run_model(const std::vector<int> &input_ids,
12221233
const GenerationConfig &gen_config,
12231234
int past,
12241235
std::vector<float> &output)
@@ -1228,7 +1239,8 @@ namespace chatllm
12281239
initial_run = true;
12291240
int past = gen_config.max_length - (int)input_ids.size();
12301241
if (past < 0) past = 0;
1231-
CHATLLM_CHECK(before_initial_run(input_ids, gen_config, past)) << "failed to reserve memory.";
1242+
if (!before_initial_run(input_ids, gen_config, past))
1243+
return false;
12321244
}
12331245

12341246
ForwardContext ctx(&backend_context);
@@ -1255,7 +1267,7 @@ namespace chatllm
12551267

12561268
output.resize(ggml::nbytes(r) / sizeof(output[0]));
12571269

1258-
CHATLLM_CHECK(ctx.allocate()) << "failed to allocate memory for graph";
1270+
if (!ctx.allocate()) return false;
12591271

12601272
Backend::write_tensor_data(input_ids_tensor, input_ids.data());
12611273

@@ -1270,6 +1282,8 @@ namespace chatllm
12701282
Backend::read_tensor_data(r, output.data());
12711283

12721284
ctx.reset();
1285+
1286+
return true;
12731287
}
12741288

12751289
virtual bool is_output_terminated(const std::vector<int> &output_ids, int &keep_idx, int &pop_output)

0 commit comments

Comments
 (0)