diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 765e6ff413..185c2daae0 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -19,6 +19,7 @@ import os from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor +from functools import partial from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -228,9 +229,9 @@ def __call__(self, *args, **kwargs) -> BaseModel: # batch size of the inputs may be `> self._batch_size` at this point engine_inputs: List[numpy.ndarray] = self.process_inputs(pipeline_inputs) if isinstance(engine_inputs, tuple): - engine_inputs, postprocess_kwargs = engine_inputs + engine_inputs, context = engine_inputs else: - postprocess_kwargs = {} + context = {} timer.stop(InferenceStages.PRE_PROCESS) self.log( @@ -247,7 +248,10 @@ def __call__(self, *args, **kwargs) -> BaseModel: ) # submit split batches to engine threadpool - batch_outputs = list(self.executor.map(self.engine_forward, batches)) + engine_forward_with_context = partial(self.engine_forward, context=context) + batch_outputs = list( + self.executor.map(engine_forward_with_context, batches) + ) # join together the batches of size `self._batch_size` engine_outputs = self.join_engine_outputs(batch_outputs, orig_batch_size) @@ -270,9 +274,7 @@ def __call__(self, *args, **kwargs) -> BaseModel: # ------ POSTPROCESSING ------ timer.start(InferenceStages.POST_PROCESS) - pipeline_outputs = self.process_engine_outputs( - engine_outputs, **postprocess_kwargs - ) + pipeline_outputs = self.process_engine_outputs(engine_outputs, **context) if not isinstance(pipeline_outputs, self.output_schema): raise ValueError( f"Outputs of {self.__class__} must be instances of " @@ -486,10 +488,13 @@ def split_engine_inputs( """ return split_engine_inputs(items, batch_size) - def engine_forward(self, engine_inputs: List[numpy.ndarray]) -> List[numpy.ndarray]: + def engine_forward( + self, engine_inputs: List[numpy.ndarray], context: Dict = {} + ) -> List[numpy.ndarray]: """ :param engine_inputs: list of numpy inputs to Pipeline engine forward pass + :param context: optional dictionary to be used during engine execution :return: result of forward pass to Pipeline engine """ return self.engine(engine_inputs) diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index d5b08728d9..97f4a58621 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -16,10 +16,11 @@ import os import warnings from dataclasses import dataclass -from typing import Generator, List, Optional, Tuple, Type, Union +from typing import Dict, Generator, List, Optional, Tuple, Type, Union import numpy from pydantic import BaseModel, Field +from transformers import TextStreamer from deepsparse import Pipeline from deepsparse.cpu import cpu_avx512_compatible @@ -46,6 +47,9 @@ class _TextGenerationTimings: class TextGenerationInput(BaseModel): + class Config: + arbitrary_types_allowed = True + sequences: Union[str, List[str]] = Field( description="The input sequences to generate the text from.", ) @@ -71,6 +75,13 @@ class TextGenerationInput(BaseModel): "to have consistent length so one " "can compute metric in a batched fashion. ", ) + streamer: Optional[TextStreamer] = Field( + default=None, + description="Streamer object that will be used to stream the " + "generated sequences. Generated tokens are passed through " + "`streamer.put(token_ids)` and the streamer is responsible " + "for any further processing.", + ) class TextGenerationOutput(BaseModel): @@ -290,7 +301,9 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]: self.engine.session_id = inputs.session_id self.multitoken_engine.session_id = inputs.session_id - postprocessing_kwargs = dict(return_logits=inputs.return_logits) + postprocessing_kwargs = dict( + return_logits=inputs.return_logits, streamer=inputs.streamer + ) return engine_input, postprocessing_kwargs def process_engine_outputs( @@ -311,7 +324,7 @@ def process_engine_outputs( return TextGenerationOutput(sequences=sequences, logits=logits) def engine_forward( - self, engine_inputs: List[numpy.ndarray], **kwargs + self, engine_inputs: List[numpy.ndarray], context: Dict ) -> Tuple[numpy.ndarray, numpy.ndarray]: """ Run the forward pass on the engine. @@ -327,6 +340,8 @@ def engine_forward( # main thread. That is why `engine_` is prepended to each of the timer phase # names in this context with self.timer_manager.new_timer_context(total_inference=False) as timer: + streamer = context.get("streamer") + if not self.multitoken_engine.kv_cache_enabled: tokens, prompt_logits = self.multitoken_engine(engine_inputs) return numpy.array([tokens]), prompt_logits @@ -336,6 +351,9 @@ def engine_forward( with timer.time(_TextGenerationTimings.PROMPT_PREFILL): tokens, prompt_logits = self.prompt_inference(engine_inputs) + if streamer is not None: + streamer.put(numpy.array(tokens)) + # create the generated output max_tokens = ( self.max_generated_tokens @@ -354,12 +372,18 @@ def engine_forward( generated_tokens.append(token) generated_logits.append(logits) + if streamer is not None: + streamer.put(numpy.array([token])) + if ( token == self.tokenizer.eos_token_id and not self.force_max_tokens ): break + if streamer is not None: + streamer.end() + return numpy.array([generated_tokens]), numpy.concatenate( generated_logits, axis=1 ) diff --git a/tests/deepsparse/pipelines/test_pipeline.py b/tests/deepsparse/pipelines/test_pipeline.py index 139e579616..dad1105437 100644 --- a/tests/deepsparse/pipelines/test_pipeline.py +++ b/tests/deepsparse/pipelines/test_pipeline.py @@ -132,7 +132,7 @@ def test_pipeline_call_is_async(engine_mock): executor = ThreadPoolExecutor(max_workers=1) pipeline = Pipeline.create("token_classification", batch_size=1, executor=executor) - def sleep_then_engine_forward(xs): + def sleep_then_engine_forward(xs, context): # each call to engine_forward also sleeps time.sleep(20 / 1000) return pipeline.engine(xs)