Skip to content

Commit 9bac61e

Browse files
dsikkaDipika Sikka
and
Dipika Sikka
authored
[TextGeneration] Fix llama tokenizer (#1635)
* add llama tokenizer fix * fix generated string * only run for streaming * add TODO --------- Co-authored-by: Dipika Sikka <[email protected]>
1 parent e09ae26 commit 9bac61e

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

Diff for: src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def run(
101101
else [],
102102
"finished_reason": [],
103103
"token_generator": token_generator,
104+
"past_tokens_queue": copy.copy(tokens),
104105
}
105106

106107
if kv_cache is None:

Diff for: src/deepsparse/transformers/pipelines/text_generation/process_outputs.py

+46-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import datetime
15-
from typing import Optional
15+
from typing import List, Optional
1616

1717
import numpy
1818

@@ -54,6 +54,33 @@ def _create_generated_text_output(
5454
finished=False,
5555
)
5656

57+
def _generate_streamed_text_from_past_tokens(
58+
self, generated_tokens: numpy.ndarray, past_tokens_queue: List[int]
59+
) -> str:
60+
"""
61+
An auxiliary method that helps to properly generate the streamed text.
62+
Some models like llama2 and mistral are using LlamaTokenizer which is
63+
based on SentencePiece tokenizer. This specific tokenizer doesn't seem
64+
to output appropriate prefix spaces when decoding token by token.
65+
One can make it work if the previously generated tokens are included.
66+
This allows the tokenizer to figure out that the appropriate spaces
67+
from last n consecutive tokens.
68+
69+
:param generated_tokens: the generated tokens from the engine
70+
:param past_tokens_queue: the queue of last n tokens (n is the
71+
original prompt length in tokens)
72+
:return: the generated string
73+
"""
74+
string_from_n_tokens = self.tokenizer.decode(
75+
past_tokens_queue, skip_special_tokens=True
76+
)
77+
past_tokens_queue.append(generated_tokens[0])
78+
string_from_n_plus_1_tokens = self.tokenizer.decode(
79+
past_tokens_queue, skip_special_tokens=True
80+
)
81+
past_tokens_queue.pop(0)
82+
return [string_from_n_plus_1_tokens[len(string_from_n_tokens) :]]
83+
5784
def run(
5885
self,
5986
generated_tokens: numpy.ndarray,
@@ -64,9 +91,24 @@ def run(
6491
):
6592
generation_config = inference_state.current_state.get("generation_config")
6693
generated_logits = generated_logits if generation_config.output_scores else None
67-
sequences = self.tokenizer.batch_decode(
68-
generated_tokens, skip_special_tokens=True
69-
)
94+
95+
import transformers
96+
97+
# Fix for LLAMA-specific models when running streaming
98+
# TODO: make streaming a conditional input to this operator. using inference
99+
# state is a quick fix.
100+
if isinstance(
101+
self.tokenizer,
102+
(transformers.LlamaTokenizer, transformers.LlamaTokenizerFast),
103+
) and inference_state.current_state.get("streaming"):
104+
past_tokens_queue = inference_state.current_state.get("past_tokens_queue")
105+
sequences = self._generate_streamed_text_from_past_tokens(
106+
generated_tokens, past_tokens_queue
107+
)
108+
else:
109+
sequences = self.tokenizer.batch_decode(
110+
generated_tokens, skip_special_tokens=True
111+
)
70112

71113
try:
72114
finished_reason = [f[-1] for f in finished_reason]

0 commit comments

Comments
 (0)