Skip to content

server : remove hack for extra parallel slot #10187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 24 additions & 29 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ struct server_queue {
std::condition_variable condition_tasks;

// callback functions
std::function<void(server_task&)> callback_new_task;
std::function<void(void)> callback_update_slots;
std::function<void(server_task)> callback_new_task;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, changing the signature here so we can move the task into the callback and mutate it.

std::function<void(void)> callback_update_slots;

// Add a new task to the end of the queue
int post(server_task task, bool front = false) {
Expand Down Expand Up @@ -431,7 +431,7 @@ struct server_queue {
}

// Register function to process a new task
void on_new_task(std::function<void(server_task &)> callback) {
void on_new_task(std::function<void(server_task)> callback) {
callback_new_task = std::move(callback);
}

Expand Down Expand Up @@ -481,7 +481,7 @@ struct server_queue {
lock.unlock();

QUE_DBG("processing task, id = %d\n", task.id);
callback_new_task(task);
callback_new_task(std::move(task));
}

// all tasks in the current loop is processed, slots data is now ready
Expand Down Expand Up @@ -644,17 +644,12 @@ struct server_context {
bool load_model(const common_params & params_) {
params = params_;

// reserve one extra sequence (seq_id == 0) for extra features
params.n_parallel += 1;

common_init_result llama_init = common_init_from_params(params);

model = llama_init.model;
ctx = llama_init.context;
loras = llama_init.lora_adapters;

params.n_parallel -= 1; // but be sneaky about it

if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
return false;
Expand Down Expand Up @@ -1288,16 +1283,16 @@ struct server_context {

void send_embedding(const server_slot & slot, const llama_batch & batch) {
server_task_result res;
res.id = slot.id_task;
res.error = false;
res.stop = true;
res.id = slot.id_task;
res.error = false;
res.stop = true;

const int n_embd = llama_n_embd(model);

std::vector<float> embd_res(n_embd, 0.0f);

for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}

Expand Down Expand Up @@ -1332,12 +1327,12 @@ struct server_context {

void send_rerank(const server_slot & slot, const llama_batch & batch) {
server_task_result res;
res.id = slot.id_task;
res.error = false;
res.stop = true;
res.id = slot.id_task;
res.error = false;
res.stop = true;

for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}

Expand Down Expand Up @@ -1510,7 +1505,7 @@ struct server_context {
// Functions to process the task
//

void process_single_task(const server_task & task) {
void process_single_task(server_task task) {
switch (task.type) {
case SERVER_TASK_TYPE_INFERENCE:
{
Expand Down Expand Up @@ -1646,7 +1641,7 @@ struct server_context {
std::string filename = task.data.at("filename");
std::string filepath = task.data.at("filepath");

const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);

const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0;
Expand Down Expand Up @@ -1688,7 +1683,7 @@ struct server_context {

slot->cache_tokens.resize(slot->n_ctx);
size_t token_count = 0;
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
if (nread == 0) {
slot->cache_tokens.resize(0);
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
Expand Down Expand Up @@ -1731,7 +1726,7 @@ struct server_context {

// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
slot->cache_tokens.clear();

server_task_result result;
Expand Down Expand Up @@ -1808,8 +1803,8 @@ struct server_context {

SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);

llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);

if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
Expand All @@ -1836,7 +1831,7 @@ struct server_context {

slot.i_batch = batch.n_tokens;

common_batch_add(batch, slot.sampled, slot.n_past, { slot.id + 1 }, true);
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);

slot.n_past += 1;

Expand Down Expand Up @@ -1983,8 +1978,8 @@ struct server_context {

const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;

llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift);

for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
Expand Down Expand Up @@ -2033,9 +2028,9 @@ struct server_context {
}

// keep only the common part
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);

// there is no common part left
slot.n_past = 0;
Expand All @@ -2048,7 +2043,7 @@ struct server_context {

// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);

if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
Expand Down
Loading