Skip to content

Commit 403fb31

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add a tokenizer python script (#1611)
Summary: Add a tokenizer python script that adds some post processing to the vanila `sentencepiece` tokenizer model. This comes in handy when we want to consume it in C++. Differential Revision: D52821402 Pulled By: larryliu0820
1 parent 05d169b commit 403fb31

File tree

4 files changed

+190
-0
lines changed

4 files changed

+190
-0
lines changed

examples/models/llama2/tokenizer/__init__.py

Whitespace-only changes.

examples/models/llama2/tokenizer/test/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import struct
9+
import tempfile
10+
import unittest
11+
from unittest.mock import patch
12+
13+
from executorch.examples.models.llama2.tokenizer.tokenizer import Tokenizer
14+
15+
16+
class TestTokenizer(unittest.TestCase):
17+
@patch(
18+
"executorch.examples.models.llama2.tokenizer.tokenizer.SentencePieceProcessor"
19+
)
20+
def test_export(self, mock_sp):
21+
# Set up the mock SentencePieceProcessor
22+
mock_sp.return_value.vocab_size.return_value = 0
23+
mock_sp.return_value.bos_id.return_value = 1
24+
mock_sp.return_value.eos_id.return_value = 2
25+
mock_sp.return_value.get_piece_size.return_value = 0
26+
# Create a temporary file
27+
with tempfile.NamedTemporaryFile(delete=True) as temp:
28+
# Initialize the tokenizer with the temporary file as the model
29+
tokenizer = Tokenizer(temp.name)
30+
# Export the tokenizer to another temporary file
31+
with tempfile.NamedTemporaryFile(delete=True) as output:
32+
tokenizer.export(output.name)
33+
# Open the output file in binary mode and read the first 16 bytes
34+
with open(output.name, "rb") as f:
35+
data = f.read(16)
36+
# Unpack the data as 4 integers
37+
vocab_size, bos_id, eos_id, max_token_length = struct.unpack(
38+
"IIII", data
39+
)
40+
# Check that the integers match the properties of the tokenizer
41+
self.assertEqual(vocab_size, 0)
42+
self.assertEqual(bos_id, 1)
43+
self.assertEqual(eos_id, 2)
44+
# Check that the max token length is correct
45+
self.assertEqual(max_token_length, 0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
# Script to rewrite tokenizer model given by sentencepiece, with lightweight
9+
# postprocessing logic.
10+
11+
import argparse
12+
import logging
13+
import os
14+
import struct
15+
from typing import List
16+
17+
from sentencepiece import SentencePieceProcessor as SentencePieceProcessor
18+
19+
20+
class Tokenizer:
21+
def __init__(self, model_path: str):
22+
assert os.path.isfile(
23+
model_path
24+
), f"Need a valid tokenizer model path but got {model_path}"
25+
self.sp_model = SentencePieceProcessor(model_file=model_path)
26+
self.model_path = model_path
27+
28+
# BOS / EOS token IDs
29+
self.n_words: int = self.sp_model.vocab_size()
30+
self.bos_id: int = self.sp_model.bos_id()
31+
self.eos_id: int = self.sp_model.eos_id()
32+
logging.info(
33+
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
34+
)
35+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
36+
37+
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
38+
assert type(s) is str
39+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
40+
t = self.sp_model.encode(s)
41+
if bos:
42+
t = [self.bos_id] + t
43+
if eos:
44+
t = t + [self.eos_id]
45+
return t
46+
47+
def decode(self, t: List[int]) -> str:
48+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
49+
return self.sp_model.decode(t)
50+
51+
def export(self, output_path: str, *, prepend_padding: bool = False) -> None:
52+
"""
53+
Export tokenizer.model to another serialization format. Here we did some lightweight
54+
processing such as supporting prepend padding token, prepend max token length and
55+
replace '_' back to empty space.
56+
57+
The binary format is:
58+
1. vocab size: int32
59+
2. bos id: int32
60+
3. eos id: int32
61+
4. max token length: int32
62+
5. score: float32, len of bytes: int32, token bytes: [byte] for each token
63+
64+
:param output_path: output path of the new binary.
65+
:param prepend_padding: a boolean to control if we want to prepend a padding token.
66+
67+
:return: None
68+
"""
69+
70+
# get all the tokens (postprocessed) and their scores as floats
71+
tokens, scores = [], []
72+
73+
if prepend_padding:
74+
# Here we use the default padding token and its score.
75+
tokens.append("<pad>".encode("utf-8"))
76+
scores.append(-1)
77+
78+
for i in range(self.n_words):
79+
80+
# decode the token and light postprocessing
81+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `id_to_piece`.
82+
t = self.sp_model.id_to_piece(i)
83+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_score`.
84+
s = self.sp_model.get_score(i)
85+
# sentencepiece use '<s>' as BOS and '</s>' for EOS
86+
if i == self.bos_id:
87+
t = "<s>"
88+
elif i == self.eos_id:
89+
t = "</s>"
90+
t = t.replace("▁", " ") # sentencepiece uses this character as whitespace
91+
b = t.encode("utf-8") # bytes of this token, utf-8 encoded
92+
93+
tokens.append(b)
94+
scores.append(s)
95+
96+
# record the max token length
97+
max_token_length = 0 if not tokens else max(len(t) for t in tokens)
98+
99+
# write to a binary file
100+
with open(output_path, "wb") as f:
101+
# write the vocab size, bos/eos ids and max token length
102+
f.write(
103+
struct.pack(
104+
"IIII", self.n_words, self.bos_id, self.eos_id, max_token_length
105+
)
106+
)
107+
for bytes, score in zip(tokens, scores):
108+
f.write(struct.pack("fI", score, len(bytes)))
109+
f.write(bytes)
110+
logging.info(f"Wrote tokenizer to {output_path}")
111+
112+
113+
if __name__ == "__main__":
114+
parser = argparse.ArgumentParser()
115+
parser.add_argument(
116+
"-t",
117+
"--tokenizer-model",
118+
type=str,
119+
default="tokenizer.model",
120+
help="path to tokenizer model, given by sentencepiece",
121+
)
122+
parser.add_argument(
123+
"-o",
124+
"--output-path",
125+
type=str,
126+
default=None,
127+
help="output path of postprocessed tokenizer model",
128+
)
129+
parser.add_argument(
130+
"-p",
131+
"--prepend-padding",
132+
action="store_true",
133+
help="whether to prepend a padding token to the beginning of the tokenizer",
134+
)
135+
136+
args = parser.parse_args()
137+
138+
t = Tokenizer(args.tokenizer_model)
139+
140+
output_path = (
141+
args.output_path
142+
if args.output_path
143+
else args.tokenizer_model.replace(".model", ".bin")
144+
)
145+
t.export(output_path, prepend_padding=args.prepend_padding)

0 commit comments

Comments
 (0)