Skip to content

update: completion function #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# gemma-cpp-python: Python Bindings for [gemma.cpp](https://github.com/google/gemma.cpp)

**Latest Version: v0.1.2**
- Support Completion function
- Fix the MacOS pip install
**Latest Version: v0.1.3**
- Interface changes due to updates in gemma.cpp.
- Enhanced user experience for ease of use 🙏. Give it a try!

[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)

Expand All @@ -18,7 +18,7 @@ Special thanks to the creators and contributors of [gemma.cpp](https://github.co
### Install from PyPI
For a quick setup, install directly from PyPI:
```bash
pip install pygemma==0.1.2
pip install pygemma==0.1.3
```

### For Developers: Install from Source
Expand All @@ -32,7 +32,7 @@ cd gemma-cpp-python

2. Install Python dependencies and pygemma:
```bash
pip install -r requirements.txt && pip install .
pip install .
```

## 🖥 Usage
Expand All @@ -41,8 +41,12 @@ To acctually run the model, you need to install the model followed on the [gemma

For usage examples, refer to tests/test_chat.py. Here's a quick start:
```bash
import pygemma
pygemma.show_help()
from pygemma import Gemma
gemma = Gemma()
gemma.show_help()
gemma.show_config()
gemma.load_model("/path/to/tokenizer", "/path/to/compressed_weight/", "model_type")
gemma.completion("Write a poem")
```

## 🤝 Contributing
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def build_extension(self, ext):

setup(
name="pygemma",
version="0.1.2",
version="0.1.3",
author="Nam Tran",
author_email="[email protected]",
description="A Python package with a C++ backend using gemma.cpp",
Expand Down
53 changes: 52 additions & 1 deletion src/gemma_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,57 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
<< "command line flag.\n";
}


std::vector<int> tokenize(
const std::string& prompt_string,
const sentencepiece::SentencePieceProcessor* tokenizer) {
std::string formatted = "<start_of_turn>user\n" + prompt_string +
"<end_of_turn>\n<start_of_turn>model\n";
std::vector<int> tokens;
HWY_ASSERT(tokenizer->Encode(formatted, &tokens).ok());
tokens.insert(tokens.begin(), 2); // BOS token
return tokens;
}

int GemmaWrapper::completionPrompt(std::string& prompt) {
size_t pos = 0; // KV Cache position
size_t num_threads = static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
hwy::ThreadPool pool(num_threads);
// Initialize random number generator
std::mt19937 gen;
std::random_device rd;
gen.seed(rd());

// Tokenize instruction
std::vector<int> tokens =
tokenize(prompt, this->m_model->Tokenizer());
size_t ntokens = tokens.size();

// This callback function gets invoked everytime a token is generated
auto stream_token = [&pos, &gen, &ntokens, tokenizer = this->m_model->Tokenizer()](
int token, float) {
++pos;
if (pos < ntokens) {
// print feedback
} else if (token != gcpp::EOS_ID) {
std::string token_text;
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
std::cout << token_text << std::flush;
}
return true;
};

GenerateGemma(*this->m_model,
{.max_tokens = 2048,
.max_generated_tokens = 1024,
.temperature = 1.0,
.verbosity = 0},
tokens, /*KV cache position = */ 0, this->m_kvcache, pool,
stream_token, gen);
std::cout << std::endl;
}

void GemmaWrapper::loadModel(const std::vector<std::string> &args) {
int argc = args.size() + 1; // +1 for the program name
std::vector<char *> argv_vec;
Expand Down Expand Up @@ -269,5 +320,5 @@ PYBIND11_MODULE(pygemma, m) {
};
self.loadModel(args); // Assuming GemmaWrapper::loadModel accepts std::vector<std::string>
}, py::arg("tokenizer"), py::arg("compressed_weights"), py::arg("model"))
.def("completion", &GemmaWrapper::completionPrompt);
.def("completion", &GemmaWrapper::completionPrompt, "Function that completes given prompt.");
}
2 changes: 1 addition & 1 deletion src/gemma_binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class GemmaWrapper {
void loadModel(const std::vector<std::string> &args); // Consider exception safety
void showConfig();
void showHelp();
std::string completionPrompt();
int completionPrompt(std::string& prompt);

private:
gcpp::LoaderArgs m_loader = gcpp::LoaderArgs(0, nullptr);
Expand Down
2 changes: 2 additions & 0 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def main():
gemma.show_config()
gemma.show_help()
gemma.load_model(args.tokenizer, args.compressed_weights, args.model)
gemma.completion("Write a poem")
gemma.completion("What is the best war in history")


if __name__ == "__main__":
Expand Down