Skip to content

Commit 92640ff

Browse files
sasha0552arthw
authored andcommitted
server : fix smart selection of available slot (ggml-org#10120)
* Fix smart selection of available slot * minor fix * replace vectors of tokens with shorthands
1 parent a0a4c1a commit 92640ff

File tree

2 files changed

+59
-28
lines changed

2 files changed

+59
-28
lines changed

examples/server/server.cpp

+12-23
Original file line numberDiff line numberDiff line change
@@ -725,12 +725,12 @@ struct server_context {
725725
return nullptr;
726726
}
727727

728-
server_slot * get_available_slot(const std::string & prompt) {
728+
server_slot * get_available_slot(const server_task & task) {
729729
server_slot * ret = nullptr;
730730

731731
// find the slot that has at least n% prompt similarity
732-
if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
733-
int max_lcp_len = 0;
732+
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
733+
int max_lcs_len = 0;
734734
float similarity = 0;
735735

736736
for (server_slot & slot : slots) {
@@ -740,25 +740,25 @@ struct server_context {
740740
}
741741

742742
// skip the slot if it does not contains cached tokens
743-
if (slot.prompt_tokens.empty()) {
743+
if (slot.cache_tokens.empty()) {
744744
continue;
745745
}
746746

747-
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
748-
int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
747+
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
748+
int lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
749749

750-
// fraction of the common substring length compared to the current slot's prompt length
751-
similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
750+
// fraction of the common subsequence length compared to the current slot's prompt length
751+
similarity = static_cast<float>(lcs_len) / static_cast<int>(slot.cache_tokens.size());
752752

753753
// select the current slot if the criteria match
754-
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
755-
max_lcp_len = lcp_len;
754+
if (lcs_len > max_lcs_len && similarity > slot_prompt_similarity) {
755+
max_lcs_len = lcs_len;
756756
ret = &slot;
757757
}
758758
}
759759

760760
if (ret != nullptr) {
761-
SLT_DBG(*ret, "selected slot by lcp similarity, max_lcp_len = %d, similarity = %f\n", max_lcp_len, similarity);
761+
SLT_DBG(*ret, "selected slot by lcs similarity, max_lcs_len = %d, similarity = %f\n", max_lcs_len, similarity);
762762
}
763763
}
764764

@@ -1514,18 +1514,7 @@ struct server_context {
15141514
{
15151515
const int id_slot = json_value(task.data, "id_slot", -1);
15161516

1517-
server_slot * slot;
1518-
1519-
if (id_slot != -1) {
1520-
slot = get_slot_by_id(id_slot);
1521-
} else {
1522-
std::string prompt;
1523-
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
1524-
prompt = json_value(task.data, "prompt", std::string());
1525-
}
1526-
1527-
slot = get_available_slot(prompt);
1528-
}
1517+
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
15291518

15301519
if (slot == nullptr) {
15311520
// if no slot is available, we defer this task for processing later

examples/server/utils.hpp

+47-5
Original file line numberDiff line numberDiff line change
@@ -439,18 +439,60 @@ static std::string gen_chatcmplid() {
439439
// other common utils
440440
//
441441

442-
static size_t longest_common_prefix(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
442+
static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) {
443443
size_t i;
444444
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
445445

446446
return i;
447447
}
448448

449-
static size_t longest_common_prefix(const std::string & a, const std::string & b) {
450-
size_t i;
451-
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
449+
static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) {
450+
// check for empty sequences
451+
if (a.empty() || b.empty()) {
452+
return 0;
453+
}
452454

453-
return i;
455+
// get the lengths of the input sequences
456+
int a_len = a.size();
457+
int b_len = b.size();
458+
459+
// initialize the maximum length of the longest common subsequence (LCS)
460+
int max_length = 0;
461+
462+
// use two rows instead of a 2D matrix to optimize space
463+
std::vector<int> prev_row(b_len + 1, 0);
464+
std::vector<int> curr_row(b_len + 1, 0);
465+
466+
// iterate through the elements of a
467+
for (int i = 1; i <= a_len; i++) {
468+
// iterate through the elements of b
469+
for (int j = 1; j <= b_len; j++) {
470+
// if elements at the current positions match
471+
if (a[i - 1] == b[j - 1]) {
472+
// if it's the first element of either sequences, set LCS length to 1
473+
if (i == 1 || j == 1) {
474+
curr_row[j] = 1;
475+
} else {
476+
// increment LCS length by 1 compared to the previous element
477+
curr_row[j] = prev_row[j - 1] + 1;
478+
}
479+
480+
// update max_length if necessary
481+
if (curr_row[j] > max_length) {
482+
max_length = curr_row[j];
483+
}
484+
} else {
485+
// reset LCS length if elements don't match
486+
curr_row[j] = 0;
487+
}
488+
}
489+
490+
// update the previous row for the next iteration
491+
prev_row = curr_row;
492+
}
493+
494+
// return the maximum length of the LCS
495+
return max_length;
454496
}
455497

456498
static bool ends_with(const std::string & str, const std::string & suffix) {

0 commit comments

Comments
 (0)