Skip to content

Commit b495931

Browse files
authored
Merge pull request #10 from namtranase/dev
update: completion function
2 parents 6589e94 + 5ec2dbc commit b495931

File tree

5 files changed

+67
-10
lines changed

5 files changed

+67
-10
lines changed

Diff for: README.md

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# gemma-cpp-python: Python Bindings for [gemma.cpp](https://github.com/google/gemma.cpp)
22

3-
**Latest Version: v0.1.2**
4-
- Support Completion function
5-
- Fix the MacOS pip install
3+
**Latest Version: v0.1.3**
4+
- Interface changes due to updates in gemma.cpp.
5+
- Enhanced user experience for ease of use 🙏. Give it a try!
66

77
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
88

@@ -18,7 +18,7 @@ Special thanks to the creators and contributors of [gemma.cpp](https://github.co
1818
### Install from PyPI
1919
For a quick setup, install directly from PyPI:
2020
```bash
21-
pip install pygemma==0.1.2
21+
pip install pygemma==0.1.3
2222
```
2323

2424
### For Developers: Install from Source
@@ -32,7 +32,7 @@ cd gemma-cpp-python
3232

3333
2. Install Python dependencies and pygemma:
3434
```bash
35-
pip install -r requirements.txt && pip install .
35+
pip install .
3636
```
3737

3838
## 🖥 Usage
@@ -41,8 +41,12 @@ To acctually run the model, you need to install the model followed on the [gemma
4141

4242
For usage examples, refer to tests/test_chat.py. Here's a quick start:
4343
```bash
44-
import pygemma
45-
pygemma.show_help()
44+
from pygemma import Gemma
45+
gemma = Gemma()
46+
gemma.show_help()
47+
gemma.show_config()
48+
gemma.load_model("/path/to/tokenizer", "/path/to/compressed_weight/", "model_type")
49+
gemma.completion("Write a poem")
4650
```
4751

4852
## 🤝 Contributing

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def build_extension(self, ext):
5656

5757
setup(
5858
name="pygemma",
59-
version="0.1.2",
59+
version="0.1.3",
6060
author="Nam Tran",
6161
author_email="[email protected]",
6262
description="A Python package with a C++ backend using gemma.cpp",

Diff for: src/gemma_binding.cpp

+52-1
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,57 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
182182
<< "command line flag.\n";
183183
}
184184

185+
186+
std::vector<int> tokenize(
187+
const std::string& prompt_string,
188+
const sentencepiece::SentencePieceProcessor* tokenizer) {
189+
std::string formatted = "<start_of_turn>user\n" + prompt_string +
190+
"<end_of_turn>\n<start_of_turn>model\n";
191+
std::vector<int> tokens;
192+
HWY_ASSERT(tokenizer->Encode(formatted, &tokens).ok());
193+
tokens.insert(tokens.begin(), 2); // BOS token
194+
return tokens;
195+
}
196+
197+
int GemmaWrapper::completionPrompt(std::string& prompt) {
198+
size_t pos = 0; // KV Cache position
199+
size_t num_threads = static_cast<size_t>(std::clamp(
200+
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
201+
hwy::ThreadPool pool(num_threads);
202+
// Initialize random number generator
203+
std::mt19937 gen;
204+
std::random_device rd;
205+
gen.seed(rd());
206+
207+
// Tokenize instruction
208+
std::vector<int> tokens =
209+
tokenize(prompt, this->m_model->Tokenizer());
210+
size_t ntokens = tokens.size();
211+
212+
// This callback function gets invoked everytime a token is generated
213+
auto stream_token = [&pos, &gen, &ntokens, tokenizer = this->m_model->Tokenizer()](
214+
int token, float) {
215+
++pos;
216+
if (pos < ntokens) {
217+
// print feedback
218+
} else if (token != gcpp::EOS_ID) {
219+
std::string token_text;
220+
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text).ok());
221+
std::cout << token_text << std::flush;
222+
}
223+
return true;
224+
};
225+
226+
GenerateGemma(*this->m_model,
227+
{.max_tokens = 2048,
228+
.max_generated_tokens = 1024,
229+
.temperature = 1.0,
230+
.verbosity = 0},
231+
tokens, /*KV cache position = */ 0, this->m_kvcache, pool,
232+
stream_token, gen);
233+
std::cout << std::endl;
234+
}
235+
185236
void GemmaWrapper::loadModel(const std::vector<std::string> &args) {
186237
int argc = args.size() + 1; // +1 for the program name
187238
std::vector<char *> argv_vec;
@@ -269,5 +320,5 @@ PYBIND11_MODULE(pygemma, m) {
269320
};
270321
self.loadModel(args); // Assuming GemmaWrapper::loadModel accepts std::vector<std::string>
271322
}, py::arg("tokenizer"), py::arg("compressed_weights"), py::arg("model"))
272-
.def("completion", &GemmaWrapper::completionPrompt);
323+
.def("completion", &GemmaWrapper::completionPrompt, "Function that completes given prompt.");
273324
}

Diff for: src/gemma_binding.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class GemmaWrapper {
3535
void loadModel(const std::vector<std::string> &args); // Consider exception safety
3636
void showConfig();
3737
void showHelp();
38-
std::string completionPrompt();
38+
int completionPrompt(std::string& prompt);
3939

4040
private:
4141
gcpp::LoaderArgs m_loader = gcpp::LoaderArgs(0, nullptr);

Diff for: tests/test_chat.py

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def main():
3232
gemma.show_config()
3333
gemma.show_help()
3434
gemma.load_model(args.tokenizer, args.compressed_weights, args.model)
35+
gemma.completion("Write a poem")
36+
gemma.completion("What is the best war in history")
3537

3638

3739
if __name__ == "__main__":

0 commit comments

Comments
 (0)