From 5ec2dbc2bb25dd3a9f6ea89a06633f9ad99f9b60 Mon Sep 17 00:00:00 2001 From: namtranase Date: Sun, 17 Mar 2024 22:27:39 +0700 Subject: [PATCH] update: completion function --- README.md | 18 +++++++++------ setup.py | 2 +- src/gemma_binding.cpp | 53 ++++++++++++++++++++++++++++++++++++++++++- src/gemma_binding.h | 2 +- tests/test_chat.py | 2 ++ 5 files changed, 67 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index e2c9e0e..d5529ff 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # gemma-cpp-python: Python Bindings for [gemma.cpp](https://github.com/google/gemma.cpp) -**Latest Version: v0.1.2** -- Support Completion function -- Fix the MacOS pip install +**Latest Version: v0.1.3** +- Interface changes due to updates in gemma.cpp. +- Enhanced user experience for ease of use 🙏. Give it a try! [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) @@ -18,7 +18,7 @@ Special thanks to the creators and contributors of [gemma.cpp](https://github.co ### Install from PyPI For a quick setup, install directly from PyPI: ```bash -pip install pygemma==0.1.2 +pip install pygemma==0.1.3 ``` ### For Developers: Install from Source @@ -32,7 +32,7 @@ cd gemma-cpp-python 2. Install Python dependencies and pygemma: ```bash -pip install -r requirements.txt && pip install . +pip install . ``` ## 🖥 Usage @@ -41,8 +41,12 @@ To acctually run the model, you need to install the model followed on the [gemma For usage examples, refer to tests/test_chat.py. Here's a quick start: ```bash -import pygemma -pygemma.show_help() +from pygemma import Gemma +gemma = Gemma() +gemma.show_help() +gemma.show_config() +gemma.load_model("/path/to/tokenizer", "/path/to/compressed_weight/", "model_type") +gemma.completion("Write a poem") ``` ## 🤝 Contributing diff --git a/setup.py b/setup.py index 9bd0c0b..030cffe 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ def build_extension(self, ext): setup( name="pygemma", - version="0.1.2", + version="0.1.3", author="Nam Tran", author_email="namtran.ase@gmail.com", description="A Python package with a C++ backend using gemma.cpp", diff --git a/src/gemma_binding.cpp b/src/gemma_binding.cpp index 8baa3e1..5e177c5 100644 --- a/src/gemma_binding.cpp +++ b/src/gemma_binding.cpp @@ -182,6 +182,57 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, << "command line flag.\n"; } + +std::vector tokenize( + const std::string& prompt_string, + const sentencepiece::SentencePieceProcessor* tokenizer) { + std::string formatted = "user\n" + prompt_string + + "\nmodel\n"; + std::vector tokens; + HWY_ASSERT(tokenizer->Encode(formatted, &tokens).ok()); + tokens.insert(tokens.begin(), 2); // BOS token + return tokens; +} + +int GemmaWrapper::completionPrompt(std::string& prompt) { + size_t pos = 0; // KV Cache position + size_t num_threads = static_cast(std::clamp( + static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); + hwy::ThreadPool pool(num_threads); + // Initialize random number generator + std::mt19937 gen; + std::random_device rd; + gen.seed(rd()); + + // Tokenize instruction + std::vector tokens = + tokenize(prompt, this->m_model->Tokenizer()); + size_t ntokens = tokens.size(); + + // This callback function gets invoked everytime a token is generated + auto stream_token = [&pos, &gen, &ntokens, tokenizer = this->m_model->Tokenizer()]( + int token, float) { + ++pos; + if (pos < ntokens) { + // print feedback + } else if (token != gcpp::EOS_ID) { + std::string token_text; + HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); + std::cout << token_text << std::flush; + } + return true; + }; + + GenerateGemma(*this->m_model, + {.max_tokens = 2048, + .max_generated_tokens = 1024, + .temperature = 1.0, + .verbosity = 0}, + tokens, /*KV cache position = */ 0, this->m_kvcache, pool, + stream_token, gen); + std::cout << std::endl; +} + void GemmaWrapper::loadModel(const std::vector &args) { int argc = args.size() + 1; // +1 for the program name std::vector argv_vec; @@ -269,5 +320,5 @@ PYBIND11_MODULE(pygemma, m) { }; self.loadModel(args); // Assuming GemmaWrapper::loadModel accepts std::vector }, py::arg("tokenizer"), py::arg("compressed_weights"), py::arg("model")) - .def("completion", &GemmaWrapper::completionPrompt); + .def("completion", &GemmaWrapper::completionPrompt, "Function that completes given prompt."); } diff --git a/src/gemma_binding.h b/src/gemma_binding.h index 7f5da9c..c02b70b 100644 --- a/src/gemma_binding.h +++ b/src/gemma_binding.h @@ -35,7 +35,7 @@ class GemmaWrapper { void loadModel(const std::vector &args); // Consider exception safety void showConfig(); void showHelp(); - std::string completionPrompt(); + int completionPrompt(std::string& prompt); private: gcpp::LoaderArgs m_loader = gcpp::LoaderArgs(0, nullptr); diff --git a/tests/test_chat.py b/tests/test_chat.py index d2b84f9..4b5efc8 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -32,6 +32,8 @@ def main(): gemma.show_config() gemma.show_help() gemma.load_model(args.tokenizer, args.compressed_weights, args.model) + gemma.completion("Write a poem") + gemma.completion("What is the best war in history") if __name__ == "__main__":