From 95e0afb977859d70c8b8aeffb712af9230ac7a26 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 12 Jan 2025 13:16:37 +0100 Subject: [PATCH 1/3] wip: chat cli --- examples/main/CMakeLists.txt | 2 +- examples/main/chat.hpp | 222 +++++++++++++++++++++++++++++++++++ examples/main/main.cpp | 41 ++++--- 3 files changed, 243 insertions(+), 22 deletions(-) create mode 100644 examples/main/chat.hpp diff --git a/examples/main/CMakeLists.txt b/examples/main/CMakeLists.txt index af3d9150f8640..37fe0e76bbc02 100644 --- a/examples/main/CMakeLists.txt +++ b/examples/main/CMakeLists.txt @@ -1,5 +1,5 @@ set(TARGET llama-cli) -add_executable(${TARGET} main.cpp) +add_executable(${TARGET} main.cpp chat.hpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/main/chat.hpp b/examples/main/chat.hpp new file mode 100644 index 0000000000000..801e7acdc1be6 --- /dev/null +++ b/examples/main/chat.hpp @@ -0,0 +1,222 @@ +#include "arg.h" +#include "common.h" +#include "console.h" +#include "log.h" +#include "sampling.h" +#include "llama.h" + +#include + +struct llama_cli_chat { + struct llama_context * ctx; + const struct llama_model * model; + struct common_sampler * smpl; + struct common_params params; + + bool interacting = false; + std::vector chat_msgs; + std::ostringstream pending_input; + + struct llama_batch batch; + llama_tokens cache_tokens; + int n_past = 0; + + llama_cli_chat( + struct common_params & params, + struct llama_context * ctx, + struct common_sampler * smpl) : ctx(ctx), smpl(smpl), params(params) { + model = llama_get_model(ctx); + batch = llama_batch_init(params.n_batch, 0, 1); + } + + void decode(llama_tokens & eval_tokens, bool is_generating) { + if (is_generating) { + GGML_ASSERT(eval_tokens.size() == 1); + } else { + n_past = common_lcp(cache_tokens, eval_tokens); + // in case we do a re-generation, we need to prevent eval_tokens from being empty + if ((int) eval_tokens.size() == n_past) { + n_past--; + } + if (n_past > 0) { + eval_tokens.erase(eval_tokens.begin(), eval_tokens.begin() + n_past); + cache_tokens.erase(cache_tokens.begin() + n_past, cache_tokens.end()); + LOG_DBG("remove from cache [%d, inf)\n", n_past); + LOG_DBG("in cache: %s\n", common_detokenize(ctx, cache_tokens, true).c_str()); + LOG_DBG("to decode %d tokens\n", (int) eval_tokens.size()); + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + } + } + + // decode + for (size_t i = 0; i < eval_tokens.size(); i += params.n_batch) { + if (interacting) { + break; + } + + common_batch_clear(batch); + for (int j = 0; j < params.n_batch && i + j < eval_tokens.size(); ++j) { + n_past++; + bool is_last_token = i + j == eval_tokens.size() - 1; + common_batch_add(batch, eval_tokens[i + j], n_past, {0}, is_last_token); + } + + if (llama_decode(ctx, batch)) { + GGML_ABORT("failed to decode\n"); + } + } + + // update cache tokens + if (is_generating) { + cache_tokens.push_back(eval_tokens[0]); + } else { + cache_tokens.insert(cache_tokens.end(), eval_tokens.begin(), eval_tokens.end()); + } + } + + [[noreturn]] void run() { + while (true) { + interacting = true; + LOG("\n> "); + + // color user input only + console::set_display(console::user_input); + std::string line; + bool another_line = true; + bool continue_input = false; + do { + another_line = console::readline(line, params.multiline_input); + if (handle_command(line, continue_input)) { + continue; // do not add this line to pending_input + } + pending_input << line; + } while (another_line); + + if (continue_input) { + continue; + } + + if (pending_input.tellp() == 0) { + LOG_DBG("empty line, passing control back\n"); + continue; + } + + // done taking input, reset color + console::set_display(console::reset); + interacting = false; + + // add message and format chat + if (!chat_msgs.empty() && chat_msgs.back().role == "user") { + chat_msgs.pop_back(); + } + chat_msgs.push_back({"user", string_strip(pending_input.str())}); + pending_input.str(""); // clear + auto formatted = common_chat_apply_template(model, params.chat_template, chat_msgs, true); + + // tokenize the new chat history and decode + llama_tokens prompt_tokens = common_tokenize(ctx, formatted, true, true); + decode(prompt_tokens, false); + + // generate response + llama_token new_token_id = LLAMA_TOKEN_NULL; + llama_tokens generated_tokens; + common_sampler_reset(smpl); + while (true) { + if (interacting) { + break; + } + + // sample the next token + new_token_id = common_sampler_sample(smpl, ctx, -1); + + // is it an end of generation? + if (llama_token_is_eog(model, new_token_id)) { + break; + } + + // print the token, then decode it + printf("%s", common_token_to_piece(ctx, new_token_id, params.special).c_str()); + fflush(stdout); + generated_tokens.push_back(new_token_id); + llama_tokens new_tok = {new_token_id}; + decode(new_tok, true); + } + + // add the generated tokens to the chat history + std::string response = common_detokenize(ctx, generated_tokens, true); + chat_msgs.push_back({"assistant", response}); + + // print a new line if needed + if (!response.empty() && response.back() != '\n') { + printf("\n"); + } + } + } + + void interrupt() { + if (interacting) { + // exit + printf("\n"); + console::cleanup(); + common_perf_print(ctx, smpl); + common_log_pause(common_log_main()); + exit(0); + } + interacting = true; + } + + bool handle_command(std::string & inp, bool & continue_input) { + if (inp.empty() || inp[0] != '/') { + return false; // not a command + } + auto parts = string_split(string_strip(inp), ' '); + std::string & cmd = parts[0]; + if (cmd == "/help") { + LOG("TODO\n"); + continue_input = true; + } else if (cmd == "/history") { + display_history(); + continue_input = true; + } else if (cmd == "/regen") { + if (chat_msgs.empty()) { + LOG_ERR("no chat history to regenerate\n"); + continue_input = true; + return true; + } + if (chat_msgs.back().role == "assistant") { + chat_msgs.pop_back(); + } + if (chat_msgs.back().role == "user") { + pending_input.str(""); // clear + pending_input << chat_msgs.back().content; + chat_msgs.pop_back(); + } + continue_input = false; + } else if (cmd == "/readfile") { + const std::string filename = parts[1]; + LOG_DBG("reading file: '%s'\n", filename.c_str()); + std::ifstream text_file(filename); + if (!text_file) { + LOG("failed to open file '%s'\n", filename.c_str()); + } else { + pending_input << text_file.rdbuf() << "\n\n"; + LOG("read %zu characters from file\n", (size_t) text_file.tellg()); + } + continue_input = true; + } else { + LOG_ERR("unknown command: %s\n", cmd.c_str()); + continue_input = true; + } + return true; + } + + void display_history() { + for (const auto & msg : chat_msgs) { + LOG("%s: %s\n\n", msg.role.c_str(), msg.content.c_str()); + } + } + + ~llama_cli_chat() { + llama_batch_free(batch); + } +}; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index aaee47e325943..5a6e356e61951 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -4,6 +4,7 @@ #include "log.h" #include "sampling.h" #include "llama.h" +#include "chat.hpp" #include #include @@ -35,6 +36,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; static common_sampler ** g_smpl; static common_params * g_params; +static llama_cli_chat * g_chat; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; static std::vector * g_output_tokens; @@ -65,7 +67,9 @@ static bool file_is_empty(const std::string & path) { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) static void sigint_handler(int signo) { if (signo == SIGINT) { - if (!is_interacting && g_params->interactive) { + if (g_chat) { + g_chat->interrupt(); + } else if (!is_interacting && g_params->interactive) { is_interacting = true; need_insert_eot = true; } else { @@ -83,14 +87,6 @@ static void sigint_handler(int signo) { } #endif -static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, const std::string & role, const std::string & content) { - common_chat_msg new_msg{role, content}; - auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user"); - chat_msgs.push_back({role, content}); - LOG_DBG("formatted: '%s'\n", formatted.c_str()); - return formatted; -} - int main(int argc, char ** argv) { common_params params; g_params = ¶ms; @@ -203,6 +199,12 @@ int main(int argc, char ** argv) { LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); } + // switch on conversation mode if chat template is present + if (!params.chat_template.empty() || !common_get_builtin_chat_template(model).empty()) { + LOG("%s: using chat mode\n", __func__); + params.conversation = true; + } + // print chat template example in conversation mode if (params.conversation) { if (params.enable_chat_template) { @@ -251,18 +253,15 @@ int main(int argc, char ** argv) { std::vector embd_inp; { - auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty()) - ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode - : params.prompt; if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { LOG_DBG("tokenize the prompt\n"); - embd_inp = common_tokenize(ctx, prompt, true, true); + embd_inp = common_tokenize(ctx, params.prompt, true, true); } else { LOG_DBG("use session tokens\n"); embd_inp = session_tokens; } - LOG_DBG("prompt: \"%s\"\n", prompt.c_str()); + LOG_DBG("prompt: \"%s\"\n", params.prompt.c_str()); LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str()); } @@ -420,6 +419,12 @@ int main(int argc, char ** argv) { LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + if (params.conversation) { + llama_cli_chat chat(params, ctx, smpl); + g_chat = &chat; + chat.run(); + } + // group-attention state // number of grouped KV tokens so far (used only if params.grp_attn_n > 1) int ga_i = 0; @@ -752,10 +757,6 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); is_antiprompt = true; } - - if (params.enable_chat_template) { - chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str()); - } is_interacting = true; LOG("\n"); } @@ -818,9 +819,7 @@ int main(int argc, char ** argv) { } bool format_chat = params.conversation && params.enable_chat_template; - std::string user_inp = format_chat - ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) - : std::move(buffer); + std::string user_inp = std::move(buffer); // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true); const auto line_inp = common_tokenize(ctx, user_inp, false, format_chat); From d07c9f6a7a15ad9a75a0c62377e06ed7c9d864cc Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 12 Jan 2025 13:27:27 +0100 Subject: [PATCH 2/3] adapt --- examples/main/chat.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/main/chat.hpp b/examples/main/chat.hpp index 801e7acdc1be6..e17cbfa020b8f 100644 --- a/examples/main/chat.hpp +++ b/examples/main/chat.hpp @@ -10,6 +10,7 @@ struct llama_cli_chat { struct llama_context * ctx; const struct llama_model * model; + const struct llama_vocab * vocab; struct common_sampler * smpl; struct common_params params; @@ -26,6 +27,7 @@ struct llama_cli_chat { struct llama_context * ctx, struct common_sampler * smpl) : ctx(ctx), smpl(smpl), params(params) { model = llama_get_model(ctx); + vocab = llama_model_get_vocab(model); batch = llama_batch_init(params.n_batch, 0, 1); } @@ -130,7 +132,7 @@ struct llama_cli_chat { new_token_id = common_sampler_sample(smpl, ctx, -1); // is it an end of generation? - if (llama_token_is_eog(model, new_token_id)) { + if (llama_vocab_is_eog(vocab, new_token_id)) { break; } From c2b26000c3edebd20080eec9f40d2ae5e2e65005 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 12 Jan 2025 13:54:25 +0100 Subject: [PATCH 3/3] remove reference to params.conversation in main --- examples/main/main.cpp | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e3ecac0d0e9f1..1ba4dcb3dfcbc 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -40,7 +40,6 @@ static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; static std::vector * g_output_tokens; static bool is_interacting = false; -static bool need_insert_eot = false; static void print_usage(int argc, char ** argv) { (void) argc; @@ -70,7 +69,6 @@ static void sigint_handler(int signo) { g_chat->interrupt(); } else if (!is_interacting && g_params->interactive) { is_interacting = true; - need_insert_eot = true; } else { console::cleanup(); LOG("\n"); @@ -763,26 +761,16 @@ int main(int argc, char ** argv) { } } - // if current token is not EOG, we add it to current assistant message - if (params.conversation) { - const auto id = common_sampler_last(smpl); - assistant_ss << common_token_to_piece(ctx, id, false); - } - if (n_past > 0 && is_interacting) { LOG_DBG("waiting for user input\n"); - if (params.conversation) { - LOG("\n> "); - } - if (params.input_prefix_bos) { LOG_DBG("adding input prefix BOS token\n"); embd_inp.push_back(llama_vocab_bos(vocab)); } std::string buffer; - if (!params.input_prefix.empty() && !params.conversation) { + if (!params.input_prefix.empty()) { LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str()); LOG("%s", params.input_prefix.c_str()); } @@ -806,7 +794,7 @@ int main(int argc, char ** argv) { // Entering a empty line lets the user pass control back if (buffer.length() > 1) { // append input suffix if any - if (!params.input_suffix.empty() && !params.conversation) { + if (!params.input_suffix.empty()) { LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str()); LOG("%s", params.input_suffix.c_str()); } @@ -819,22 +807,14 @@ int main(int argc, char ** argv) { string_process_escapes(buffer); } - bool format_chat = params.conversation && params.enable_chat_template; std::string user_inp = std::move(buffer); // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true); - const auto line_inp = common_tokenize(ctx, user_inp, false, format_chat); + const auto line_inp = common_tokenize(ctx, user_inp, false, true); const auto line_sfx = common_tokenize(ctx, params.input_suffix, false, true); LOG_DBG("input tokens: %s\n", string_from(ctx, line_inp).c_str()); - // if user stop generation mid-way, we must add EOT to finish model's last response - if (need_insert_eot && format_chat) { - llama_token eot = llama_vocab_eot(vocab); - embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_vocab_eos(vocab) : eot); - need_insert_eot = false; - } - embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end()); embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());