40
40
from deepsparse .transformers .utils .helpers import (
41
41
create_causal_mask ,
42
42
pad_to_fixed_length ,
43
+ repeat_inputs ,
43
44
)
44
45
from deepsparse .transformers .utils .timings import TextGenerationTimings
45
46
from deepsparse .utils .onnx import default_cached_outputs
@@ -57,6 +58,18 @@ class Config:
57
58
sequences : Union [str , List [str ]] = Field (
58
59
description = "The input sequences to generate the text from." ,
59
60
)
61
+ num_generated_predictions : int = Field (
62
+ default = 1 ,
63
+ description = "The number of text generations to create from a single prompt. If "
64
+ "the same sequence is given as an input multiple times, the number of generated"
65
+ "the number of generated predictins is equivalent to the number of times the "
66
+ "the sequence is repeated." ,
67
+ )
68
+ max_tokens : int = Field (
69
+ default = 1024 ,
70
+ description = "Maximum number of tokens to generate per output sequence. If no "
71
+ "value is provided, will default to 1024." ,
72
+ )
60
73
return_logits : bool = Field (
61
74
default = False ,
62
75
description = "A flag that indicates whether to return "
@@ -110,7 +123,7 @@ class Config:
110
123
111
124
112
125
class TextGenerationOutput (BaseModel ):
113
- sequences : Union [str , List [str ]] = Field (
126
+ sequences : Union [str , List [str ], List [ List [ str ]] ] = Field (
114
127
description = "The generated text sequences." ,
115
128
)
116
129
logits : Optional [Any ] = Field ( # numpy array, set to Any for FastAPI compatibility
@@ -143,11 +156,6 @@ class TextGenerationPipeline(TransformersPipeline):
143
156
from the probability distribution computed from the logits.
144
157
Higher values will result in more random samples. Should
145
158
be greater than 0.0.
146
- :param max_generated_tokens: the maximum number of tokens to generate
147
- given the input sequence. If None, the model will generate
148
- tokens until the end of the sequence is reached.
149
- Otherwise, it will generate up to the maximum number of tokens or end of
150
- sequence is reached.
151
159
:param sequence_length: sequence length to compile model and tokenizer for.
152
160
This controls the maximum context length of the pipeline. Default is 512
153
161
:param prompt_sequence_length: For large prompts, the prompt is
@@ -164,7 +172,6 @@ def __init__(
164
172
self ,
165
173
deterministic : bool = True ,
166
174
sampling_temperature : float = 1.0 ,
167
- max_generated_tokens : Optional [int ] = 1024 ,
168
175
prompt_sequence_length : int = 64 ,
169
176
sequence_length : int = 512 ,
170
177
force_max_tokens : bool = False ,
@@ -203,16 +210,8 @@ def __init__(
203
210
if "WAND_OPT_FLAGS" not in os .environ :
204
211
os .environ ["WAND_OPT_FLAGS" ] = "default,~pyramids"
205
212
206
- if not self .cache_support_enabled and max_generated_tokens > 1 :
207
- raise ValueError (
208
- "The model used for inference does not support kv cache. It is "
209
- "assumed that it maps from the token sequence to predicted logits."
210
- "Set `max_generated_tokens` to 1 to support that scenario."
211
- )
212
-
213
213
self .deterministic = deterministic
214
214
self .sampling_temperature = sampling_temperature
215
- self .max_generated_tokens = max_generated_tokens
216
215
self .prompt_sequence_length = prompt_sequence_length
217
216
self .force_max_tokens = force_max_tokens
218
217
self .internal_kv_cache = internal_kv_cache
@@ -369,6 +368,26 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
369
368
:param inputs: the input schema for the pipeline
370
369
:return: the inputs for the engine
371
370
"""
371
+ if not self .cache_support_enabled and inputs .max_tokens > 1 :
372
+ raise ValueError (
373
+ "The model used for inference does not support kv cache. It is "
374
+ "assumed that it maps from the token sequence to predicted logits."
375
+ "Set `max_tokens` to 1 to support that scenario."
376
+ )
377
+
378
+ # If the num_generated_predictions > 1, repeat the prompt
379
+ # num_generated_predictions times. Also, update the engine so that deterministic
380
+ # is set to False.
381
+ if inputs .num_generated_predictions > 1 :
382
+ if isinstance (inputs .sequences , str ):
383
+ inputs .sequences = [inputs .sequences ]
384
+ inputs .sequences = repeat_inputs (
385
+ inputs .sequences , inputs .num_generated_predictions
386
+ )
387
+ if self .engine :
388
+ self .engine .deterministic = False
389
+ if self .multitoken_engine :
390
+ self .multitoken_engine .deterministic = False
372
391
373
392
if inputs .fixed_sequences_length or not self .cache_support_enabled :
374
393
# to enforce a fixed sequence length, we need to
@@ -414,14 +433,16 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
414
433
self .engine .session_id = inputs .session_id
415
434
self .multitoken_engine .session_id = inputs .session_id
416
435
417
- postprocessing_kwargs = dict (
436
+ context = dict (
437
+ num_generated_predictions = inputs .num_generated_predictions ,
418
438
return_logits = inputs .return_logits ,
419
439
streamer = inputs .streamer ,
420
440
include_prompt_logits = inputs .include_prompt_logits ,
421
441
callback = inputs .callback ,
422
442
stop = inputs .stop ,
443
+ max_tokens = inputs .max_tokens ,
423
444
)
424
- return engine_input , postprocessing_kwargs
445
+ return engine_input , context
425
446
426
447
def process_engine_outputs (
427
448
self , engine_outputs : List [numpy .ndarray ], ** kwargs
@@ -436,6 +457,18 @@ def process_engine_outputs(
436
457
sequences = self .tokenizer .batch_decode (
437
458
generated_tokens , skip_special_tokens = True
438
459
)
460
+ num_preds = kwargs .get ("num_generated_predictions" , 1 )
461
+ # If the num_generated_predictions > 1, group the generated sequences and return
462
+ # the sequences as a list of lists where each list consists of the generated
463
+ # predictions for a given prompt, and all the lists are in the order matching
464
+ # the order that the prompts were given as inputs.
465
+ if num_preds > 1 :
466
+ grouped_seq = [
467
+ sequences [n : n + num_preds ]
468
+ for n in range (0 , len (sequences ), num_preds )
469
+ ]
470
+ sequences = grouped_seq
471
+
439
472
logits = generated_logits if kwargs .get ("return_logits" ) else None
440
473
441
474
return TextGenerationOutput (sequences = sequences , logits = logits )
@@ -472,11 +505,8 @@ def engine_forward(
472
505
streamer .put (numpy .array (tokens ))
473
506
474
507
# create the generated output
475
- max_tokens = (
476
- self .max_generated_tokens
477
- if self .max_generated_tokens and self .max_generated_tokens > 0
478
- else 100 * self .sequence_length
479
- ) # set safety for absolute max generation
508
+ max_tokens = context .get ("max_tokens" , 0 )
509
+ max_tokens = max_tokens if max_tokens > 0 else (100 * self .sequence_length )
480
510
481
511
# last prompt token is the first generated token
482
512
# add it to generated tokens, and the logits
0 commit comments