12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import copy
15
16
import datetime
16
17
import logging
17
18
import os
@@ -580,10 +581,24 @@ def _stream_engine_outputs(
580
581
self , engine_outputs , prompts , generation_config , ** kwargs
581
582
):
582
583
for output in engine_outputs :
583
- generated_tokens , generated_logits , finished_reason = output
584
+ (
585
+ generated_tokens ,
586
+ generated_logits ,
587
+ finished_reason ,
588
+ past_tokens_queue ,
589
+ ) = output
584
590
logits = generated_logits if generation_config .output_scores else None
591
+ from transformers import LlamaTokenizer , LlamaTokenizerFast
592
+
593
+ if isinstance (self .tokenizer , (LlamaTokenizer , LlamaTokenizerFast )):
594
+ # temporary fix for LLama2/Mistral/... models
595
+ generated_string = self ._generate_streamed_text_from_past_tokens (
596
+ generated_tokens , past_tokens_queue
597
+ )
598
+ else :
599
+ generated_string = self .tokenizer .batch_decode (generated_tokens )[0 ]
585
600
generation = self ._create_generated_text_output (
586
- self . tokenizer . batch_decode ( generated_tokens )[ 0 ] ,
601
+ generated_string ,
587
602
finished_reason [0 ],
588
603
logits ,
589
604
)
@@ -601,6 +616,33 @@ def _stream_engine_outputs(
601
616
** schema_kwargs ,
602
617
)
603
618
619
+ def _generate_streamed_text_from_past_tokens (
620
+ self , generated_tokens : numpy .ndarray , past_tokens_queue : List [int ]
621
+ ) -> str :
622
+ """
623
+ An auxiliary method that helps to properly generate the streamed text.
624
+ Some models like llama2 and mistral are using LlamaTokenizer which is
625
+ based on SentencePiece tokenizer. This specific tokenizer doesn't seem
626
+ to output appropriate prefix spaces when decoding token by token.
627
+ One can make it work if the previously generated tokens are included.
628
+ This allows the tokenizer to figure out that the appropriate spaces
629
+ from last n consecutive tokens.
630
+
631
+ :param generated_tokens: the generated tokens from the engine
632
+ :param past_tokens_queue: the queue of last n tokens (n is the
633
+ original prompt length in tokens)
634
+ :return: the generated string
635
+ """
636
+ string_from_n_tokens = self .tokenizer .decode (
637
+ past_tokens_queue , skip_special_tokens = True
638
+ )
639
+ past_tokens_queue .append (generated_tokens [0 ])
640
+ string_from_n_plus_1_tokens = self .tokenizer .decode (
641
+ past_tokens_queue , skip_special_tokens = True
642
+ )
643
+ past_tokens_queue .pop (0 )
644
+ return string_from_n_plus_1_tokens [len (string_from_n_tokens ) :]
645
+
604
646
def process_engine_outputs (
605
647
self , engine_outputs : List [Union [numpy .ndarray , FinishReason ]], ** kwargs
606
648
) -> TextGenerationOutput :
@@ -738,6 +780,9 @@ def engine_forward(
738
780
prompt_logits , session = self .prompt_inference (engine_inputs )
739
781
740
782
tokens = engine_inputs [0 ][engine_inputs [1 ].nonzero ()].tolist ()
783
+ # copy the tokens so that we can use them for streaming
784
+ past_tokens_queue = copy .copy (tokens )
785
+
741
786
token_generator = TokenGenerator (
742
787
logits_shape = prompt_logits [- 1 ].shape [- 1 ],
743
788
tokens = tokens ,
@@ -776,6 +821,7 @@ def engine_forward(
776
821
numpy .array ([generated_tokens [- 1 ]]),
777
822
numpy .array ([generated_logits [- 1 ]]),
778
823
[None ],
824
+ past_tokens_queue ,
779
825
)
780
826
781
827
while len (generated_tokens ) < max_tokens :
@@ -816,7 +862,12 @@ def engine_forward(
816
862
break
817
863
818
864
if streaming :
819
- yield (numpy .array ([token ]), numpy .array ([logits ]), [None ])
865
+ yield (
866
+ numpy .array ([token ]),
867
+ numpy .array ([logits ]),
868
+ [None ],
869
+ past_tokens_queue ,
870
+ )
820
871
821
872
# Run the autoregressive inference only to put the
822
873
# kv cache entry for the last generated token into the
@@ -831,12 +882,14 @@ def engine_forward(
831
882
numpy .array ([generated_tokens ]),
832
883
numpy .concatenate (generated_logits , axis = 1 ),
833
884
[FinishReason .LENGTH ],
885
+ past_tokens_queue ,
834
886
)
835
887
else :
836
888
yield (
837
889
numpy .array ([token ]),
838
890
numpy .array ([logits ]),
839
891
[finished_reason [- 1 ]],
892
+ past_tokens_queue ,
840
893
)
841
894
842
895
if not streaming :
0 commit comments