Skip to content

Commit b7726ed

Browse files
authored
Implement streamer for text-generation and add context arg to Pipeline.engine_forward (#1140)
* Update with engine_forward and context * Style * Fix test engine_forward patching * postprocessing_kwargs -> context
1 parent ed9b1ee commit b7726ed

File tree

3 files changed

+40
-11
lines changed

3 files changed

+40
-11
lines changed

src/deepsparse/pipeline.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
from abc import ABC, abstractmethod
2121
from concurrent.futures import ThreadPoolExecutor
22+
from functools import partial
2223
from pathlib import Path
2324
from typing import Any, Dict, List, Optional, Tuple, Type, Union
2425

@@ -228,9 +229,9 @@ def __call__(self, *args, **kwargs) -> BaseModel:
228229
# batch size of the inputs may be `> self._batch_size` at this point
229230
engine_inputs: List[numpy.ndarray] = self.process_inputs(pipeline_inputs)
230231
if isinstance(engine_inputs, tuple):
231-
engine_inputs, postprocess_kwargs = engine_inputs
232+
engine_inputs, context = engine_inputs
232233
else:
233-
postprocess_kwargs = {}
234+
context = {}
234235

235236
timer.stop(InferenceStages.PRE_PROCESS)
236237
self.log(
@@ -247,7 +248,10 @@ def __call__(self, *args, **kwargs) -> BaseModel:
247248
)
248249

249250
# submit split batches to engine threadpool
250-
batch_outputs = list(self.executor.map(self.engine_forward, batches))
251+
engine_forward_with_context = partial(self.engine_forward, context=context)
252+
batch_outputs = list(
253+
self.executor.map(engine_forward_with_context, batches)
254+
)
251255

252256
# join together the batches of size `self._batch_size`
253257
engine_outputs = self.join_engine_outputs(batch_outputs, orig_batch_size)
@@ -270,9 +274,7 @@ def __call__(self, *args, **kwargs) -> BaseModel:
270274

271275
# ------ POSTPROCESSING ------
272276
timer.start(InferenceStages.POST_PROCESS)
273-
pipeline_outputs = self.process_engine_outputs(
274-
engine_outputs, **postprocess_kwargs
275-
)
277+
pipeline_outputs = self.process_engine_outputs(engine_outputs, **context)
276278
if not isinstance(pipeline_outputs, self.output_schema):
277279
raise ValueError(
278280
f"Outputs of {self.__class__} must be instances of "
@@ -486,10 +488,13 @@ def split_engine_inputs(
486488
"""
487489
return split_engine_inputs(items, batch_size)
488490

489-
def engine_forward(self, engine_inputs: List[numpy.ndarray]) -> List[numpy.ndarray]:
491+
def engine_forward(
492+
self, engine_inputs: List[numpy.ndarray], context: Dict = {}
493+
) -> List[numpy.ndarray]:
490494
"""
491495
:param engine_inputs: list of numpy inputs to Pipeline engine forward
492496
pass
497+
:param context: optional dictionary to be used during engine execution
493498
:return: result of forward pass to Pipeline engine
494499
"""
495500
return self.engine(engine_inputs)

src/deepsparse/transformers/pipelines/text_generation.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import os
1717
import warnings
1818
from dataclasses import dataclass
19-
from typing import Generator, List, Optional, Tuple, Type, Union
19+
from typing import Dict, Generator, List, Optional, Tuple, Type, Union
2020

2121
import numpy
2222
from pydantic import BaseModel, Field
23+
from transformers import TextStreamer
2324

2425
from deepsparse import Pipeline
2526
from deepsparse.cpu import cpu_avx512_compatible
@@ -46,6 +47,9 @@ class _TextGenerationTimings:
4647

4748

4849
class TextGenerationInput(BaseModel):
50+
class Config:
51+
arbitrary_types_allowed = True
52+
4953
sequences: Union[str, List[str]] = Field(
5054
description="The input sequences to generate the text from.",
5155
)
@@ -71,6 +75,13 @@ class TextGenerationInput(BaseModel):
7175
"to have consistent length so one "
7276
"can compute metric in a batched fashion. ",
7377
)
78+
streamer: Optional[TextStreamer] = Field(
79+
default=None,
80+
description="Streamer object that will be used to stream the "
81+
"generated sequences. Generated tokens are passed through "
82+
"`streamer.put(token_ids)` and the streamer is responsible "
83+
"for any further processing.",
84+
)
7485

7586

7687
class TextGenerationOutput(BaseModel):
@@ -290,7 +301,9 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
290301
self.engine.session_id = inputs.session_id
291302
self.multitoken_engine.session_id = inputs.session_id
292303

293-
postprocessing_kwargs = dict(return_logits=inputs.return_logits)
304+
postprocessing_kwargs = dict(
305+
return_logits=inputs.return_logits, streamer=inputs.streamer
306+
)
294307
return engine_input, postprocessing_kwargs
295308

296309
def process_engine_outputs(
@@ -311,7 +324,7 @@ def process_engine_outputs(
311324
return TextGenerationOutput(sequences=sequences, logits=logits)
312325

313326
def engine_forward(
314-
self, engine_inputs: List[numpy.ndarray], **kwargs
327+
self, engine_inputs: List[numpy.ndarray], context: Dict
315328
) -> Tuple[numpy.ndarray, numpy.ndarray]:
316329
"""
317330
Run the forward pass on the engine.
@@ -327,6 +340,8 @@ def engine_forward(
327340
# main thread. That is why `engine_` is prepended to each of the timer phase
328341
# names in this context
329342
with self.timer_manager.new_timer_context(total_inference=False) as timer:
343+
streamer = context.get("streamer")
344+
330345
if not self.multitoken_engine.kv_cache_enabled:
331346
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
332347
return numpy.array([tokens]), prompt_logits
@@ -336,6 +351,9 @@ def engine_forward(
336351
with timer.time(_TextGenerationTimings.PROMPT_PREFILL):
337352
tokens, prompt_logits = self.prompt_inference(engine_inputs)
338353

354+
if streamer is not None:
355+
streamer.put(numpy.array(tokens))
356+
339357
# create the generated output
340358
max_tokens = (
341359
self.max_generated_tokens
@@ -354,12 +372,18 @@ def engine_forward(
354372
generated_tokens.append(token)
355373
generated_logits.append(logits)
356374

375+
if streamer is not None:
376+
streamer.put(numpy.array([token]))
377+
357378
if (
358379
token == self.tokenizer.eos_token_id
359380
and not self.force_max_tokens
360381
):
361382
break
362383

384+
if streamer is not None:
385+
streamer.end()
386+
363387
return numpy.array([generated_tokens]), numpy.concatenate(
364388
generated_logits, axis=1
365389
)

tests/deepsparse/pipelines/test_pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_pipeline_call_is_async(engine_mock):
132132
executor = ThreadPoolExecutor(max_workers=1)
133133
pipeline = Pipeline.create("token_classification", batch_size=1, executor=executor)
134134

135-
def sleep_then_engine_forward(xs):
135+
def sleep_then_engine_forward(xs, context):
136136
# each call to engine_forward also sleeps
137137
time.sleep(20 / 1000)
138138
return pipeline.engine(xs)

0 commit comments

Comments
 (0)