Skip to content

Commit d8d84ed

Browse files
authored
[Cherry-Pick][Fix] Appropriate whitespace missing in streaming output for Llama2, Mistral models (#1439)
* [Fix] Remove erronous LIB.kv_cache input when using external kv cache management (#1337) * initial commit * initial commit * cleanup * cleanup2 * initial commit * initial commit cleanup more cleaning up specify the tokenizer types PR comments responds add tests rename variable to be more clear
1 parent 32ab7c1 commit d8d84ed

File tree

2 files changed

+82
-3
lines changed

2 files changed

+82
-3
lines changed

Diff for: src/deepsparse/transformers/pipelines/text_generation.py

+56-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516
import datetime
1617
import logging
1718
import os
@@ -578,10 +579,24 @@ def _stream_engine_outputs(
578579
self, engine_outputs, prompts, generation_config, **kwargs
579580
):
580581
for output in engine_outputs:
581-
generated_tokens, generated_logits, finished_reason = output
582+
(
583+
generated_tokens,
584+
generated_logits,
585+
finished_reason,
586+
past_tokens_queue,
587+
) = output
582588
logits = generated_logits if generation_config.output_scores else None
589+
from transformers import LlamaTokenizer, LlamaTokenizerFast
590+
591+
if isinstance(self.tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
592+
# temporary fix for LLama2/Mistral/... models
593+
generated_string = self._generate_streamed_text_from_past_tokens(
594+
generated_tokens, past_tokens_queue
595+
)
596+
else:
597+
generated_string = self.tokenizer.batch_decode(generated_tokens)[0]
583598
generation = self._create_generated_text_output(
584-
self.tokenizer.batch_decode(generated_tokens)[0],
599+
generated_string,
585600
finished_reason[0],
586601
logits,
587602
)
@@ -599,6 +614,33 @@ def _stream_engine_outputs(
599614
**schema_kwargs,
600615
)
601616

617+
def _generate_streamed_text_from_past_tokens(
618+
self, generated_tokens: numpy.ndarray, past_tokens_queue: List[int]
619+
) -> str:
620+
"""
621+
An auxiliary method that helps to properly generate the streamed text.
622+
Some models like llama2 and mistral are using LlamaTokenizer which is
623+
based on SentencePiece tokenizer. This specific tokenizer doesn't seem
624+
to output appropriate prefix spaces when decoding token by token.
625+
One can make it work if the previously generated tokens are included.
626+
This allows the tokenizer to figure out that the appropriate spaces
627+
from last n consecutive tokens.
628+
629+
:param generated_tokens: the generated tokens from the engine
630+
:param past_tokens_queue: the queue of last n tokens (n is the
631+
original prompt length in tokens)
632+
:return: the generated string
633+
"""
634+
string_from_n_tokens = self.tokenizer.decode(
635+
past_tokens_queue, skip_special_tokens=True
636+
)
637+
past_tokens_queue.append(generated_tokens[0])
638+
string_from_n_plus_1_tokens = self.tokenizer.decode(
639+
past_tokens_queue, skip_special_tokens=True
640+
)
641+
past_tokens_queue.pop(0)
642+
return string_from_n_plus_1_tokens[len(string_from_n_tokens) :]
643+
602644
def process_engine_outputs(
603645
self, engine_outputs: List[Union[numpy.ndarray, FinishReason]], **kwargs
604646
) -> TextGenerationOutput:
@@ -733,6 +775,9 @@ def engine_forward(
733775
prompt_logits, session = self.prompt_inference(engine_inputs)
734776

735777
tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist()
778+
# copy the tokens so that we can use them for streaming
779+
past_tokens_queue = copy.copy(tokens)
780+
736781
token_generator = TokenGenerator(
737782
logits_shape=prompt_logits[-1].shape[-1],
738783
tokens=tokens,
@@ -771,6 +816,7 @@ def engine_forward(
771816
numpy.array([generated_tokens[-1]]),
772817
numpy.array([generated_logits[-1]]),
773818
[None],
819+
past_tokens_queue,
774820
)
775821

776822
while len(generated_tokens) < max_tokens:
@@ -811,7 +857,12 @@ def engine_forward(
811857
break
812858

813859
if streaming:
814-
yield (numpy.array([token]), numpy.array([logits]), [None])
860+
yield (
861+
numpy.array([token]),
862+
numpy.array([logits]),
863+
[None],
864+
past_tokens_queue,
865+
)
815866

816867
# Run the autoregressive inference only to put the
817868
# kv cache entry for the last generated token into the
@@ -826,12 +877,14 @@ def engine_forward(
826877
numpy.array([generated_tokens]),
827878
numpy.concatenate(generated_logits, axis=1),
828879
[FinishReason.LENGTH],
880+
past_tokens_queue,
829881
)
830882
else:
831883
yield (
832884
numpy.array([token]),
833885
numpy.array([logits]),
834886
[finished_reason[-1]],
887+
past_tokens_queue,
835888
)
836889

837890
if not streaming:

Diff for: tests/deepsparse/transformers/pipelines/test_text_generation.py

+26
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,29 @@ def test_streaming_mode_returns_generator(pipeline, prompt):
130130
isinstance(response, pipeline.output_schema) for response in response_generator
131131
), "Pipeline should return a generator of output_schema \
132132
objects in streaming mode"
133+
134+
135+
def test_streaming_with_several_prompts(pipeline, prompt):
136+
additional_prompt = "Never gonna run around and desert you"
137+
prompts = [prompt, additional_prompt]
138+
139+
generations_first_prompt_only = list(pipeline(prompt=prompts[0], streaming=True))
140+
generations_second_prompt_only = list(pipeline(prompt=prompts[1], streaming=True))
141+
142+
bag_of_words_first_prompt = [
143+
g.generations[0].text for g in generations_first_prompt_only
144+
]
145+
bag_of_words_second_prompt = [
146+
g.generations[0].text for g in generations_second_prompt_only
147+
]
148+
149+
generations = pipeline(prompt=prompts, streaming=True)
150+
bag_of_words_shared = []
151+
for r in generations:
152+
for gen in r.generations:
153+
text = gen.text
154+
bag_of_words_shared.append(text)
155+
156+
assert sorted(bag_of_words_first_prompt + bag_of_words_second_prompt) == sorted(
157+
bag_of_words_shared
158+
)

0 commit comments

Comments
 (0)