Skip to content

Commit b1e5240

Browse files
ziedbhacebtenzzre
authored andcommitted
server : add single-client multi-prompt support (ggml-org#4232)
* * add multiprompt support * * cleanup * * more cleanup * * remove atomicity of id_gen, and change lock_guard to unique_lock on completion requests * * remove all references to mutex_multitasks * Update examples/server/server.cpp Co-authored-by: Jared Van Bortel <[email protected]> * Update examples/server/server.cpp Co-authored-by: Jared Van Bortel <[email protected]> * Update examples/server/server.cpp Co-authored-by: Jared Van Bortel <[email protected]> * Update examples/server/server.cpp Co-authored-by: Jared Van Bortel <[email protected]> * * change to set --------- Co-authored-by: Jared Van Bortel <[email protected]>
1 parent 2382b64 commit b1e5240

File tree

1 file changed

+128
-11
lines changed

1 file changed

+128
-11
lines changed

examples/server/server.cpp

+128-11
Original file line numberDiff line numberDiff line change
@@ -155,15 +155,23 @@ struct task_server {
155155
json data;
156156
bool infill_mode = false;
157157
bool embedding_mode = false;
158+
int multitask_id = -1;
158159
};
159160

160161
struct task_result {
161162
int id;
163+
int multitask_id = -1;
162164
bool stop;
163165
bool error;
164166
json result_json;
165167
};
166168

169+
struct task_multi {
170+
int id;
171+
std::set<int> subtasks_remaining{};
172+
std::vector<task_result> results{};
173+
};
174+
167175
// TODO: can become bool if we can't find use of more states
168176
enum slot_state
169177
{
@@ -406,6 +414,9 @@ struct llama_client_slot
406414
double t_prompt_processing; // ms
407415
double t_token_generation; // ms
408416

417+
// multitasks
418+
int multitask_id = -1;
419+
409420
void reset() {
410421
num_prompt_tokens = 0;
411422
generated_text = "";
@@ -529,7 +540,8 @@ struct llama_server_context
529540

530541
std::vector<task_server> queue_tasks;
531542
std::vector<task_result> queue_results;
532-
std::mutex mutex_tasks;
543+
std::vector<task_multi> queue_multitasks;
544+
std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
533545
std::mutex mutex_results;
534546

535547
~llama_server_context()
@@ -1112,17 +1124,40 @@ struct llama_server_context
11121124
return slot.images.size() > 0;
11131125
}
11141126

1115-
void send_error(int id, std::string error)
1127+
void send_error(task_server& task, std::string error)
11161128
{
11171129
std::lock_guard<std::mutex> lock(mutex_results);
11181130
task_result res;
1119-
res.id = id;
1131+
res.id = task.id;
1132+
res.multitask_id = task.multitask_id;
11201133
res.stop = false;
11211134
res.error = true;
11221135
res.result_json = { { "content", error } };
11231136
queue_results.push_back(res);
11241137
}
11251138

1139+
void add_multi_task(int id, std::vector<int>& sub_ids)
1140+
{
1141+
std::lock_guard<std::mutex> lock(mutex_tasks);
1142+
task_multi multi;
1143+
multi.id = id;
1144+
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
1145+
queue_multitasks.push_back(multi);
1146+
}
1147+
1148+
void update_multi_task(int multitask_id, int subtask_id, task_result& result)
1149+
{
1150+
std::lock_guard<std::mutex> lock(mutex_tasks);
1151+
for (auto& multitask : queue_multitasks)
1152+
{
1153+
if (multitask.id == multitask_id)
1154+
{
1155+
multitask.subtasks_remaining.erase(subtask_id);
1156+
multitask.results.push_back(result);
1157+
}
1158+
}
1159+
}
1160+
11261161
json get_model_props()
11271162
{
11281163
return get_formated_generation(slots[0]);
@@ -1167,6 +1202,7 @@ struct llama_server_context
11671202
std::lock_guard<std::mutex> lock(mutex_results);
11681203
task_result res;
11691204
res.id = slot.task_id;
1205+
res.multitask_id = slot.multitask_id;
11701206
res.error = false;
11711207
res.stop = false;
11721208

@@ -1206,6 +1242,7 @@ struct llama_server_context
12061242
std::lock_guard<std::mutex> lock(mutex_results);
12071243
task_result res;
12081244
res.id = slot.task_id;
1245+
res.multitask_id = slot.multitask_id;
12091246
res.error = false;
12101247
res.stop = true;
12111248

@@ -1251,6 +1288,12 @@ struct llama_server_context
12511288
res.result_json["model"] = slot.oaicompat_model;
12521289
}
12531290

1291+
// parent multitask, if any, needs to be updated
1292+
if (slot.multitask_id != -1)
1293+
{
1294+
update_multi_task(slot.multitask_id, slot.task_id, res);
1295+
}
1296+
12541297
queue_results.push_back(res);
12551298
}
12561299

@@ -1259,6 +1302,7 @@ struct llama_server_context
12591302
std::lock_guard<std::mutex> lock(mutex_results);
12601303
task_result res;
12611304
res.id = slot.task_id;
1305+
res.multitask_id = slot.multitask_id;
12621306
res.error = false;
12631307
res.stop = true;
12641308

@@ -1285,16 +1329,26 @@ struct llama_server_context
12851329
queue_results.push_back(res);
12861330
}
12871331

1288-
int request_completion(json data, bool infill, bool embedding)
1332+
int request_completion(json data, bool infill, bool embedding, int multitask_id)
12891333
{
1290-
std::lock_guard<std::mutex> lock(mutex_tasks);
1334+
std::unique_lock<std::mutex> lock(mutex_tasks);
12911335
task_server task;
12921336
task.id = id_gen++;
12931337
task.target_id = 0;
12941338
task.data = std::move(data);
12951339
task.infill_mode = infill;
12961340
task.embedding_mode = embedding;
12971341
task.type = COMPLETION_TASK;
1342+
task.multitask_id = multitask_id;
1343+
1344+
// when a completion task's prompt array is not a singleton, we split it into multiple requests
1345+
if (task.data.at("prompt").size() > 1)
1346+
{
1347+
lock.unlock(); // entering new func scope
1348+
return split_multiprompt_task(task);
1349+
}
1350+
1351+
// otherwise, it's a single-prompt task, we actually queue it
12981352
queue_tasks.push_back(task);
12991353
return task.id;
13001354
}
@@ -1313,8 +1367,17 @@ struct llama_server_context
13131367

13141368
for (int i = 0; i < (int) queue_results.size(); i++)
13151369
{
1370+
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
1371+
if (queue_results[i].multitask_id == task_id)
1372+
{
1373+
update_multi_task(task_id, queue_results[i].id, queue_results[i]);
1374+
queue_results.erase(queue_results.begin() + i);
1375+
continue;
1376+
}
1377+
13161378
if (queue_results[i].id == task_id)
13171379
{
1380+
assert(queue_results[i].multitask_id == -1);
13181381
task_result res = queue_results[i];
13191382
queue_results.erase(queue_results.begin() + i);
13201383
return res;
@@ -1404,6 +1467,27 @@ struct llama_server_context
14041467
queue_tasks.push_back(task);
14051468
}
14061469

1470+
int split_multiprompt_task(task_server& multiprompt_task)
1471+
{
1472+
auto prompt_count = multiprompt_task.data.at("prompt").size();
1473+
assert(prompt_count > 1);
1474+
1475+
int multitask_id = id_gen++;
1476+
std::vector<int> subtask_ids(prompt_count);
1477+
for (int i = 0; i < prompt_count; i++)
1478+
{
1479+
json subtask_data = multiprompt_task.data;
1480+
subtask_data["prompt"] = subtask_data["prompt"][i];
1481+
1482+
// subtasks inherit everything else (infill mode, embedding mode, etc.)
1483+
subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
1484+
}
1485+
1486+
// queue up the multitask so we can track its subtask progression
1487+
add_multi_task(multitask_id, subtask_ids);
1488+
return multitask_id;
1489+
}
1490+
14071491
void process_tasks()
14081492
{
14091493
std::lock_guard<std::mutex> lock(mutex_tasks);
@@ -1419,7 +1503,7 @@ struct llama_server_context
14191503
{
14201504
LOG_TEE("slot unavailable\n");
14211505
// send error result
1422-
send_error(task.id, "slot unavailable");
1506+
send_error(task, "slot unavailable");
14231507
return;
14241508
}
14251509

@@ -1433,11 +1517,12 @@ struct llama_server_context
14331517
slot->infill = task.infill_mode;
14341518
slot->embedding = task.embedding_mode;
14351519
slot->task_id = task.id;
1520+
slot->multitask_id = task.multitask_id;
14361521

14371522
if (!launch_slot_with_data(slot, task.data))
14381523
{
14391524
// send error result
1440-
send_error(task.id, "internal_error");
1525+
send_error(task, "internal_error");
14411526
break;
14421527
}
14431528
} break;
@@ -1453,6 +1538,38 @@ struct llama_server_context
14531538
} break;
14541539
}
14551540
}
1541+
1542+
// remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
1543+
auto queue_iterator = queue_multitasks.begin();
1544+
while (queue_iterator != queue_multitasks.end())
1545+
{
1546+
if (queue_iterator->subtasks_remaining.empty())
1547+
{
1548+
// all subtasks done == multitask is done
1549+
task_result aggregate_result;
1550+
aggregate_result.id = queue_iterator->id;
1551+
aggregate_result.stop = true;
1552+
aggregate_result.error = false;
1553+
1554+
// collect json results into one json result
1555+
std::vector<json> result_jsons;
1556+
for (auto& subres : queue_iterator->results)
1557+
{
1558+
result_jsons.push_back(subres.result_json);
1559+
aggregate_result.error = aggregate_result.error && subres.error;
1560+
}
1561+
aggregate_result.result_json = json{ "results", result_jsons };
1562+
1563+
std::lock_guard<std::mutex> lock(mutex_results);
1564+
queue_results.push_back(aggregate_result);
1565+
1566+
queue_iterator = queue_multitasks.erase(queue_iterator);
1567+
}
1568+
else
1569+
{
1570+
++queue_iterator;
1571+
}
1572+
}
14561573
}
14571574

14581575
bool update_slots() {
@@ -2596,7 +2713,7 @@ int main(int argc, char **argv)
25962713
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
25972714
{
25982715
json data = json::parse(req.body);
2599-
const int task_id = llama.request_completion(data, false, false);
2716+
const int task_id = llama.request_completion(data, false, false, -1);
26002717
if (!json_value(data, "stream", false)) {
26012718
std::string completion_text;
26022719
task_result result = llama.next_result(task_id);
@@ -2685,7 +2802,7 @@ int main(int argc, char **argv)
26852802
{
26862803
json data = oaicompat_completion_params_parse(json::parse(req.body));
26872804

2688-
const int task_id = llama.request_completion(data, false, false);
2805+
const int task_id = llama.request_completion(data, false, false, -1);
26892806

26902807
if (!json_value(data, "stream", false)) {
26912808
std::string completion_text;
@@ -2754,7 +2871,7 @@ int main(int argc, char **argv)
27542871
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
27552872
{
27562873
json data = json::parse(req.body);
2757-
const int task_id = llama.request_completion(data, true, false);
2874+
const int task_id = llama.request_completion(data, true, false, -1);
27582875
if (!json_value(data, "stream", false)) {
27592876
std::string completion_text;
27602877
task_result result = llama.next_result(task_id);
@@ -2858,7 +2975,7 @@ int main(int argc, char **argv)
28582975
{
28592976
prompt = "";
28602977
}
2861-
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true);
2978+
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true, -1);
28622979
task_result result = llama.next_result(task_id);
28632980
return res.set_content(result.result_json.dump(), "application/json");
28642981
});

0 commit comments

Comments
 (0)