@@ -542,38 +542,22 @@ def __next__(self) -> List[int]:
542
542
return value
543
543
544
544
545
- @torch .inference_mode ()
546
- def inference_stream (model : PreTrainedModel ,
547
- template : Template ,
548
- query : str ,
549
- history : Optional [History ] = None ,
550
- system : Optional [str ] = None ,
551
- images : Optional [List [str ]] = None ,
552
- * ,
553
- generation_config : Optional [GenerationConfig ] = None ,
554
- stop_words : Optional [StopWords ] = None ,
555
- generation_info : Optional [Dict [str , int ]] = None ,
556
- adapter_names : Optional [List [str ]] = None ,
557
- ** kwargs ) -> Iterator [Tuple [str , History ]]:
558
- """
559
- generation_config: Priority: generation_config > model.generation_config.
560
- """
545
+ def _prepare_inputs (model : PreTrainedModel ,
546
+ template : Template ,
547
+ query : str ,
548
+ history : History ,
549
+ system : Optional [str ] = None ,
550
+ images : Optional [List [str ]] = None ,
551
+ * ,
552
+ generation_config : Optional [GenerationConfig ] = None ,
553
+ stop_words : Optional [StopWords ] = None ,
554
+ adapter_names : Optional [List [str ]] = None ,
555
+ ** kwargs ) -> Tuple [Dict [str , Any ], Dict [str , Any ], int ]:
561
556
if stop_words is None :
562
557
stop_words = []
563
- if history is None :
564
- history = []
565
- else :
566
- history = deepcopy (history )
567
558
if images is None :
568
559
images = []
569
560
570
- # agent support
571
- is_observation = history [- 1 ][- 1 ].endswith ('Observation:' ) if history and history [- 1 ][- 1 ] else False
572
- if is_observation :
573
- history [- 1 ][- 1 ] = history [- 1 ][- 1 ] + query
574
- act_length = len (history [- 1 ][- 1 ])
575
- query = None
576
-
577
561
example = {
578
562
'query' : query ,
579
563
'history' : history ,
@@ -587,7 +571,7 @@ def inference_stream(model: PreTrainedModel,
587
571
truncation_strategy = kwargs .pop ('truncation_strategy' , 'delete' )
588
572
if len (inputs ) == 0 and truncation_strategy == 'delete' :
589
573
# input_ids exceeds `max_length`. Please increase the value of `max_length`.
590
- return '' , history
574
+ return {}, tokenizer_kwargs , 0
591
575
592
576
inputs .pop ('labels' , None )
593
577
tokenizer = template .tokenizer
@@ -606,11 +590,8 @@ def inference_stream(model: PreTrainedModel,
606
590
inputs ['token_type_ids' ] = torch .tensor (inputs ['token_type_ids' ])[None ]
607
591
model .eval ()
608
592
if generation_config is None :
609
- generation_config = getattr (model , 'generation_config' , None )
593
+ generation_config = getattr (model , 'generation_config' )
610
594
generation_config = deepcopy (generation_config )
611
- if generation_config .num_beams != 1 :
612
- error_msg = 'Streaming generation does not support beam search.'
613
- raise ValueError (error_msg )
614
595
615
596
if tokenizer .eos_token_id is not None :
616
597
generation_config .eos_token_id = tokenizer .eos_token_id
@@ -627,21 +608,69 @@ def inference_stream(model: PreTrainedModel,
627
608
raise AssertionError ('Current sentence length exceeds' f'the model max_length: { max_length } ' )
628
609
if template .suffix [- 1 ] not in stop_words :
629
610
stop_words .append (template .suffix [- 1 ])
630
- stopping_criteria = StoppingCriteriaList ([StopWordsCriteria (tokenizer , stop_words , ** tokenizer_kwargs )])
631
611
inputs = to_device (inputs , device )
632
- if generation_info is not None :
633
- generation_info ['num_prompt_tokens' ] = token_len
634
612
if 'inputs_embeds' in inputs :
635
613
inputs .pop ('input_ids' , None )
636
- streamer = TokenListIteratorStreamer ()
637
614
if adapter_names is not None :
638
615
inputs ['adapter_names' ] = adapter_names
639
- generation_kwargs = {
640
- 'streamer' : streamer ,
641
- 'generation_config' : generation_config ,
642
- 'stopping_criteria' : stopping_criteria ,
643
- ** inputs
644
- }
616
+
617
+ stopping_criteria = StoppingCriteriaList ([StopWordsCriteria (tokenizer , stop_words , ** tokenizer_kwargs )])
618
+ inputs ['stopping_criteria' ] = stopping_criteria
619
+ inputs ['generation_config' ] = generation_config
620
+ return inputs , tokenizer_kwargs , token_len
621
+
622
+
623
+ @torch .inference_mode ()
624
+ def inference_stream (model : PreTrainedModel ,
625
+ template : Template ,
626
+ query : str ,
627
+ history : Optional [History ] = None ,
628
+ system : Optional [str ] = None ,
629
+ images : Optional [List [str ]] = None ,
630
+ * ,
631
+ generation_config : Optional [GenerationConfig ] = None ,
632
+ stop_words : Optional [StopWords ] = None ,
633
+ generation_info : Optional [Dict [str , int ]] = None ,
634
+ adapter_names : Optional [List [str ]] = None ,
635
+ ** kwargs ) -> Iterator [Tuple [str , History ]]:
636
+ """
637
+ generation_config: Priority: generation_config > model.generation_config.
638
+ """
639
+ if history is None :
640
+ history = []
641
+ else :
642
+ history = deepcopy (history )
643
+ inputs , tokenizer_kwargs , token_len = _prepare_inputs (
644
+ model ,
645
+ template ,
646
+ query ,
647
+ history ,
648
+ system ,
649
+ images ,
650
+ generation_config = generation_config ,
651
+ stop_words = stop_words ,
652
+ adapter_names = adapter_names ,
653
+ ** kwargs )
654
+ if len (inputs ) == 0 :
655
+ return '' , history
656
+ if generation_info is None :
657
+ generation_info = {}
658
+ generation_info ['num_prompt_tokens' ] = token_len
659
+
660
+ # agent support
661
+ is_observation = history [- 1 ][- 1 ].endswith ('Observation:' ) if history and history [- 1 ][- 1 ] else False
662
+ if is_observation :
663
+ history [- 1 ][- 1 ] = history [- 1 ][- 1 ] + query
664
+ act_length = len (history [- 1 ][- 1 ])
665
+ query = None
666
+
667
+ generation_config = inputs ['generation_config' ]
668
+ if generation_config .num_beams != 1 :
669
+ error_msg = 'Streaming generation does not support beam search.'
670
+ raise ValueError (error_msg )
671
+
672
+ streamer = TokenListIteratorStreamer ()
673
+ generation_kwargs = {'streamer' : streamer , ** inputs }
645
674
_model_generate = model .generate
646
675
if is_torch_npu_available ():
647
676
@@ -667,8 +696,7 @@ def _model_generate(*args, **kwargs):
667
696
except StopIteration :
668
697
is_finished = True
669
698
generate_ids = template .get_generate_ids (torch .tensor (raw_generate_ids )[None ], token_len )
670
- if generation_info is not None :
671
- generation_info ['num_generated_tokens' ] = len (generate_ids )
699
+ generation_info ['num_generated_tokens' ] = len (generate_ids )
672
700
response = template .generate_ids_to_response (
673
701
generate_ids ,
674
702
is_finished ,
@@ -702,58 +730,38 @@ def inference(model: PreTrainedModel,
702
730
"""
703
731
generation_config: Priority: generation_config > model.generation_config.
704
732
"""
705
- if stop_words is None :
706
- stop_words = []
707
733
if history is None :
708
734
history = []
709
735
else :
710
736
history = deepcopy (history )
711
- if images is None :
712
- images = []
737
+ inputs , tokenizer_kwargs , token_len = _prepare_inputs (
738
+ model ,
739
+ template ,
740
+ query ,
741
+ history ,
742
+ system ,
743
+ images ,
744
+ generation_config = generation_config ,
745
+ stop_words = stop_words ,
746
+ adapter_names = adapter_names ,
747
+ ** kwargs )
748
+ if len (inputs ) == 0 :
749
+ return '' , history
750
+ if generation_info is None :
751
+ generation_info = {}
752
+ generation_info ['num_prompt_tokens' ] = token_len
713
753
754
+ # agent support
714
755
is_observation = history [- 1 ][- 1 ].endswith ('Observation:' ) if history and history [- 1 ][- 1 ] else False
715
756
if is_observation :
716
757
history [- 1 ][- 1 ] = history [- 1 ][- 1 ] + query
717
758
query = None
718
759
719
- example = {
720
- 'query' : query ,
721
- 'history' : history ,
722
- 'system' : system ,
723
- 'images' : images , # for vl. str.
724
- 'tools' : kwargs .pop ('tools' , None )
725
- }
726
- template .model = model
727
- inputs , tokenizer_kwargs = template .encode (example )
728
-
729
- truncation_strategy = kwargs .pop ('truncation_strategy' , 'delete' )
730
- if len (inputs ) == 0 and truncation_strategy == 'delete' :
731
- # input_ids exceeds `max_length`. Please increase the value of `max_length`.
732
- return '' , history
733
-
734
- inputs .pop ('labels' , None )
735
- tokenizer = template .tokenizer
736
- device = next (model .parameters ()).device
737
- if 'input_ids' in inputs :
738
- input_ids = torch .tensor (inputs ['input_ids' ])[None ]
739
- inputs ['input_ids' ] = input_ids
740
- token_len = input_ids .shape [1 ]
741
- if 'inputs_embeds' in inputs :
742
- inputs_embeds = inputs ['inputs_embeds' ][None ]
743
- inputs ['inputs_embeds' ] = inputs_embeds
744
- token_len = inputs_embeds .shape [1 ]
745
-
746
- inputs ['attention_mask' ] = torch .ones (token_len )[None ]
747
- if 'token_type_ids' in inputs :
748
- inputs ['token_type_ids' ] = torch .tensor (inputs ['token_type_ids' ])[None ]
749
- model .eval ()
750
- if generation_config is None :
751
- generation_config = getattr (model , 'generation_config' , None )
752
- generation_config = deepcopy (generation_config )
753
760
if stream and not verbose :
754
761
logger .warning ('Please set verbose to True to support TextStreamer, or use `inference_stream.`' )
755
762
stream = False
756
763
streamer = None
764
+ tokenizer = template .tokenizer
757
765
if stream :
758
766
streamer = TextStreamer (tokenizer , skip_prompt = True )
759
767
if verbose :
@@ -762,37 +770,12 @@ def inference(model: PreTrainedModel,
762
770
print (
763
771
f'{ prompt_prefix } { safe_tokenizer_decode (tokenizer , input_ids [0 ], ** tokenizer_kwargs )} { output_prefix } ' ,
764
772
end = '' )
765
- elif 'query' in example :
766
- query = example ['query' ]
773
+ else :
767
774
print (f'[QUERY]{ query } \n { output_prefix } ' , end = '' )
768
- if tokenizer .eos_token_id is not None :
769
- generation_config .eos_token_id = tokenizer .eos_token_id
770
- if tokenizer .pad_token_id is not None :
771
- generation_config .pad_token_id = tokenizer .pad_token_id
772
- if tokenizer .bos_token_id is not None :
773
- generation_config .bos_token_id = tokenizer .bos_token_id
774
- if generation_config .max_new_tokens is not None :
775
- generation_config .max_length = 20 # fix max_length, max_new_tokens warning
776
- max_length = get_max_model_len (model .config )
777
- if max_length and token_len + generation_config .max_new_tokens > max_length :
778
- generation_config .max_new_tokens = max_length - token_len
779
- if generation_config .max_new_tokens <= 0 :
780
- raise AssertionError ('Current sentence length exceeds' f'the model max_length: { max_length } ' )
781
- if template .suffix [- 1 ] not in stop_words :
782
- stop_words .append (template .suffix [- 1 ])
783
- stopping_criteria = StoppingCriteriaList ([StopWordsCriteria (tokenizer , stop_words , ** tokenizer_kwargs )])
784
- inputs = to_device (inputs , device )
785
- if generation_info is not None :
786
- generation_info ['num_prompt_tokens' ] = token_len
787
- if 'inputs_embeds' in inputs :
788
- inputs .pop ('input_ids' , None )
789
- if adapter_names is not None :
790
- inputs ['adapter_names' ] = adapter_names
791
- generate_ids = model .generate (
792
- streamer = streamer , generation_config = generation_config , stopping_criteria = stopping_criteria , ** inputs )
775
+
776
+ generate_ids = model .generate (streamer = streamer , ** inputs )
793
777
generate_ids = template .get_generate_ids (generate_ids , token_len )
794
- if generation_info is not None :
795
- generation_info ['num_generated_tokens' ] = len (generate_ids )
778
+ generation_info ['num_generated_tokens' ] = len (generate_ids )
796
779
if verbose and stream is False :
797
780
response = tokenizer .decode (generate_ids , ** tokenizer_kwargs )
798
781
print (response )
0 commit comments