16
16
import os
17
17
import warnings
18
18
from dataclasses import dataclass
19
- from typing import Generator , List , Optional , Tuple , Type , Union
19
+ from typing import Dict , Generator , List , Optional , Tuple , Type , Union
20
20
21
21
import numpy
22
22
from pydantic import BaseModel , Field
23
+ from transformers import TextStreamer
23
24
24
25
from deepsparse import Pipeline
25
26
from deepsparse .cpu import cpu_avx512_compatible
@@ -46,6 +47,9 @@ class _TextGenerationTimings:
46
47
47
48
48
49
class TextGenerationInput (BaseModel ):
50
+ class Config :
51
+ arbitrary_types_allowed = True
52
+
49
53
sequences : Union [str , List [str ]] = Field (
50
54
description = "The input sequences to generate the text from." ,
51
55
)
@@ -71,6 +75,13 @@ class TextGenerationInput(BaseModel):
71
75
"to have consistent length so one "
72
76
"can compute metric in a batched fashion. " ,
73
77
)
78
+ streamer : Optional [TextStreamer ] = Field (
79
+ default = None ,
80
+ description = "Streamer object that will be used to stream the "
81
+ "generated sequences. Generated tokens are passed through "
82
+ "`streamer.put(token_ids)` and the streamer is responsible "
83
+ "for any further processing." ,
84
+ )
74
85
75
86
76
87
class TextGenerationOutput (BaseModel ):
@@ -290,7 +301,9 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
290
301
self .engine .session_id = inputs .session_id
291
302
self .multitoken_engine .session_id = inputs .session_id
292
303
293
- postprocessing_kwargs = dict (return_logits = inputs .return_logits )
304
+ postprocessing_kwargs = dict (
305
+ return_logits = inputs .return_logits , streamer = inputs .streamer
306
+ )
294
307
return engine_input , postprocessing_kwargs
295
308
296
309
def process_engine_outputs (
@@ -311,7 +324,7 @@ def process_engine_outputs(
311
324
return TextGenerationOutput (sequences = sequences , logits = logits )
312
325
313
326
def engine_forward (
314
- self , engine_inputs : List [numpy .ndarray ], ** kwargs
327
+ self , engine_inputs : List [numpy .ndarray ], context : Dict
315
328
) -> Tuple [numpy .ndarray , numpy .ndarray ]:
316
329
"""
317
330
Run the forward pass on the engine.
@@ -327,6 +340,8 @@ def engine_forward(
327
340
# main thread. That is why `engine_` is prepended to each of the timer phase
328
341
# names in this context
329
342
with self .timer_manager .new_timer_context (total_inference = False ) as timer :
343
+ streamer = context .get ("streamer" )
344
+
330
345
if not self .multitoken_engine .kv_cache_enabled :
331
346
tokens , prompt_logits = self .multitoken_engine (engine_inputs )
332
347
return numpy .array ([tokens ]), prompt_logits
@@ -336,6 +351,9 @@ def engine_forward(
336
351
with timer .time (_TextGenerationTimings .PROMPT_PREFILL ):
337
352
tokens , prompt_logits = self .prompt_inference (engine_inputs )
338
353
354
+ if streamer is not None :
355
+ streamer .put (numpy .array (tokens ))
356
+
339
357
# create the generated output
340
358
max_tokens = (
341
359
self .max_generated_tokens
@@ -354,12 +372,18 @@ def engine_forward(
354
372
generated_tokens .append (token )
355
373
generated_logits .append (logits )
356
374
375
+ if streamer is not None :
376
+ streamer .put (numpy .array ([token ]))
377
+
357
378
if (
358
379
token == self .tokenizer .eos_token_id
359
380
and not self .force_max_tokens
360
381
):
361
382
break
362
383
384
+ if streamer is not None :
385
+ streamer .end ()
386
+
363
387
return numpy .array ([generated_tokens ]), numpy .concatenate (
364
388
generated_logits , axis = 1
365
389
)
0 commit comments