Skip to content

Commit d92e0fa

Browse files
committed
server : remove hack for extra parallel slot
ggml-ci
1 parent b8deef0 commit d92e0fa

File tree

1 file changed

+21
-26
lines changed

1 file changed

+21
-26
lines changed

examples/server/server.cpp

+21-26
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ struct server_queue {
378378
std::condition_variable condition_tasks;
379379

380380
// callback functions
381-
std::function<void(server_task&)> callback_new_task;
382-
std::function<void(void)> callback_update_slots;
381+
std::function<void(server_task)> callback_new_task;
382+
std::function<void(void)> callback_update_slots;
383383

384384
// Add a new task to the end of the queue
385385
int post(server_task task, bool front = false) {
@@ -431,7 +431,7 @@ struct server_queue {
431431
}
432432

433433
// Register function to process a new task
434-
void on_new_task(std::function<void(server_task &)> callback) {
434+
void on_new_task(std::function<void(server_task)> callback) {
435435
callback_new_task = std::move(callback);
436436
}
437437

@@ -481,7 +481,7 @@ struct server_queue {
481481
lock.unlock();
482482

483483
QUE_DBG("processing task, id = %d\n", task.id);
484-
callback_new_task(task);
484+
callback_new_task(std::move(task));
485485
}
486486

487487
// all tasks in the current loop is processed, slots data is now ready
@@ -644,17 +644,12 @@ struct server_context {
644644
bool load_model(const common_params & params_) {
645645
params = params_;
646646

647-
// reserve one extra sequence (seq_id == 0) for extra features
648-
params.n_parallel += 1;
649-
650647
common_init_result llama_init = common_init_from_params(params);
651648

652649
model = llama_init.model;
653650
ctx = llama_init.context;
654651
loras = llama_init.lora_adapters;
655652

656-
params.n_parallel -= 1; // but be sneaky about it
657-
658653
if (model == nullptr) {
659654
SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
660655
return false;
@@ -1288,16 +1283,16 @@ struct server_context {
12881283

12891284
void send_embedding(const server_slot & slot, const llama_batch & batch) {
12901285
server_task_result res;
1291-
res.id = slot.id_task;
1292-
res.error = false;
1293-
res.stop = true;
1286+
res.id = slot.id_task;
1287+
res.error = false;
1288+
res.stop = true;
12941289

12951290
const int n_embd = llama_n_embd(model);
12961291

12971292
std::vector<float> embd_res(n_embd, 0.0f);
12981293

12991294
for (int i = 0; i < batch.n_tokens; ++i) {
1300-
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
1295+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
13011296
continue;
13021297
}
13031298

@@ -1332,12 +1327,12 @@ struct server_context {
13321327

13331328
void send_rerank(const server_slot & slot, const llama_batch & batch) {
13341329
server_task_result res;
1335-
res.id = slot.id_task;
1336-
res.error = false;
1337-
res.stop = true;
1330+
res.id = slot.id_task;
1331+
res.error = false;
1332+
res.stop = true;
13381333

13391334
for (int i = 0; i < batch.n_tokens; ++i) {
1340-
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
1335+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
13411336
continue;
13421337
}
13431338

@@ -1510,7 +1505,7 @@ struct server_context {
15101505
// Functions to process the task
15111506
//
15121507

1513-
void process_single_task(const server_task & task) {
1508+
void process_single_task(server_task task) {
15141509
switch (task.type) {
15151510
case SERVER_TASK_TYPE_INFERENCE:
15161511
{
@@ -1808,8 +1803,8 @@ struct server_context {
18081803

18091804
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
18101805

1811-
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
1812-
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past, -n_discard);
1806+
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
1807+
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
18131808

18141809
if (slot.params.cache_prompt) {
18151810
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@@ -1836,7 +1831,7 @@ struct server_context {
18361831

18371832
slot.i_batch = batch.n_tokens;
18381833

1839-
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id + 1 }, true);
1834+
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
18401835

18411836
slot.n_past += 1;
18421837

@@ -1983,8 +1978,8 @@ struct server_context {
19831978

19841979
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
19851980

1986-
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
1987-
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
1981+
llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
1982+
llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift);
19881983

19891984
for (size_t i = 0; i < n_match; i++) {
19901985
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
@@ -2033,9 +2028,9 @@ struct server_context {
20332028
}
20342029

20352030
// keep only the common part
2036-
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
2031+
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
20372032
// could not partially delete (likely using a non-Transformer model)
2038-
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
2033+
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
20392034

20402035
// there is no common part left
20412036
slot.n_past = 0;
@@ -2048,7 +2043,7 @@ struct server_context {
20482043

20492044
// add prompt tokens for processing in the current batch
20502045
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
2051-
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
2046+
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
20522047

20532048
if (slot.params.cache_prompt) {
20542049
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);

0 commit comments

Comments
 (0)