@@ -273,7 +273,6 @@ def __init__(
273
273
self .tokenizer .pad_token = self .tokenizer .eos_token
274
274
275
275
self .engine , self .multitoken_engine = self .initialize_engines ()
276
- self .streaming = False
277
276
278
277
def initialize_engines (
279
278
self ,
@@ -419,7 +418,6 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
419
418
:param inputs: the input schema for the pipeline
420
419
:return: the inputs for the engine
421
420
"""
422
- self .streaming = inputs .streaming
423
421
if not self .cache_support_enabled and inputs .max_tokens > 1 :
424
422
raise ValueError (
425
423
"The model used for inference does not support kv cache. It is "
@@ -488,6 +486,7 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
488
486
489
487
context = dict (
490
488
prompts = original_inputs ,
489
+ streaming = inputs .streaming ,
491
490
num_generated_predictions = inputs .num_generated_predictions ,
492
491
return_logits = inputs .return_logits ,
493
492
include_prompt_logits = inputs .include_prompt_logits ,
@@ -547,8 +546,9 @@ def process_engine_outputs(
547
546
"""
548
547
549
548
prompts = kwargs .get ("prompts" )
549
+ streaming = kwargs .get ("streaming" )
550
550
551
- if self . streaming :
551
+ if streaming :
552
552
return self ._stream_engine_outputs (engine_outputs , prompts , kwargs )
553
553
554
554
generated_tokens , generated_logits , finished_reason = list (* engine_outputs )
@@ -611,6 +611,7 @@ def engine_forward(
611
611
612
612
with self .timer_manager .new_timer_context (total_inference = False ) as timer :
613
613
finished_reason = []
614
+ streaming = context .get ("streaming" )
614
615
615
616
if not self .cache_support_enabled :
616
617
prompt_logits = self .multitoken_engine (engine_inputs )
@@ -688,17 +689,17 @@ def engine_forward(
688
689
if len (generated_tokens ) == max_tokens :
689
690
finished_reason .append (FinishReason .LENGTH )
690
691
691
- if self . streaming :
692
+ if streaming :
692
693
yield (numpy .array ([token ]), numpy .array ([logits ]), [None ])
693
694
694
- if self . streaming :
695
+ if streaming :
695
696
yield (
696
697
numpy .array ([token ]),
697
698
numpy .array ([logits ]),
698
699
[finished_reason [- 1 ]],
699
700
)
700
701
701
- if not self . streaming :
702
+ if not streaming :
702
703
yield (
703
704
numpy .array ([generated_tokens ]),
704
705
numpy .concatenate (generated_logits , axis = 1 ),
@@ -895,6 +896,7 @@ def join_engine_outputs(
895
896
self ,
896
897
batch_outputs : List [List [Union [numpy .ndarray , FinishReason ]]],
897
898
orig_batch_size : int ,
899
+ ** kwargs ,
898
900
) -> List [Union [numpy .ndarray , FinishReason ]]:
899
901
"""
900
902
Takes a list of outputs (batches) from the engine
@@ -906,7 +908,8 @@ def join_engine_outputs(
906
908
:param orig_batch_size: The original batch size
907
909
:return: A list of joined outputs
908
910
"""
909
- if self .streaming :
911
+ streaming = kwargs .get ("streaming" )
912
+ if streaming :
910
913
for batch in batch_outputs :
911
914
for outputs in batch :
912
915
yield outputs
0 commit comments