Skip to content

Commit cd74aa2

Browse files
authored
[TextGeneration] Add Streaming Functionality (#1246)
* add streaming functionality * remove print * set back default value * rebase * update to yield * update pipeline.py * update tests * refactor out streaming functions and remove yield in process_engine_output * fix tests * update pipeline to use kwargs * rebase * Update src/deepsparse/transformers/pipelines/text_generation.py
1 parent fdb5d44 commit cd74aa2

File tree

3 files changed

+143
-86
lines changed

3 files changed

+143
-86
lines changed

src/deepsparse/pipeline.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from concurrent.futures import ThreadPoolExecutor
2222
from functools import partial
2323
from pathlib import Path
24-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
24+
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union
2525

2626
import numpy
2727
from pydantic import BaseModel, Field
@@ -259,7 +259,9 @@ def __call__(self, *args, **kwargs) -> BaseModel:
259259
)
260260

261261
# join together the batches of size `self._batch_size`
262-
engine_outputs = self.join_engine_outputs(batch_outputs, orig_batch_size)
262+
engine_outputs = self.join_engine_outputs(
263+
batch_outputs, orig_batch_size, **context
264+
)
263265
timer.stop(InferenceStages.ENGINE_FORWARD)
264266

265267
self.log(
@@ -280,7 +282,10 @@ def __call__(self, *args, **kwargs) -> BaseModel:
280282
# ------ POSTPROCESSING ------
281283
timer.start(InferenceStages.POST_PROCESS)
282284
pipeline_outputs = self.process_engine_outputs(engine_outputs, **context)
283-
if not isinstance(pipeline_outputs, self.output_schema):
285+
if not (
286+
isinstance(pipeline_outputs, (self.output_schema, Generator))
287+
or isinstance(pipeline_outputs, Generator)
288+
):
284289
raise ValueError(
285290
f"Outputs of {self.__class__} must be instances of "
286291
f"{self.output_schema} found output of type "
@@ -467,7 +472,7 @@ def to_config(self) -> "PipelineConfig":
467472
)
468473

469474
def join_engine_outputs(
470-
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
475+
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int, **kwargs
471476
) -> List[numpy.ndarray]:
472477
"""
473478
Joins list of engine outputs together into one list.

src/deepsparse/transformers/pipelines/text_generation.py

+129-81
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import numpy
3434
import onnx
3535
from pydantic import BaseModel, Field
36-
from transformers import TextStreamer
3736

3837
from deepsparse import Pipeline
3938
from deepsparse.pipeline import DEEPSPARSE_ENGINE
@@ -61,6 +60,7 @@ class FinishReason(Enum):
6160
STOP = "stop"
6261
LENGTH = "length"
6362
TIME = "time"
63+
CALLBACK = "callback"
6464

6565

6666
class TextGenerationInput(BaseModel):
@@ -106,12 +106,12 @@ class Config:
106106
"to have consistent length so one "
107107
"can compute metric in a batched fashion. ",
108108
)
109-
streamer: Optional[TextStreamer] = Field(
110-
default=None,
111-
description="Streamer object that will be used to stream the "
112-
"generated sequences. Generated tokens are passed through "
113-
"`streamer.put(token_ids)` and the streamer is responsible "
114-
"for any further processing.",
109+
streaming: bool = Field(
110+
default=False,
111+
description="Whether to stream the results back as they are generated. If "
112+
"True, then the results are returned as a generator object which yields "
113+
"the results as they are generated. If False, then the results are returned "
114+
"as a list after it has completed.",
115115
)
116116
callback: Optional[Callable[[Any], Union[bool, Any]]] = Field(
117117
default=None,
@@ -161,7 +161,7 @@ class GeneratedText(BaseModel):
161161
"The scores have the shape [sequence_length, vocab_size]"
162162
)
163163
finished: bool = Field(description="Whether generation has stopped.")
164-
finished_reason: str = Field(
164+
finished_reason: Optional[str] = Field(
165165
description="The reason for generation to stop. "
166166
"Defined by FinishReason. One of stop, length, or time."
167167
)
@@ -473,9 +473,9 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
473473

474474
context = dict(
475475
prompts=original_inputs,
476+
streaming=inputs.streaming,
476477
num_generated_predictions=inputs.num_generated_predictions,
477478
return_logits=inputs.return_logits,
478-
streamer=inputs.streamer,
479479
include_prompt_logits=inputs.include_prompt_logits,
480480
callback=inputs.callback,
481481
stop=inputs.stop,
@@ -488,6 +488,40 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
488488

489489
return engine_input, context
490490

491+
def _create_generated_text_output(
492+
self,
493+
sequence: str,
494+
finish_reason: Optional[FinishReason] = None,
495+
logits: Optional[numpy.array] = None,
496+
):
497+
if finish_reason:
498+
return GeneratedText(
499+
text=sequence,
500+
score=logits,
501+
finished=True,
502+
finished_reason=finish_reason.value,
503+
)
504+
return GeneratedText(
505+
text=sequence,
506+
score=logits,
507+
finished=False,
508+
)
509+
510+
def _stream_engine_outputs(self, engine_outputs, prompts, kwargs):
511+
for output in engine_outputs:
512+
generated_tokens, generated_logits, finished_reason = output
513+
logits = generated_logits if kwargs.get("return_logits") else None
514+
generation = self._create_generated_text_output(
515+
self.tokenizer.batch_decode(generated_tokens)[0],
516+
finished_reason[0],
517+
logits,
518+
)
519+
yield TextGenerationOutput(
520+
created=datetime.datetime.now(),
521+
prompts=prompts,
522+
generations=[generation],
523+
)
524+
491525
def process_engine_outputs(
492526
self, engine_outputs: List[Union[numpy.ndarray, FinishReason]], **kwargs
493527
) -> TextGenerationOutput:
@@ -497,33 +531,29 @@ def process_engine_outputs(
497531
:param engine_outputs: the outputs from the engine
498532
:return: the output schema for the pipeline
499533
"""
500-
generated_tokens, generated_logits, finished_reason, *debug = engine_outputs
501-
finished_reason = [f[0] for f in finished_reason]
502534

535+
prompts = kwargs.get("prompts")
536+
streaming = kwargs.get("streaming")
537+
538+
if streaming:
539+
return self._stream_engine_outputs(engine_outputs, prompts, kwargs)
540+
541+
generated_tokens, generated_logits, finished_reason, *debug = list(
542+
*engine_outputs
543+
)
503544
sequences = self.tokenizer.batch_decode(
504545
generated_tokens, skip_special_tokens=True
505546
)
506-
num_preds = kwargs.get("num_generated_predictions", 1)
507-
prompts = kwargs.get("prompts")
508-
509-
def _create_generated_text_output(
510-
sequence: str,
511-
finish_reason: FinishReason,
512-
logits: Optional[numpy.array] = None,
513-
):
514-
return GeneratedText(
515-
text=sequence,
516-
score=logits,
517-
finished=True,
518-
finished_reason=finish_reason.value,
519-
)
520547

521548
logits = generated_logits if kwargs.get("return_logits") else None
522549

550+
num_preds = kwargs.get("num_generated_predictions", 1)
551+
finished_reason = [f[0] for f in finished_reason]
552+
523553
if logits is not None:
524554
generations = list(
525555
self.executor.map(
526-
_create_generated_text_output,
556+
self._create_generated_text_output,
527557
sequences,
528558
finished_reason,
529559
logits,
@@ -532,7 +562,7 @@ def _create_generated_text_output(
532562
else:
533563
generations = list(
534564
self.executor.map(
535-
_create_generated_text_output, sequences, finished_reason
565+
self._create_generated_text_output, sequences, finished_reason
536566
)
537567
)
538568

@@ -582,8 +612,8 @@ def engine_forward(
582612
# names in this context
583613

584614
with self.timer_manager.new_timer_context(total_inference=False) as timer:
585-
streamer = context.get("streamer")
586615
finished_reason = []
616+
streaming = context.get("streaming")
587617

588618
if not self.cache_support_enabled:
589619
prompt_logits = self.multitoken_engine(engine_inputs)
@@ -610,9 +640,6 @@ def engine_forward(
610640
)
611641
token_generator.generate(prompt_logits[-1][0, -1, :])
612642

613-
if streamer is not None:
614-
streamer.put(numpy.array(token_generator.tokens))
615-
616643
# create the generated output
617644
max_tokens = context.get("max_tokens", 0)
618645
max_tokens = max_tokens if max_tokens > 0 else (100 * self.sequence_length)
@@ -638,9 +665,6 @@ def engine_forward(
638665
generated_tokens.append(token)
639666
generated_logits.append(logits)
640667

641-
if streamer is not None:
642-
streamer.put(numpy.array([token]))
643-
644668
if (
645669
token == self.tokenizer.eos_token_id
646670
and not self.force_max_tokens
@@ -656,30 +680,38 @@ def engine_forward(
656680
finished_reason.append(FinishReason.STOP)
657681
break
658682

659-
# TODO: Add any generic callback reason?
660683
if callback is not None and callback(token) is False:
661684
_LOGGER.debug(
662685
"callback %s returned False, stopping generation."
663686
% callback.__qualname__
664687
)
688+
finished_reason.append(FinishReason.CALLBACK)
665689
break
666690

667691
if len(generated_tokens) == max_tokens:
668692
finished_reason.append(FinishReason.LENGTH)
669693

670-
if streamer is not None:
671-
streamer.end()
694+
if streaming:
695+
yield (numpy.array([token]), numpy.array([logits]), [None])
672696

673-
returns = (
674-
numpy.array([generated_tokens]),
675-
numpy.concatenate(generated_logits, axis=1),
676-
finished_reason,
677-
)
697+
if streaming:
698+
yield (
699+
numpy.array([token]),
700+
numpy.array([logits]),
701+
[finished_reason[-1]],
702+
)
703+
704+
if not streaming:
705+
returns = (
706+
numpy.array([generated_tokens]),
707+
numpy.concatenate(generated_logits, axis=1),
708+
finished_reason,
709+
)
678710

679-
if self._debug is True:
680-
return *returns, session
711+
if self._debug is True:
712+
yield *returns, session
681713

682-
return returns
714+
yield returns
683715

684716
def prompt_inference(
685717
self,
@@ -870,6 +902,7 @@ def join_engine_outputs(
870902
self,
871903
batch_outputs: List[List[Union[numpy.ndarray, FinishReason]]],
872904
orig_batch_size: int,
905+
**kwargs,
873906
) -> List[Union[numpy.ndarray, FinishReason]]:
874907
"""
875908
Takes a list of outputs (batches) from the engine
@@ -881,48 +914,63 @@ def join_engine_outputs(
881914
:param orig_batch_size: The original batch size
882915
:return: A list of joined outputs
883916
"""
884-
tokens, logits, finish_reason, *debug = zip(*batch_outputs)
885-
if self.cache_support_enabled:
886-
# if the model has kv cache, we need to account for
887-
# the fact that the predicted outputs may have
888-
# different lengths
889-
890-
# find the longest sequence in the batch of tokens
891-
max_len = max([token.shape[1] for token in tokens])
892-
893-
# pad all tokens to the same length
894-
tokens = [
895-
pad_to_fixed_length(
896-
array=prediction,
897-
max_len=max_len,
898-
value=self.tokenizer.pad_token_id,
899-
axis=1,
900-
)
901-
for prediction in tokens
902-
]
917+
streaming = kwargs.get("streaming")
918+
if streaming:
919+
for batch in batch_outputs:
920+
for outputs in batch:
921+
yield outputs
922+
else:
923+
batch_outputs = [list(*b) for b in batch_outputs]
924+
tokens, logits, finish_reason, *debug = zip(*batch_outputs)
925+
if self.cache_support_enabled:
926+
# if the model has kv cache, we need to account for
927+
# the fact that the predicted outputs may have
928+
# different lengths
929+
930+
# find the longest sequence in the batch of tokens
931+
max_len = max([token.shape[1] for token in tokens])
932+
933+
# pad all tokens to the same length
934+
tokens = [
935+
pad_to_fixed_length(
936+
array=prediction,
937+
max_len=max_len,
938+
value=self.tokenizer.pad_token_id,
939+
axis=1,
940+
)
941+
for prediction in tokens
942+
]
903943

904-
# find the longest sequence in the batch of logits
905-
max_len = max([logits.shape[1] for logits in logits])
944+
# find the longest sequence in the batch of logits
945+
max_len = max([logits.shape[1] for logits in logits])
906946

907-
# pad all logits to the same length
908-
logits = [
909-
pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1)
910-
for single_logits in logits
911-
]
947+
# pad all logits to the same length
948+
logits = [
949+
pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1)
950+
for single_logits in logits
951+
]
912952

913-
tokens = numpy.concatenate(tokens, axis=0)
914-
logits = numpy.concatenate(logits, axis=0)
953+
tokens = numpy.concatenate(tokens, axis=0)
954+
logits = numpy.concatenate(logits, axis=0)
915955

916-
if debug:
917-
sessions = debug[0]
918-
kv_cache_state = numpy.stack(session.cached_inputs for session in sessions)
919-
num_processed_tokens = numpy.stack(
920-
session.total_num_processed_tokens for session in sessions
921-
)
956+
if debug:
957+
sessions = debug[0]
958+
kv_cache_state = numpy.stack(
959+
session.cached_inputs for session in sessions
960+
)
961+
num_processed_tokens = numpy.stack(
962+
session.total_num_processed_tokens for session in sessions
963+
)
922964

923-
return [tokens, logits, finish_reason, kv_cache_state, num_processed_tokens]
965+
yield [
966+
tokens,
967+
logits,
968+
finish_reason,
969+
kv_cache_state,
970+
num_processed_tokens,
971+
]
924972

925-
return [tokens, logits, finish_reason]
973+
yield [tokens, logits, finish_reason]
926974

927975
@staticmethod
928976
def causal_mask_input_present(model_path: str) -> bool:

0 commit comments

Comments
 (0)