From e42ea7553513626c7d32794fc9fbf1ecb6c07deb Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Thu, 18 Jan 2024 22:58:38 -0800 Subject: [PATCH 1/2] Add a tokenizer python script (#1611) Summary: Add a tokenizer python script that adds some post processing to the vanila `sentencepiece` tokenizer model. This comes in handy when we want to consume it in C++. Pull Request resolved: https://github.com/pytorch/executorch/pull/1611 Differential Revision: D52821402 Pulled By: larryliu0820 fbshipit-source-id: a9b10b37a3157f00983c7ce0f0badeefbee1aa4a --- examples/models/llama2/tokenizer/__init__.py | 0 .../models/llama2/tokenizer/test/__init__.py | 0 .../llama2/tokenizer/test/test_tokenizer.py | 45 ++++++ examples/models/llama2/tokenizer/tokenizer.py | 147 ++++++++++++++++++ 4 files changed, 192 insertions(+) create mode 100644 examples/models/llama2/tokenizer/__init__.py create mode 100644 examples/models/llama2/tokenizer/test/__init__.py create mode 100644 examples/models/llama2/tokenizer/test/test_tokenizer.py create mode 100644 examples/models/llama2/tokenizer/tokenizer.py diff --git a/examples/models/llama2/tokenizer/__init__.py b/examples/models/llama2/tokenizer/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/models/llama2/tokenizer/test/__init__.py b/examples/models/llama2/tokenizer/test/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/models/llama2/tokenizer/test/test_tokenizer.py b/examples/models/llama2/tokenizer/test/test_tokenizer.py new file mode 100644 index 00000000000..0c861d6f7b8 --- /dev/null +++ b/examples/models/llama2/tokenizer/test/test_tokenizer.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import struct +import tempfile +import unittest +from unittest.mock import patch + +from executorch.examples.models.llama2.tokenizer.tokenizer import Tokenizer + + +class TestTokenizer(unittest.TestCase): + @patch( + "executorch.examples.models.llama2.tokenizer.tokenizer.SentencePieceProcessor" + ) + def test_export(self, mock_sp): + # Set up the mock SentencePieceProcessor + mock_sp.return_value.vocab_size.return_value = 0 + mock_sp.return_value.bos_id.return_value = 1 + mock_sp.return_value.eos_id.return_value = 2 + mock_sp.return_value.get_piece_size.return_value = 0 + # Create a temporary file + with tempfile.NamedTemporaryFile(delete=True) as temp: + # Initialize the tokenizer with the temporary file as the model + tokenizer = Tokenizer(temp.name) + # Export the tokenizer to another temporary file + with open("/tmp/test.bin", "wb") as output: + tokenizer.export(output.name) + # Open the output file in binary mode and read the first 16 bytes + with open(output.name, "rb") as f: + data = f.read(16) + # Unpack the data as 4 integers + vocab_size, bos_id, eos_id, max_token_length = struct.unpack( + "IIII", data + ) + # Check that the integers match the properties of the tokenizer + self.assertEqual(vocab_size, 0) + self.assertEqual(bos_id, 1) + self.assertEqual(eos_id, 2) + # Check that the max token length is correct + self.assertEqual(max_token_length, 0) diff --git a/examples/models/llama2/tokenizer/tokenizer.py b/examples/models/llama2/tokenizer/tokenizer.py new file mode 100644 index 00000000000..785a49dff94 --- /dev/null +++ b/examples/models/llama2/tokenizer/tokenizer.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# Script to rewrite tokenizer model given by sentencepiece, with lightweight +# postprocessing logic. + +import argparse +import logging +import os +import struct +from typing import List + +from sentencepiece import SentencePieceProcessor as SentencePieceProcessor + + +class Tokenizer: + def __init__(self, model_path: str): + assert os.path.isfile( + model_path + ), f"Need a valid tokenizer model path but got {model_path}" + # pyre-fixme[28]: Unexpected keyword argument `model_file` to call `SentencePieceProcessor.__init__`. + self.sp_model = SentencePieceProcessor(model_file=model_path) + self.model_path = model_path + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + logging.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_piece_size`. + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + assert type(s) is str + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + return self.sp_model.decode(t) + + def export(self, output_path: str, *, prepend_padding: bool = False) -> None: + """ + Export tokenizer.model to another serialization format. Here we did some lightweight + processing such as supporting prepend padding token, prepend max token length and + replace '_' back to empty space. + + The binary format is: + 1. vocab size: int32 + 2. bos id: int32 + 3. eos id: int32 + 4. max token length: int32 + 5. score: float32, len of bytes: int32, token bytes: [byte] for each token + + :param output_path: output path of the new binary. + :param prepend_padding: a boolean to control if we want to prepend a padding token. + + :return: None + """ + + # get all the tokens (postprocessed) and their scores as floats + tokens, scores = [], [] + + if prepend_padding: + # Here we use the default padding token and its score. + tokens.append("".encode("utf-8")) + scores.append(-1) + + for i in range(self.n_words): + + # decode the token and light postprocessing + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `id_to_piece`. + t = self.sp_model.id_to_piece(i) + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_score`. + s = self.sp_model.get_score(i) + # sentencepiece use '' as BOS and '' for EOS + if i == self.bos_id: + t = "" + elif i == self.eos_id: + t = "" + t = t.replace("▁", " ") # sentencepiece uses this character as whitespace + b = t.encode("utf-8") # bytes of this token, utf-8 encoded + + tokens.append(b) + scores.append(s) + + # record the max token length + max_token_length = 0 if not tokens else max(len(t) for t in tokens) + + # write to a binary file + with open(output_path, "wb") as f: + # write the vocab size, bos/eos ids and max token length + f.write( + struct.pack( + "IIII", self.n_words, self.bos_id, self.eos_id, max_token_length + ) + ) + for bytes, score in zip(tokens, scores): + f.write(struct.pack("fI", score, len(bytes))) + f.write(bytes) + logging.info(f"Wrote tokenizer to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-t", + "--tokenizer-model", + type=str, + default="tokenizer.model", + help="path to tokenizer model, given by sentencepiece", + ) + parser.add_argument( + "-o", + "--output-path", + type=str, + default=None, + help="output path of postprocessed tokenizer model", + ) + parser.add_argument( + "-p", + "--prepend-padding", + action="store_true", + help="whether to prepend a padding token to the beginning of the tokenizer", + ) + + args = parser.parse_args() + + t = Tokenizer(args.tokenizer_model) + + output_path = ( + args.output_path + if args.output_path + else args.tokenizer_model.replace(".model", ".bin") + ) + t.export(output_path, prepend_padding=args.prepend_padding) From 239a111fd9751aa7f1e21e4107454a41cf26895c Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Thu, 18 Jan 2024 22:58:49 -0800 Subject: [PATCH 2/2] Add a tokenizer (#1641) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/1641 Add a tokenizer in examples. This needs to consume the artifact generated by `tokenizer.py`. Reviewed By: mikekgfb Differential Revision: D52894344 fbshipit-source-id: df192ad8a8d47da3cd5e70d911095d8dbc237b4c --- examples/models/llama2/tokenizer/targets.bzl | 33 ++ .../models/llama2/tokenizer/test/test.bin | Bin 0 -> 16 bytes .../llama2/tokenizer/test/test_tokenizer.cpp | 55 +++ .../models/llama2/tokenizer/tokenizer.cpp | 336 ++++++++++++++++++ examples/models/llama2/tokenizer/tokenizer.h | 70 ++++ 5 files changed, 494 insertions(+) create mode 100644 examples/models/llama2/tokenizer/targets.bzl create mode 100644 examples/models/llama2/tokenizer/test/test.bin create mode 100644 examples/models/llama2/tokenizer/test/test_tokenizer.cpp create mode 100644 examples/models/llama2/tokenizer/tokenizer.cpp create mode 100644 examples/models/llama2/tokenizer/tokenizer.h diff --git a/examples/models/llama2/tokenizer/targets.bzl b/examples/models/llama2/tokenizer/targets.bzl new file mode 100644 index 00000000000..ff22d87cf41 --- /dev/null +++ b/examples/models/llama2/tokenizer/targets.bzl @@ -0,0 +1,33 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + runtime.cxx_library( + name = "tokenizer_lib", + srcs = ["tokenizer.cpp"], + headers = ["tokenizer.h"], + exported_deps = [ + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = [ + "//executorch/...", + ], + ) + + if not runtime.is_oss: + # no resources support + runtime.export_file( + name = "tokenizer_file", + src = "test/test.bin", + ) + + runtime.cxx_test( + name = "test_tokenizer_cpp", + srcs = ["test/test_tokenizer.cpp"], + deps = [ + ":tokenizer_lib", + "//executorch/codegen:macros", + "fbsource//xplat/tools/cxx:resources", + ], + resources = [":tokenizer_file"], + ) diff --git a/examples/models/llama2/tokenizer/test/test.bin b/examples/models/llama2/tokenizer/test/test.bin new file mode 100644 index 0000000000000000000000000000000000000000..798ad70f157bfcf4e72d34ebf1f8f022f282af30 GIT binary patch literal 16 RcmZQzU|?VbVkRI40RR9j00aO4 literal 0 HcmV?d00001 diff --git a/examples/models/llama2/tokenizer/test/test_tokenizer.cpp b/examples/models/llama2/tokenizer/test/test_tokenizer.cpp new file mode 100644 index 00000000000..acabf00cba2 --- /dev/null +++ b/examples/models/llama2/tokenizer/test/test_tokenizer.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include "tools/cxx/Resources.h" + +using namespace ::testing; + +namespace torch { +namespace executor { + +class TokenizerExtensionTest : public ::testing::Test { + public: + void SetUp() override { + torch::executor::runtime_init(); + modelPath_ = + build::getResourcePath( + "executorch/examples/models/llama2/tokenizer/test/test.bin") + .string(); + tokenizer_ = std::make_unique(32000); + } + + std::unique_ptr tokenizer_; + std::string modelPath_; +}; + +TEST_F(TokenizerExtensionTest, EncodeWithoutLoadFails) { + Error error = tokenizer_->encode("hello world", 0, 0, nullptr, nullptr); + EXPECT_EQ(error, Error::NotSupported); +} + +TEST_F(TokenizerExtensionTest, DecodeWithoutLoadFails) { + auto result = tokenizer_->decode(0, 0); + EXPECT_EQ(result.error(), Error::NotSupported); +} + +TEST_F(TokenizerExtensionTest, TokenizerVocabSizeIsExpected) { + Error res = tokenizer_->load(modelPath_.c_str()); + EXPECT_EQ(res, Error::Ok); + // test.bin has vocab size 0 but the tokenizer respects the vocab size being + // passed in and add placeholder tokens. + EXPECT_EQ(tokenizer_->vocab_size(), 32000); + EXPECT_EQ(tokenizer_->bos_tok(), 1); + EXPECT_EQ(tokenizer_->eos_tok(), 2); +} + +} // namespace executor +} // namespace torch diff --git a/examples/models/llama2/tokenizer/tokenizer.cpp b/examples/models/llama2/tokenizer/tokenizer.cpp new file mode 100644 index 00000000000..ccc67770cb4 --- /dev/null +++ b/examples/models/llama2/tokenizer/tokenizer.cpp @@ -0,0 +1,336 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace torch { +namespace executor { + +static int compare_tokens(const void* a, const void* b) { + if (((TokenIndex*)a)->str == nullptr) { + return -1; + } + if (((TokenIndex*)b)->str == nullptr) { + return 1; + } + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); +} + +Tokenizer::Tokenizer(int32_t vocab_size) + : initialized_(false), + vocab_size_(vocab_size), + vocab_(std::make_unique(vocab_size)), + vocab_scores_(std::make_unique(vocab_size)), + sorted_vocab_(std::make_unique(vocab_size)) { + for (int i = 0; i < 256; i++) { + byte_pieces_[i * 2] = (unsigned char)i; + byte_pieces_[i * 2 + 1] = '\0'; + } +} + +/** + * @brief Load the tokenizer from a file. The tokenizer file contains the + * vocabulary and scores. The format is: the first integer is the maximum + * token length, followed by a list of (word_len, word) pairs. Here we + * are reading all the vocabulary into memory and keep it sorted for fast + * lookup. + * + * @param tokenizer_path The path to the tokenizer file. + * @return Error + */ +Error Tokenizer::load(const char* tokenizer_path) { + if (initialized_) { + ET_LOG(Info, "Tokenizer already initialized"); + return Error::Ok; + } + // read in the file + FILE* file = fopen(tokenizer_path, "rb"); + if (!file) { + ET_LOG(Error, "couldn't load %s", tokenizer_path); + return Error::InvalidArgument; + } + int32_t metadata[4]; + for (int i = 0; i < 4; i++) { + if (fread(metadata + i, sizeof(int32_t), 1, file) != 1) { + ET_LOG( + Error, + "Failed to read the metadata at position %d, the tokenizer file is not valid!", + i); + return Error::InvalidArgument; + } + } + + // now we have two vocab_sizes one from the model and another from the + // tokenizer file. + int32_t tokenizer_vocab_size = metadata[0]; + if (tokenizer_vocab_size < vocab_size_) { + ET_LOG( + Info, + "The tokenizer vocab size %d is smaller than the model vocab size %d, will add padding tokens.", + tokenizer_vocab_size, + vocab_size_); + } else if (tokenizer_vocab_size > vocab_size_) { + ET_LOG( + Info, + "The tokenizer vocab size %d is larger than the model vocab size %d.", + tokenizer_vocab_size, + vocab_size_); + } + + bos_tok_ = metadata[1]; + eos_tok_ = metadata[2]; + max_token_length_ = metadata[3]; + + // allocate space for the vocabulary + vocab_ = std::make_unique(vocab_size_); + vocab_scores_ = std::make_unique(vocab_size_); + sorted_vocab_ = std::make_unique(vocab_size_); + + // read in the vocabulary + for (int i = 0; i < vocab_size_; i++) { + if (fread(vocab_scores_.get() + i, sizeof(float), 1, file) != 1) { + // This is allowed, we just pad the rest of the vocab with strings + std::string padding = ""; + vocab_[i] = new char[padding.length() + 1]; + strcpy(vocab_[i], padding.c_str()); + vocab_[i][padding.length()] = '\0'; + continue; + } + int32_t len; + if (fread(&len, sizeof(int32_t), 1, file) != 1) { + ET_LOG(Error, "Failed to read the length of the word at index %d", i); + return Error::InvalidArgument; + } + vocab_[i] = new char[len + 1]; + if (fread(vocab_[i], len, 1, file) != 1) { + ET_LOG( + Error, + "Failed to read the word, total length %d, index %d\n", + len, + i); + return Error::InvalidArgument; + } + vocab_[i][len] = '\0'; // add the string terminating token + } + fclose(file); + + for (int32_t i = 0; i < vocab_size_; i++) { + sorted_vocab_[i].str = vocab_[i]; + sorted_vocab_[i].id = i; + } + qsort(sorted_vocab_.get(), vocab_size_, sizeof(TokenIndex), compare_tokens); + + initialized_ = true; + return Error::Ok; +} + +Tokenizer::~Tokenizer() { + for (int i = 0; i < vocab_size_; i++) { + delete[] vocab_[i]; + } +} + +/** + * @brief Decode a token into string. + * + * @param prev_token The previous token. + * @param token The current token. + * @return Result A pointer to the string representation of the + * token. + */ +Result Tokenizer::decode(int32_t prev_token, int32_t token) { + if (!initialized_) { + ET_LOG(Error, "Tokenizer not initialized"); + return Error::NotSupported; + } + const char* piece = vocab_[token]; + // following BOS token, sentencepiece decoder strips any leading + // whitespace + if (prev_token == bos_tok_ && piece[0] == ' ') { + piece++; + } + // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' + // parse this and convert and return the actual byte + unsigned char byte_val; + if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { + piece = (char*)byte_pieces_ + byte_val * 2; + } + return piece; +} + +static int32_t +str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) { + // efficiently find the perfect match for str in vocab, return its index or -1 + // if not found + TokenIndex tok = {.str = str}; // acts as the key to search for + TokenIndex* res = (TokenIndex*)bsearch( + &tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + return res != nullptr ? res->id : -1; +} + +/** + * @brief Encode a string into a sequence of tokens. + * + * @param text The string to be encoded. + * @param bos The number of BOS to prepend to the token list. + * @param eos The number of EOS to append to the token list. + * @param tokens The output tokens. + * @param n_tokens The number of tokens. + * @return Error + */ +Error Tokenizer::encode( + const char* text, + int8_t bos, + int8_t eos, + int32_t* tokens, + int32_t* n_tokens) { + if (!initialized_) { + ET_LOG(Error, "Tokenizer not initialized"); + return Error::NotSupported; + } + // encode the string text (input) into an upper-bound preallocated tokens[] + // array bos != 0 means prepend the BOS token (=1), eos != 0 means append the + // EOS token (=2) + if (text == nullptr) { + ET_LOG(Error, "cannot encode null text"); + return Error::InvalidArgument; + } + + // create a temporary buffer that will store merge candidates of always two + // consecutive tokens *2 for concat, +1 for null terminator +2 for UTF8 (in + // case max_token_length is 1) + char* str_buffer = new char[max_token_length_ * 2 + 1 + 2]; + size_t str_len = 0; + + // start at 0 tokens + *n_tokens = 0; + + // add optional BOS token, if desired + if (bos > 0) { + while (bos--) { + tokens[(*n_tokens)++] = bos_tok_; + } + } else { + ET_LOG(Error, "bos %d should be >= 0", bos); + return Error::InvalidArgument; + } + + // add_dummy_prefix is true by default + // so prepend a dummy prefix token to the input string, but only if text != "" + // TODO: pretty sure this isn't correct in the general case but I don't have + // the energy to read more of the sentencepiece code to figure out what it's + // doing + const char* space = " "; + if (text[0] != '\0') { + int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_); + tokens[(*n_tokens)++] = dummy_prefix; + } + + // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: + // Code point ↔ UTF-8 conversion + // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4 + // U+0000 U+007F 0xxxxxxx + // U+0080 U+07FF 110xxxxx 10xxxxxx + // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx + // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + + // process the raw (UTF-8) byte sequence of the input string + for (const char* c = text; *c != '\0'; c++) { + // reset buffer if the current byte is ASCII or a leading byte + // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the + // rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in + // first two bits so in English this is: "if this byte is not a continuation + // byte" + if ((*c & 0xC0) != 0x80) { + // this byte must be either a leading byte (11...) or an ASCII char + // (0x...) + // => reset our location, as we're starting a new UTF-8 codepoint + str_len = 0; + } + + // append the current byte to the buffer + str_buffer[str_len++] = + *c; // ++ is post-increment, incremented after this line + str_buffer[str_len] = '\0'; + + // while the next character is a continuation byte, continue appending + // but if there are too many of them, just stop to avoid overruning + // str_buffer size. + if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) { + continue; + } + + // ok c+1 is not a continuation byte, so we've read in a full codepoint + int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_); + if (id != -1) { + // we found this codepoint in vocab, add it as a token + tokens[(*n_tokens)++] = id; + } else { + // byte_fallback encoding: just encode each byte as a token + // +3 is here because the first 3 vocab elements are , , + // so the individual bytes only start at index 3 + for (int i = 0; i < str_len; i++) { + tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3; + } + } + str_len = 0; // protect against a sequence of stray UTF8 continuation bytes + } + + // merge the best consecutive pair each iteration, according the scores in + // vocab_scores + while (1) { + float best_score = -1e10; + int best_id = -1; + int best_idx = -1; + + for (int i = 0; i < (*n_tokens - 1); i++) { + // check if we can merge the pair (tokens[i], tokens[i+1]) + snprintf( + str_buffer, + max_token_length_ * 2 + 3, + "%s%s", + vocab_[tokens[i]], + vocab_[tokens[i + 1]]); + int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_); + if (id != -1 && vocab_scores_[id] > best_score) { + // this merge pair exists in vocab! record its score and position + best_score = vocab_scores_[id]; + best_id = id; + best_idx = i; + } + } + + if (best_idx == -1) { + break; // we couldn't find any more pairs to merge, so we're done + } + + // merge the consecutive pair (best_idx, best_idx+1) into new token best_id + tokens[best_idx] = best_id; + // delete token at position best_idx+1, shift the entire sequence back 1 + for (int i = best_idx + 1; i < (*n_tokens - 1); i++) { + tokens[i] = tokens[i + 1]; + } + (*n_tokens)--; // token length decreased + } + + // add optional EOS (=2) token, if desired + if (eos >= 0) { + while (eos--) { + tokens[(*n_tokens)++] = eos_tok_; + } + } else { + ET_LOG(Error, "eos %d should be >= 0", eos); + return Error::InvalidArgument; + } + + delete[] str_buffer; + return Error::Ok; +} + +} // namespace executor +} // namespace torch diff --git a/examples/models/llama2/tokenizer/tokenizer.h b/examples/models/llama2/tokenizer/tokenizer.h new file mode 100644 index 00000000000..7bd0d5ee891 --- /dev/null +++ b/examples/models/llama2/tokenizer/tokenizer.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple Byte Pair Encoding (BPE) Tokenizer. Note that the vanila tokenizer +// model won't work with this class, it needs to go through tokenizer.py first. +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace torch { +namespace executor { + +struct TokenIndex { + const char* str; + int32_t id; +}; + +class Tokenizer { + public: + explicit Tokenizer(int32_t vocab_size); + ~Tokenizer(); + + Error load(const char* tokenizer_path); + + Error encode( + const char* text, + int8_t bos, + int8_t eos, + int32_t* tokens, + int32_t* n_tokens); + + Result decode(int prev_token, int token); + + // getters + int32_t vocab_size() const { + return vocab_size_; + } + + int32_t bos_tok() const { + return bos_tok_; + } + + int32_t eos_tok() const { + return eos_tok_; + } + + private: + bool initialized_; + const int32_t vocab_size_; + int32_t bos_tok_, eos_tok_; + std::unique_ptr vocab_; + std::unique_ptr vocab_scores_; + std::unique_ptr sorted_vocab_; + unsigned int max_token_length_; + unsigned char byte_pieces_[512]; // stores all single-byte strings +}; + +} // namespace executor +} // namespace torch