Skip to content

server : fix smart selection of available slot #10120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 12 additions & 23 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,12 +725,12 @@ struct server_context {
return nullptr;
}

server_slot * get_available_slot(const std::string & prompt) {
server_slot * get_available_slot(const server_task & task) {
server_slot * ret = nullptr;

// find the slot that has at least n% prompt similarity
if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
int max_lcp_len = 0;
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
int max_lcs_len = 0;
float similarity = 0;

for (server_slot & slot : slots) {
Expand All @@ -740,25 +740,25 @@ struct server_context {
}

// skip the slot if it does not contains cached tokens
if (slot.prompt_tokens.empty()) {
if (slot.cache_tokens.empty()) {
continue;
}

// length of the Longest Common Prefix between the current slot's prompt and the input prompt
int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
int lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);

// fraction of the common substring length compared to the current slot's prompt length
similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
// fraction of the common subsequence length compared to the current slot's prompt length
similarity = static_cast<float>(lcs_len) / static_cast<int>(slot.cache_tokens.size());

// select the current slot if the criteria match
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
max_lcp_len = lcp_len;
if (lcs_len > max_lcs_len && similarity > slot_prompt_similarity) {
max_lcs_len = lcs_len;
ret = &slot;
}
}

if (ret != nullptr) {
SLT_DBG(*ret, "selected slot by lcp similarity, max_lcp_len = %d, similarity = %f\n", max_lcp_len, similarity);
SLT_DBG(*ret, "selected slot by lcs similarity, max_lcs_len = %d, similarity = %f\n", max_lcs_len, similarity);
}
}

Expand Down Expand Up @@ -1514,18 +1514,7 @@ struct server_context {
{
const int id_slot = json_value(task.data, "id_slot", -1);

server_slot * slot;

if (id_slot != -1) {
slot = get_slot_by_id(id_slot);
} else {
std::string prompt;
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
prompt = json_value(task.data, "prompt", std::string());
}

slot = get_available_slot(prompt);
}
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);

if (slot == nullptr) {
// if no slot is available, we defer this task for processing later
Expand Down
50 changes: 46 additions & 4 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,11 +446,53 @@ static size_t longest_common_prefix(const std::vector<llama_token> & a, const st
return i;
}

static size_t longest_common_prefix(const std::string & a, const std::string & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
static size_t longest_common_subsequence(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}

// get the lengths of the input sequences
int a_len = a.size();
int b_len = b.size();

// initialize the maximum length of the longest common subsequence (LCS)
int max_length = 0;

// use two rows instead of a 2D matrix to optimize space
std::vector<int> prev_row(b_len + 1, 0);
std::vector<int> curr_row(b_len + 1, 0);

// iterate through the elements of a
for (int i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (int j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}

return i;
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}

// update the previous row for the next iteration
prev_row = curr_row;
}

// return the maximum length of the LCS
return max_length;
}

static bool ends_with(const std::string & str, const std::string & suffix) {
Expand Down
Loading