Skip to content

Commit 11c9225

Browse files
authored
initial commit (#1431)
cleanup more cleaning up specify the tokenizer types PR comments responds add tests rename variable to be more clear
1 parent 83a511d commit 11c9225

File tree

2 files changed

+82
-3
lines changed

2 files changed

+82
-3
lines changed

src/deepsparse/transformers/pipelines/text_generation.py

+56-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516
import datetime
1617
import logging
1718
import os
@@ -580,10 +581,24 @@ def _stream_engine_outputs(
580581
self, engine_outputs, prompts, generation_config, **kwargs
581582
):
582583
for output in engine_outputs:
583-
generated_tokens, generated_logits, finished_reason = output
584+
(
585+
generated_tokens,
586+
generated_logits,
587+
finished_reason,
588+
past_tokens_queue,
589+
) = output
584590
logits = generated_logits if generation_config.output_scores else None
591+
from transformers import LlamaTokenizer, LlamaTokenizerFast
592+
593+
if isinstance(self.tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
594+
# temporary fix for LLama2/Mistral/... models
595+
generated_string = self._generate_streamed_text_from_past_tokens(
596+
generated_tokens, past_tokens_queue
597+
)
598+
else:
599+
generated_string = self.tokenizer.batch_decode(generated_tokens)[0]
585600
generation = self._create_generated_text_output(
586-
self.tokenizer.batch_decode(generated_tokens)[0],
601+
generated_string,
587602
finished_reason[0],
588603
logits,
589604
)
@@ -601,6 +616,33 @@ def _stream_engine_outputs(
601616
**schema_kwargs,
602617
)
603618

619+
def _generate_streamed_text_from_past_tokens(
620+
self, generated_tokens: numpy.ndarray, past_tokens_queue: List[int]
621+
) -> str:
622+
"""
623+
An auxiliary method that helps to properly generate the streamed text.
624+
Some models like llama2 and mistral are using LlamaTokenizer which is
625+
based on SentencePiece tokenizer. This specific tokenizer doesn't seem
626+
to output appropriate prefix spaces when decoding token by token.
627+
One can make it work if the previously generated tokens are included.
628+
This allows the tokenizer to figure out that the appropriate spaces
629+
from last n consecutive tokens.
630+
631+
:param generated_tokens: the generated tokens from the engine
632+
:param past_tokens_queue: the queue of last n tokens (n is the
633+
original prompt length in tokens)
634+
:return: the generated string
635+
"""
636+
string_from_n_tokens = self.tokenizer.decode(
637+
past_tokens_queue, skip_special_tokens=True
638+
)
639+
past_tokens_queue.append(generated_tokens[0])
640+
string_from_n_plus_1_tokens = self.tokenizer.decode(
641+
past_tokens_queue, skip_special_tokens=True
642+
)
643+
past_tokens_queue.pop(0)
644+
return string_from_n_plus_1_tokens[len(string_from_n_tokens) :]
645+
604646
def process_engine_outputs(
605647
self, engine_outputs: List[Union[numpy.ndarray, FinishReason]], **kwargs
606648
) -> TextGenerationOutput:
@@ -738,6 +780,9 @@ def engine_forward(
738780
prompt_logits, session = self.prompt_inference(engine_inputs)
739781

740782
tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()
783+
# copy the tokens so that we can use them for streaming
784+
past_tokens_queue = copy.copy(tokens)
785+
741786
token_generator = TokenGenerator(
742787
logits_shape=prompt_logits[-1].shape[-1],
743788
tokens=tokens,
@@ -776,6 +821,7 @@ def engine_forward(
776821
numpy.array([generated_tokens[-1]]),
777822
numpy.array([generated_logits[-1]]),
778823
[None],
824+
past_tokens_queue,
779825
)
780826

781827
while len(generated_tokens) < max_tokens:
@@ -816,7 +862,12 @@ def engine_forward(
816862
break
817863

818864
if streaming:
819-
yield (numpy.array([token]), numpy.array([logits]), [None])
865+
yield (
866+
numpy.array([token]),
867+
numpy.array([logits]),
868+
[None],
869+
past_tokens_queue,
870+
)
820871

821872
# Run the autoregressive inference only to put the
822873
# kv cache entry for the last generated token into the
@@ -831,12 +882,14 @@ def engine_forward(
831882
numpy.array([generated_tokens]),
832883
numpy.concatenate(generated_logits, axis=1),
833884
[FinishReason.LENGTH],
885+
past_tokens_queue,
834886
)
835887
else:
836888
yield (
837889
numpy.array([token]),
838890
numpy.array([logits]),
839891
[finished_reason[-1]],
892+
past_tokens_queue,
840893
)
841894

842895
if not streaming:

tests/deepsparse/transformers/pipelines/test_text_generation.py

+26
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,29 @@ def test_streaming_mode_returns_generator(pipeline, prompt):
130130
isinstance(response, pipeline.output_schema) for response in response_generator
131131
), "Pipeline should return a generator of output_schema \
132132
objects in streaming mode"
133+
134+
135+
def test_streaming_with_several_prompts(pipeline, prompt):
136+
additional_prompt = "Never gonna run around and desert you"
137+
prompts = [prompt, additional_prompt]
138+
139+
generations_first_prompt_only = list(pipeline(prompt=prompts[0], streaming=True))
140+
generations_second_prompt_only = list(pipeline(prompt=prompts[1], streaming=True))
141+
142+
bag_of_words_first_prompt = [
143+
g.generations[0].text for g in generations_first_prompt_only
144+
]
145+
bag_of_words_second_prompt = [
146+
g.generations[0].text for g in generations_second_prompt_only
147+
]
148+
149+
generations = pipeline(prompt=prompts, streaming=True)
150+
bag_of_words_shared = []
151+
for r in generations:
152+
for gen in r.generations:
153+
text = gen.text
154+
bag_of_words_shared.append(text)
155+
156+
assert sorted(bag_of_words_first_prompt + bag_of_words_second_prompt) == sorted(
157+
bag_of_words_shared
158+
)

0 commit comments

Comments
 (0)