@@ -155,15 +155,23 @@ struct task_server {
155
155
json data;
156
156
bool infill_mode = false ;
157
157
bool embedding_mode = false ;
158
+ int multitask_id = -1 ;
158
159
};
159
160
160
161
struct task_result {
161
162
int id;
163
+ int multitask_id = -1 ;
162
164
bool stop;
163
165
bool error;
164
166
json result_json;
165
167
};
166
168
169
+ struct task_multi {
170
+ int id;
171
+ std::set<int > subtasks_remaining{};
172
+ std::vector<task_result> results{};
173
+ };
174
+
167
175
// TODO: can become bool if we can't find use of more states
168
176
enum slot_state
169
177
{
@@ -406,6 +414,9 @@ struct llama_client_slot
406
414
double t_prompt_processing; // ms
407
415
double t_token_generation; // ms
408
416
417
+ // multitasks
418
+ int multitask_id = -1 ;
419
+
409
420
void reset () {
410
421
num_prompt_tokens = 0 ;
411
422
generated_text = " " ;
@@ -529,7 +540,8 @@ struct llama_server_context
529
540
530
541
std::vector<task_server> queue_tasks;
531
542
std::vector<task_result> queue_results;
532
- std::mutex mutex_tasks;
543
+ std::vector<task_multi> queue_multitasks;
544
+ std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
533
545
std::mutex mutex_results;
534
546
535
547
~llama_server_context ()
@@ -1112,17 +1124,40 @@ struct llama_server_context
1112
1124
return slot.images .size () > 0 ;
1113
1125
}
1114
1126
1115
- void send_error (int id , std::string error)
1127
+ void send_error (task_server& task , std::string error)
1116
1128
{
1117
1129
std::lock_guard<std::mutex> lock (mutex_results);
1118
1130
task_result res;
1119
- res.id = id;
1131
+ res.id = task.id ;
1132
+ res.multitask_id = task.multitask_id ;
1120
1133
res.stop = false ;
1121
1134
res.error = true ;
1122
1135
res.result_json = { { " content" , error } };
1123
1136
queue_results.push_back (res);
1124
1137
}
1125
1138
1139
+ void add_multi_task (int id, std::vector<int >& sub_ids)
1140
+ {
1141
+ std::lock_guard<std::mutex> lock (mutex_tasks);
1142
+ task_multi multi;
1143
+ multi.id = id;
1144
+ std::copy (sub_ids.begin (), sub_ids.end (), std::inserter (multi.subtasks_remaining , multi.subtasks_remaining .end ()));
1145
+ queue_multitasks.push_back (multi);
1146
+ }
1147
+
1148
+ void update_multi_task (int multitask_id, int subtask_id, task_result& result)
1149
+ {
1150
+ std::lock_guard<std::mutex> lock (mutex_tasks);
1151
+ for (auto & multitask : queue_multitasks)
1152
+ {
1153
+ if (multitask.id == multitask_id)
1154
+ {
1155
+ multitask.subtasks_remaining .erase (subtask_id);
1156
+ multitask.results .push_back (result);
1157
+ }
1158
+ }
1159
+ }
1160
+
1126
1161
json get_model_props ()
1127
1162
{
1128
1163
return get_formated_generation (slots[0 ]);
@@ -1167,6 +1202,7 @@ struct llama_server_context
1167
1202
std::lock_guard<std::mutex> lock (mutex_results);
1168
1203
task_result res;
1169
1204
res.id = slot.task_id ;
1205
+ res.multitask_id = slot.multitask_id ;
1170
1206
res.error = false ;
1171
1207
res.stop = false ;
1172
1208
@@ -1206,6 +1242,7 @@ struct llama_server_context
1206
1242
std::lock_guard<std::mutex> lock (mutex_results);
1207
1243
task_result res;
1208
1244
res.id = slot.task_id ;
1245
+ res.multitask_id = slot.multitask_id ;
1209
1246
res.error = false ;
1210
1247
res.stop = true ;
1211
1248
@@ -1251,6 +1288,12 @@ struct llama_server_context
1251
1288
res.result_json [" model" ] = slot.oaicompat_model ;
1252
1289
}
1253
1290
1291
+ // parent multitask, if any, needs to be updated
1292
+ if (slot.multitask_id != -1 )
1293
+ {
1294
+ update_multi_task (slot.multitask_id , slot.task_id , res);
1295
+ }
1296
+
1254
1297
queue_results.push_back (res);
1255
1298
}
1256
1299
@@ -1259,6 +1302,7 @@ struct llama_server_context
1259
1302
std::lock_guard<std::mutex> lock (mutex_results);
1260
1303
task_result res;
1261
1304
res.id = slot.task_id ;
1305
+ res.multitask_id = slot.multitask_id ;
1262
1306
res.error = false ;
1263
1307
res.stop = true ;
1264
1308
@@ -1285,16 +1329,26 @@ struct llama_server_context
1285
1329
queue_results.push_back (res);
1286
1330
}
1287
1331
1288
- int request_completion (json data, bool infill, bool embedding)
1332
+ int request_completion (json data, bool infill, bool embedding, int multitask_id )
1289
1333
{
1290
- std::lock_guard <std::mutex> lock (mutex_tasks);
1334
+ std::unique_lock <std::mutex> lock (mutex_tasks);
1291
1335
task_server task;
1292
1336
task.id = id_gen++;
1293
1337
task.target_id = 0 ;
1294
1338
task.data = std::move (data);
1295
1339
task.infill_mode = infill;
1296
1340
task.embedding_mode = embedding;
1297
1341
task.type = COMPLETION_TASK;
1342
+ task.multitask_id = multitask_id;
1343
+
1344
+ // when a completion task's prompt array is not a singleton, we split it into multiple requests
1345
+ if (task.data .at (" prompt" ).size () > 1 )
1346
+ {
1347
+ lock.unlock (); // entering new func scope
1348
+ return split_multiprompt_task (task);
1349
+ }
1350
+
1351
+ // otherwise, it's a single-prompt task, we actually queue it
1298
1352
queue_tasks.push_back (task);
1299
1353
return task.id ;
1300
1354
}
@@ -1313,8 +1367,17 @@ struct llama_server_context
1313
1367
1314
1368
for (int i = 0 ; i < (int ) queue_results.size (); i++)
1315
1369
{
1370
+ // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
1371
+ if (queue_results[i].multitask_id == task_id)
1372
+ {
1373
+ update_multi_task (task_id, queue_results[i].id , queue_results[i]);
1374
+ queue_results.erase (queue_results.begin () + i);
1375
+ continue ;
1376
+ }
1377
+
1316
1378
if (queue_results[i].id == task_id)
1317
1379
{
1380
+ assert (queue_results[i].multitask_id == -1 );
1318
1381
task_result res = queue_results[i];
1319
1382
queue_results.erase (queue_results.begin () + i);
1320
1383
return res;
@@ -1404,6 +1467,27 @@ struct llama_server_context
1404
1467
queue_tasks.push_back (task);
1405
1468
}
1406
1469
1470
+ int split_multiprompt_task (task_server& multiprompt_task)
1471
+ {
1472
+ auto prompt_count = multiprompt_task.data .at (" prompt" ).size ();
1473
+ assert (prompt_count > 1 );
1474
+
1475
+ int multitask_id = id_gen++;
1476
+ std::vector<int > subtask_ids (prompt_count);
1477
+ for (int i = 0 ; i < prompt_count; i++)
1478
+ {
1479
+ json subtask_data = multiprompt_task.data ;
1480
+ subtask_data[" prompt" ] = subtask_data[" prompt" ][i];
1481
+
1482
+ // subtasks inherit everything else (infill mode, embedding mode, etc.)
1483
+ subtask_ids[i] = request_completion (subtask_data, multiprompt_task.infill_mode , multiprompt_task.embedding_mode , multitask_id);
1484
+ }
1485
+
1486
+ // queue up the multitask so we can track its subtask progression
1487
+ add_multi_task (multitask_id, subtask_ids);
1488
+ return multitask_id;
1489
+ }
1490
+
1407
1491
void process_tasks ()
1408
1492
{
1409
1493
std::lock_guard<std::mutex> lock (mutex_tasks);
@@ -1419,7 +1503,7 @@ struct llama_server_context
1419
1503
{
1420
1504
LOG_TEE (" slot unavailable\n " );
1421
1505
// send error result
1422
- send_error (task. id , " slot unavailable" );
1506
+ send_error (task, " slot unavailable" );
1423
1507
return ;
1424
1508
}
1425
1509
@@ -1433,11 +1517,12 @@ struct llama_server_context
1433
1517
slot->infill = task.infill_mode ;
1434
1518
slot->embedding = task.embedding_mode ;
1435
1519
slot->task_id = task.id ;
1520
+ slot->multitask_id = task.multitask_id ;
1436
1521
1437
1522
if (!launch_slot_with_data (slot, task.data ))
1438
1523
{
1439
1524
// send error result
1440
- send_error (task. id , " internal_error" );
1525
+ send_error (task, " internal_error" );
1441
1526
break ;
1442
1527
}
1443
1528
} break ;
@@ -1453,6 +1538,38 @@ struct llama_server_context
1453
1538
} break ;
1454
1539
}
1455
1540
}
1541
+
1542
+ // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue
1543
+ auto queue_iterator = queue_multitasks.begin ();
1544
+ while (queue_iterator != queue_multitasks.end ())
1545
+ {
1546
+ if (queue_iterator->subtasks_remaining .empty ())
1547
+ {
1548
+ // all subtasks done == multitask is done
1549
+ task_result aggregate_result;
1550
+ aggregate_result.id = queue_iterator->id ;
1551
+ aggregate_result.stop = true ;
1552
+ aggregate_result.error = false ;
1553
+
1554
+ // collect json results into one json result
1555
+ std::vector<json> result_jsons;
1556
+ for (auto & subres : queue_iterator->results )
1557
+ {
1558
+ result_jsons.push_back (subres.result_json );
1559
+ aggregate_result.error = aggregate_result.error && subres.error ;
1560
+ }
1561
+ aggregate_result.result_json = json{ " results" , result_jsons };
1562
+
1563
+ std::lock_guard<std::mutex> lock (mutex_results);
1564
+ queue_results.push_back (aggregate_result);
1565
+
1566
+ queue_iterator = queue_multitasks.erase (queue_iterator);
1567
+ }
1568
+ else
1569
+ {
1570
+ ++queue_iterator;
1571
+ }
1572
+ }
1456
1573
}
1457
1574
1458
1575
bool update_slots () {
@@ -2596,7 +2713,7 @@ int main(int argc, char **argv)
2596
2713
svr.Post (" /completion" , [&llama](const httplib::Request &req, httplib::Response &res)
2597
2714
{
2598
2715
json data = json::parse (req.body );
2599
- const int task_id = llama.request_completion (data, false , false );
2716
+ const int task_id = llama.request_completion (data, false , false , - 1 );
2600
2717
if (!json_value (data, " stream" , false )) {
2601
2718
std::string completion_text;
2602
2719
task_result result = llama.next_result (task_id);
@@ -2685,7 +2802,7 @@ int main(int argc, char **argv)
2685
2802
{
2686
2803
json data = oaicompat_completion_params_parse (json::parse (req.body ));
2687
2804
2688
- const int task_id = llama.request_completion (data, false , false );
2805
+ const int task_id = llama.request_completion (data, false , false , - 1 );
2689
2806
2690
2807
if (!json_value (data, " stream" , false )) {
2691
2808
std::string completion_text;
@@ -2754,7 +2871,7 @@ int main(int argc, char **argv)
2754
2871
svr.Post (" /infill" , [&llama](const httplib::Request &req, httplib::Response &res)
2755
2872
{
2756
2873
json data = json::parse (req.body );
2757
- const int task_id = llama.request_completion (data, true , false );
2874
+ const int task_id = llama.request_completion (data, true , false , - 1 );
2758
2875
if (!json_value (data, " stream" , false )) {
2759
2876
std::string completion_text;
2760
2877
task_result result = llama.next_result (task_id);
@@ -2858,7 +2975,7 @@ int main(int argc, char **argv)
2858
2975
{
2859
2976
prompt = " " ;
2860
2977
}
2861
- const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 } }, false , true );
2978
+ const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 } }, false , true , - 1 );
2862
2979
task_result result = llama.next_result (task_id);
2863
2980
return res.set_content (result.result_json .dump (), " application/json" );
2864
2981
});
0 commit comments