Skip to content

Commit 7a16ce7

Browse files
authored
server : smart slot selection using Longest Common Prefix (ggml-org#7728)
* server : Smart selection of available slot using Longest Common Substring * add usage * remove trailing whitespaces * Use Longest Common Prefix (LCP) instead of LCS * Rename argument
1 parent da799b4 commit 7a16ce7

File tree

4 files changed

+138
-15
lines changed

4 files changed

+138
-15
lines changed

common/common.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -1491,6 +1491,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
14911491
params.chat_template = argv[i];
14921492
return true;
14931493
}
1494+
if (arg == "--slot-prompt-similarity" || arg == "-sps") {
1495+
if (++i >= argc) {
1496+
invalid_param = true;
1497+
return true;
1498+
}
1499+
params.slot_prompt_similarity = std::stof(argv[i]);
1500+
return true;
1501+
}
14941502
if (arg == "-pps") {
14951503
params.is_pp_shared = true;
14961504
return true;
@@ -1913,6 +1921,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
19131921
"set custom jinja chat template (default: template taken from model's metadata)\n"
19141922
"only commonly used templates are accepted:\n"
19151923
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
1924+
options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY",
1925+
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
19161926

19171927
#ifndef LOG_DISABLE_LOGS
19181928
options.push_back({ "logging" });

common/common.h

+2
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ struct gpt_params {
203203

204204
std::string slot_save_path;
205205

206+
float slot_prompt_similarity = 0.5f;
207+
206208
// batched-bench params
207209
bool is_pp_shared = false;
208210

examples/server/server.cpp

+119-15
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,9 @@ struct server_context {
647647

648648
server_metrics metrics;
649649

650+
// Necessary similarity of prompt for slot selection
651+
float slot_prompt_similarity = 0.0f;
652+
650653
~server_context() {
651654
if (ctx) {
652655
llama_free(ctx);
@@ -795,24 +798,88 @@ struct server_context {
795798
return prompt_tokens;
796799
}
797800

798-
server_slot * get_slot(int id) {
799-
int64_t t_last = ggml_time_us();
800-
801-
server_slot * last_used = nullptr;
802-
801+
server_slot * get_slot_by_id(int id) {
803802
for (server_slot & slot : slots) {
804-
if (slot.id == id && slot.available()) {
803+
if (slot.id == id) {
805804
return &slot;
806805
}
806+
}
807+
808+
return nullptr;
809+
}
810+
811+
server_slot * get_available_slot(const std::string & prompt) {
812+
server_slot * ret = nullptr;
813+
814+
// find the slot that has at least n% prompt similarity
815+
if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
816+
int max_lcp_len = 0;
817+
float similarity = 0;
818+
819+
for (server_slot & slot : slots) {
820+
// skip the slot if it is not available
821+
if (!slot.available()) {
822+
continue;
823+
}
824+
825+
// skip the slot if it does not contains prompt
826+
if (!slot.prompt.is_string()) {
827+
continue;
828+
}
829+
830+
// current slot's prompt
831+
std::string slot_prompt = slot.prompt.get<std::string>();
832+
833+
// length of the current slot's prompt
834+
int slot_prompt_len = slot_prompt.size();
835+
836+
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
837+
int lcp_len = common_part(slot_prompt, prompt);
838+
839+
// fraction of the common substring length compared to the current slot's prompt length
840+
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
841+
842+
// select the current slot if the criteria match
843+
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
844+
max_lcp_len = lcp_len;
845+
ret = &slot;
846+
}
847+
}
807848

808-
// among all available slots, find the one that has been least recently used
809-
if (slot.available() && slot.t_last_used < t_last) {
810-
last_used = &slot;
811-
t_last = slot.t_last_used;
849+
if (ret != nullptr) {
850+
LOG_VERBOSE("selected slot by lcp similarity", {
851+
{"id_slot", ret->id},
852+
{"max_lcp_len", max_lcp_len},
853+
{"similarity", similarity},
854+
});
812855
}
813856
}
814857

815-
return last_used;
858+
// find the slot that has been least recently used
859+
if (ret == nullptr) {
860+
int64_t t_last = ggml_time_us();
861+
for (server_slot & slot : slots) {
862+
// skip the slot if it is not available
863+
if (!slot.available()) {
864+
continue;
865+
}
866+
867+
// select the current slot if the criteria match
868+
if (slot.t_last_used < t_last) {
869+
t_last = slot.t_last_used;
870+
ret = &slot;
871+
}
872+
}
873+
874+
if (ret != nullptr) {
875+
LOG_VERBOSE("selected slot by lru", {
876+
{"id_slot", ret->id},
877+
{"t_last", t_last},
878+
});
879+
}
880+
}
881+
882+
return ret;
816883
}
817884

818885
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
@@ -1515,13 +1582,29 @@ struct server_context {
15151582
switch (task.type) {
15161583
case SERVER_TASK_TYPE_COMPLETION:
15171584
{
1518-
server_slot * slot = get_slot(json_value(task.data, "id_slot", -1));
1585+
int id_slot = json_value(task.data, "id_slot", -1);
1586+
std::string prompt = json_value(task.data, "prompt", std::string());
1587+
1588+
server_slot * slot;
1589+
1590+
if (id_slot != -1) {
1591+
slot = get_slot_by_id(id_slot);
1592+
} else {
1593+
slot = get_available_slot(prompt);
1594+
}
1595+
15191596
if (slot == nullptr) {
15201597
// if no slot is available, we defer this task for processing later
15211598
LOG_VERBOSE("no slot is available", {{"id_task", task.id}});
15221599
queue_tasks.defer(task);
15231600
break;
15241601
}
1602+
if (!slot->available()) {
1603+
// if requested slot is unavailable, we defer this task for processing later
1604+
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
1605+
queue_tasks.defer(task);
1606+
break;
1607+
}
15251608

15261609
if (task.data.contains("system_prompt")) {
15271610
std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
@@ -1638,11 +1721,17 @@ struct server_context {
16381721
case SERVER_TASK_TYPE_SLOT_SAVE:
16391722
{
16401723
int id_slot = task.data.at("id_slot");
1641-
server_slot * slot = get_slot(id_slot);
1724+
server_slot * slot = get_slot_by_id(id_slot);
16421725
if (slot == nullptr) {
16431726
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
16441727
break;
16451728
}
1729+
if (!slot->available()) {
1730+
// if requested slot is unavailable, we defer this task for processing later
1731+
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
1732+
queue_tasks.defer(task);
1733+
break;
1734+
}
16461735

16471736
const size_t token_count = slot->cache_tokens.size();
16481737
const int64_t t_start = ggml_time_us();
@@ -1673,11 +1762,17 @@ struct server_context {
16731762
case SERVER_TASK_TYPE_SLOT_RESTORE:
16741763
{
16751764
int id_slot = task.data.at("id_slot");
1676-
server_slot * slot = get_slot(id_slot);
1765+
server_slot * slot = get_slot_by_id(id_slot);
16771766
if (slot == nullptr) {
16781767
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
16791768
break;
16801769
}
1770+
if (!slot->available()) {
1771+
// if requested slot is unavailable, we defer this task for processing later
1772+
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
1773+
queue_tasks.defer(task);
1774+
break;
1775+
}
16811776

16821777
const int64_t t_start = ggml_time_us();
16831778

@@ -1715,11 +1810,17 @@ struct server_context {
17151810
case SERVER_TASK_TYPE_SLOT_ERASE:
17161811
{
17171812
int id_slot = task.data.at("id_slot");
1718-
server_slot * slot = get_slot(id_slot);
1813+
server_slot * slot = get_slot_by_id(id_slot);
17191814
if (slot == nullptr) {
17201815
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
17211816
break;
17221817
}
1818+
if (!slot->available()) {
1819+
// if requested slot is unavailable, we defer this task for processing later
1820+
LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
1821+
queue_tasks.defer(task);
1822+
break;
1823+
}
17231824

17241825
// Erase token cache
17251826
const size_t n_erased = slot->cache_tokens.size();
@@ -2467,6 +2568,9 @@ int main(int argc, char ** argv) {
24672568
log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded";
24682569
}
24692570

2571+
// Necessary similarity of prompt for slot selection
2572+
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
2573+
24702574
// load the model
24712575
if (!ctx_server.load_model(params)) {
24722576
state.store(SERVER_STATE_ERROR);

examples/server/utils.hpp

+7
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,13 @@ static size_t common_part(const std::vector<llama_token> & a, const std::vector<
253253
return i;
254254
}
255255

256+
static size_t common_part(const std::string & a, const std::string & b) {
257+
size_t i;
258+
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
259+
260+
return i;
261+
}
262+
256263
static bool ends_with(const std::string & str, const std::string & suffix) {
257264
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
258265
}

0 commit comments

Comments
 (0)