12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import copy
15
16
import datetime
16
17
import logging
17
18
import os
@@ -578,10 +579,24 @@ def _stream_engine_outputs(
578
579
self , engine_outputs , prompts , generation_config , ** kwargs
579
580
):
580
581
for output in engine_outputs :
581
- generated_tokens , generated_logits , finished_reason = output
582
+ (
583
+ generated_tokens ,
584
+ generated_logits ,
585
+ finished_reason ,
586
+ past_tokens_queue ,
587
+ ) = output
582
588
logits = generated_logits if generation_config .output_scores else None
589
+ from transformers import LlamaTokenizer , LlamaTokenizerFast
590
+
591
+ if isinstance (self .tokenizer , (LlamaTokenizer , LlamaTokenizerFast )):
592
+ # temporary fix for LLama2/Mistral/... models
593
+ generated_string = self ._generate_streamed_text_from_past_tokens (
594
+ generated_tokens , past_tokens_queue
595
+ )
596
+ else :
597
+ generated_string = self .tokenizer .batch_decode (generated_tokens )[0 ]
583
598
generation = self ._create_generated_text_output (
584
- self . tokenizer . batch_decode ( generated_tokens )[ 0 ] ,
599
+ generated_string ,
585
600
finished_reason [0 ],
586
601
logits ,
587
602
)
@@ -599,6 +614,33 @@ def _stream_engine_outputs(
599
614
** schema_kwargs ,
600
615
)
601
616
617
+ def _generate_streamed_text_from_past_tokens (
618
+ self , generated_tokens : numpy .ndarray , past_tokens_queue : List [int ]
619
+ ) -> str :
620
+ """
621
+ An auxiliary method that helps to properly generate the streamed text.
622
+ Some models like llama2 and mistral are using LlamaTokenizer which is
623
+ based on SentencePiece tokenizer. This specific tokenizer doesn't seem
624
+ to output appropriate prefix spaces when decoding token by token.
625
+ One can make it work if the previously generated tokens are included.
626
+ This allows the tokenizer to figure out that the appropriate spaces
627
+ from last n consecutive tokens.
628
+
629
+ :param generated_tokens: the generated tokens from the engine
630
+ :param past_tokens_queue: the queue of last n tokens (n is the
631
+ original prompt length in tokens)
632
+ :return: the generated string
633
+ """
634
+ string_from_n_tokens = self .tokenizer .decode (
635
+ past_tokens_queue , skip_special_tokens = True
636
+ )
637
+ past_tokens_queue .append (generated_tokens [0 ])
638
+ string_from_n_plus_1_tokens = self .tokenizer .decode (
639
+ past_tokens_queue , skip_special_tokens = True
640
+ )
641
+ past_tokens_queue .pop (0 )
642
+ return string_from_n_plus_1_tokens [len (string_from_n_tokens ) :]
643
+
602
644
def process_engine_outputs (
603
645
self , engine_outputs : List [Union [numpy .ndarray , FinishReason ]], ** kwargs
604
646
) -> TextGenerationOutput :
@@ -733,6 +775,9 @@ def engine_forward(
733
775
prompt_logits , session = self .prompt_inference (engine_inputs )
734
776
735
777
tokens = engine_inputs [0 ][engine_inputs [1 ].nonzero ()].tolist ()
778
+ # copy the tokens so that we can use them for streaming
779
+ past_tokens_queue = copy .copy (tokens )
780
+
736
781
token_generator = TokenGenerator (
737
782
logits_shape = prompt_logits [- 1 ].shape [- 1 ],
738
783
tokens = tokens ,
@@ -771,6 +816,7 @@ def engine_forward(
771
816
numpy .array ([generated_tokens [- 1 ]]),
772
817
numpy .array ([generated_logits [- 1 ]]),
773
818
[None ],
819
+ past_tokens_queue ,
774
820
)
775
821
776
822
while len (generated_tokens ) < max_tokens :
@@ -811,7 +857,12 @@ def engine_forward(
811
857
break
812
858
813
859
if streaming :
814
- yield (numpy .array ([token ]), numpy .array ([logits ]), [None ])
860
+ yield (
861
+ numpy .array ([token ]),
862
+ numpy .array ([logits ]),
863
+ [None ],
864
+ past_tokens_queue ,
865
+ )
815
866
816
867
# Run the autoregressive inference only to put the
817
868
# kv cache entry for the last generated token into the
@@ -826,12 +877,14 @@ def engine_forward(
826
877
numpy .array ([generated_tokens ]),
827
878
numpy .concatenate (generated_logits , axis = 1 ),
828
879
[FinishReason .LENGTH ],
880
+ past_tokens_queue ,
829
881
)
830
882
else :
831
883
yield (
832
884
numpy .array ([token ]),
833
885
numpy .array ([logits ]),
834
886
[finished_reason [- 1 ]],
887
+ past_tokens_queue ,
835
888
)
836
889
837
890
if not streaming :
0 commit comments