33
33
import numpy
34
34
import onnx
35
35
from pydantic import BaseModel , Field
36
- from transformers import TextStreamer
37
36
38
37
from deepsparse import Pipeline
39
38
from deepsparse .pipeline import DEEPSPARSE_ENGINE
@@ -61,6 +60,7 @@ class FinishReason(Enum):
61
60
STOP = "stop"
62
61
LENGTH = "length"
63
62
TIME = "time"
63
+ CALLBACK = "callback"
64
64
65
65
66
66
class TextGenerationInput (BaseModel ):
@@ -106,12 +106,12 @@ class Config:
106
106
"to have consistent length so one "
107
107
"can compute metric in a batched fashion. " ,
108
108
)
109
- streamer : Optional [ TextStreamer ] = Field (
110
- default = None ,
111
- description = "Streamer object that will be used to stream the "
112
- "generated sequences. Generated tokens are passed through "
113
- "`streamer.put(token_ids)` and the streamer is responsible "
114
- "for any further processing ." ,
109
+ streaming : bool = Field (
110
+ default = False ,
111
+ description = "Whether to stream the results back as they are generated. If "
112
+ "True, then the results are returned as a generator object which yields "
113
+ "the results as they are generated. If False, then the results are returned "
114
+ "as a list after it has completed ." ,
115
115
)
116
116
callback : Optional [Callable [[Any ], Union [bool , Any ]]] = Field (
117
117
default = None ,
@@ -161,7 +161,7 @@ class GeneratedText(BaseModel):
161
161
"The scores have the shape [sequence_length, vocab_size]"
162
162
)
163
163
finished : bool = Field (description = "Whether generation has stopped." )
164
- finished_reason : str = Field (
164
+ finished_reason : Optional [ str ] = Field (
165
165
description = "The reason for generation to stop. "
166
166
"Defined by FinishReason. One of stop, length, or time."
167
167
)
@@ -473,9 +473,9 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
473
473
474
474
context = dict (
475
475
prompts = original_inputs ,
476
+ streaming = inputs .streaming ,
476
477
num_generated_predictions = inputs .num_generated_predictions ,
477
478
return_logits = inputs .return_logits ,
478
- streamer = inputs .streamer ,
479
479
include_prompt_logits = inputs .include_prompt_logits ,
480
480
callback = inputs .callback ,
481
481
stop = inputs .stop ,
@@ -488,6 +488,40 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
488
488
489
489
return engine_input , context
490
490
491
+ def _create_generated_text_output (
492
+ self ,
493
+ sequence : str ,
494
+ finish_reason : Optional [FinishReason ] = None ,
495
+ logits : Optional [numpy .array ] = None ,
496
+ ):
497
+ if finish_reason :
498
+ return GeneratedText (
499
+ text = sequence ,
500
+ score = logits ,
501
+ finished = True ,
502
+ finished_reason = finish_reason .value ,
503
+ )
504
+ return GeneratedText (
505
+ text = sequence ,
506
+ score = logits ,
507
+ finished = False ,
508
+ )
509
+
510
+ def _stream_engine_outputs (self , engine_outputs , prompts , kwargs ):
511
+ for output in engine_outputs :
512
+ generated_tokens , generated_logits , finished_reason = output
513
+ logits = generated_logits if kwargs .get ("return_logits" ) else None
514
+ generation = self ._create_generated_text_output (
515
+ self .tokenizer .batch_decode (generated_tokens )[0 ],
516
+ finished_reason [0 ],
517
+ logits ,
518
+ )
519
+ yield TextGenerationOutput (
520
+ created = datetime .datetime .now (),
521
+ prompts = prompts ,
522
+ generations = [generation ],
523
+ )
524
+
491
525
def process_engine_outputs (
492
526
self , engine_outputs : List [Union [numpy .ndarray , FinishReason ]], ** kwargs
493
527
) -> TextGenerationOutput :
@@ -497,33 +531,29 @@ def process_engine_outputs(
497
531
:param engine_outputs: the outputs from the engine
498
532
:return: the output schema for the pipeline
499
533
"""
500
- generated_tokens , generated_logits , finished_reason , * debug = engine_outputs
501
- finished_reason = [f [0 ] for f in finished_reason ]
502
534
535
+ prompts = kwargs .get ("prompts" )
536
+ streaming = kwargs .get ("streaming" )
537
+
538
+ if streaming :
539
+ return self ._stream_engine_outputs (engine_outputs , prompts , kwargs )
540
+
541
+ generated_tokens , generated_logits , finished_reason , * debug = list (
542
+ * engine_outputs
543
+ )
503
544
sequences = self .tokenizer .batch_decode (
504
545
generated_tokens , skip_special_tokens = True
505
546
)
506
- num_preds = kwargs .get ("num_generated_predictions" , 1 )
507
- prompts = kwargs .get ("prompts" )
508
-
509
- def _create_generated_text_output (
510
- sequence : str ,
511
- finish_reason : FinishReason ,
512
- logits : Optional [numpy .array ] = None ,
513
- ):
514
- return GeneratedText (
515
- text = sequence ,
516
- score = logits ,
517
- finished = True ,
518
- finished_reason = finish_reason .value ,
519
- )
520
547
521
548
logits = generated_logits if kwargs .get ("return_logits" ) else None
522
549
550
+ num_preds = kwargs .get ("num_generated_predictions" , 1 )
551
+ finished_reason = [f [0 ] for f in finished_reason ]
552
+
523
553
if logits is not None :
524
554
generations = list (
525
555
self .executor .map (
526
- _create_generated_text_output ,
556
+ self . _create_generated_text_output ,
527
557
sequences ,
528
558
finished_reason ,
529
559
logits ,
@@ -532,7 +562,7 @@ def _create_generated_text_output(
532
562
else :
533
563
generations = list (
534
564
self .executor .map (
535
- _create_generated_text_output , sequences , finished_reason
565
+ self . _create_generated_text_output , sequences , finished_reason
536
566
)
537
567
)
538
568
@@ -582,8 +612,8 @@ def engine_forward(
582
612
# names in this context
583
613
584
614
with self .timer_manager .new_timer_context (total_inference = False ) as timer :
585
- streamer = context .get ("streamer" )
586
615
finished_reason = []
616
+ streaming = context .get ("streaming" )
587
617
588
618
if not self .cache_support_enabled :
589
619
prompt_logits = self .multitoken_engine (engine_inputs )
@@ -610,9 +640,6 @@ def engine_forward(
610
640
)
611
641
token_generator .generate (prompt_logits [- 1 ][0 , - 1 , :])
612
642
613
- if streamer is not None :
614
- streamer .put (numpy .array (token_generator .tokens ))
615
-
616
643
# create the generated output
617
644
max_tokens = context .get ("max_tokens" , 0 )
618
645
max_tokens = max_tokens if max_tokens > 0 else (100 * self .sequence_length )
@@ -638,9 +665,6 @@ def engine_forward(
638
665
generated_tokens .append (token )
639
666
generated_logits .append (logits )
640
667
641
- if streamer is not None :
642
- streamer .put (numpy .array ([token ]))
643
-
644
668
if (
645
669
token == self .tokenizer .eos_token_id
646
670
and not self .force_max_tokens
@@ -656,30 +680,38 @@ def engine_forward(
656
680
finished_reason .append (FinishReason .STOP )
657
681
break
658
682
659
- # TODO: Add any generic callback reason?
660
683
if callback is not None and callback (token ) is False :
661
684
_LOGGER .debug (
662
685
"callback %s returned False, stopping generation."
663
686
% callback .__qualname__
664
687
)
688
+ finished_reason .append (FinishReason .CALLBACK )
665
689
break
666
690
667
691
if len (generated_tokens ) == max_tokens :
668
692
finished_reason .append (FinishReason .LENGTH )
669
693
670
- if streamer is not None :
671
- streamer . end ( )
694
+ if streaming :
695
+ yield ( numpy . array ([ token ]), numpy . array ([ logits ]), [ None ] )
672
696
673
- returns = (
674
- numpy .array ([generated_tokens ]),
675
- numpy .concatenate (generated_logits , axis = 1 ),
676
- finished_reason ,
677
- )
697
+ if streaming :
698
+ yield (
699
+ numpy .array ([token ]),
700
+ numpy .array ([logits ]),
701
+ [finished_reason [- 1 ]],
702
+ )
703
+
704
+ if not streaming :
705
+ returns = (
706
+ numpy .array ([generated_tokens ]),
707
+ numpy .concatenate (generated_logits , axis = 1 ),
708
+ finished_reason ,
709
+ )
678
710
679
- if self ._debug is True :
680
- return * returns , session
711
+ if self ._debug is True :
712
+ yield * returns , session
681
713
682
- return returns
714
+ yield returns
683
715
684
716
def prompt_inference (
685
717
self ,
@@ -870,6 +902,7 @@ def join_engine_outputs(
870
902
self ,
871
903
batch_outputs : List [List [Union [numpy .ndarray , FinishReason ]]],
872
904
orig_batch_size : int ,
905
+ ** kwargs ,
873
906
) -> List [Union [numpy .ndarray , FinishReason ]]:
874
907
"""
875
908
Takes a list of outputs (batches) from the engine
@@ -881,48 +914,63 @@ def join_engine_outputs(
881
914
:param orig_batch_size: The original batch size
882
915
:return: A list of joined outputs
883
916
"""
884
- tokens , logits , finish_reason , * debug = zip (* batch_outputs )
885
- if self .cache_support_enabled :
886
- # if the model has kv cache, we need to account for
887
- # the fact that the predicted outputs may have
888
- # different lengths
889
-
890
- # find the longest sequence in the batch of tokens
891
- max_len = max ([token .shape [1 ] for token in tokens ])
892
-
893
- # pad all tokens to the same length
894
- tokens = [
895
- pad_to_fixed_length (
896
- array = prediction ,
897
- max_len = max_len ,
898
- value = self .tokenizer .pad_token_id ,
899
- axis = 1 ,
900
- )
901
- for prediction in tokens
902
- ]
917
+ streaming = kwargs .get ("streaming" )
918
+ if streaming :
919
+ for batch in batch_outputs :
920
+ for outputs in batch :
921
+ yield outputs
922
+ else :
923
+ batch_outputs = [list (* b ) for b in batch_outputs ]
924
+ tokens , logits , finish_reason , * debug = zip (* batch_outputs )
925
+ if self .cache_support_enabled :
926
+ # if the model has kv cache, we need to account for
927
+ # the fact that the predicted outputs may have
928
+ # different lengths
929
+
930
+ # find the longest sequence in the batch of tokens
931
+ max_len = max ([token .shape [1 ] for token in tokens ])
932
+
933
+ # pad all tokens to the same length
934
+ tokens = [
935
+ pad_to_fixed_length (
936
+ array = prediction ,
937
+ max_len = max_len ,
938
+ value = self .tokenizer .pad_token_id ,
939
+ axis = 1 ,
940
+ )
941
+ for prediction in tokens
942
+ ]
903
943
904
- # find the longest sequence in the batch of logits
905
- max_len = max ([logits .shape [1 ] for logits in logits ])
944
+ # find the longest sequence in the batch of logits
945
+ max_len = max ([logits .shape [1 ] for logits in logits ])
906
946
907
- # pad all logits to the same length
908
- logits = [
909
- pad_to_fixed_length (array = single_logits , max_len = max_len , axis = 1 )
910
- for single_logits in logits
911
- ]
947
+ # pad all logits to the same length
948
+ logits = [
949
+ pad_to_fixed_length (array = single_logits , max_len = max_len , axis = 1 )
950
+ for single_logits in logits
951
+ ]
912
952
913
- tokens = numpy .concatenate (tokens , axis = 0 )
914
- logits = numpy .concatenate (logits , axis = 0 )
953
+ tokens = numpy .concatenate (tokens , axis = 0 )
954
+ logits = numpy .concatenate (logits , axis = 0 )
915
955
916
- if debug :
917
- sessions = debug [0 ]
918
- kv_cache_state = numpy .stack (session .cached_inputs for session in sessions )
919
- num_processed_tokens = numpy .stack (
920
- session .total_num_processed_tokens for session in sessions
921
- )
956
+ if debug :
957
+ sessions = debug [0 ]
958
+ kv_cache_state = numpy .stack (
959
+ session .cached_inputs for session in sessions
960
+ )
961
+ num_processed_tokens = numpy .stack (
962
+ session .total_num_processed_tokens for session in sessions
963
+ )
922
964
923
- return [tokens , logits , finish_reason , kv_cache_state , num_processed_tokens ]
965
+ yield [
966
+ tokens ,
967
+ logits ,
968
+ finish_reason ,
969
+ kv_cache_state ,
970
+ num_processed_tokens ,
971
+ ]
924
972
925
- return [tokens , logits , finish_reason ]
973
+ yield [tokens , logits , finish_reason ]
926
974
927
975
@staticmethod
928
976
def causal_mask_input_present (model_path : str ) -> bool :
0 commit comments