@@ -138,14 +138,15 @@ class Config:
138
138
description = "GenerationConfig file consisting of parameters used to control "
139
139
"sequences generated for each prompt. The current supported parameters are: "
140
140
"max_length, max_new_tokens, num_return_sequences, output_scores, top_p, "
141
- "top_k, repetition_penalty, do_sample, temperature" ,
141
+ "top_k, repetition_penalty, do_sample, temperature. If None is provided, "
142
+ "deepsparse defaults will be used. For all other input types, HuggingFace "
143
+ "defaults for GenerationConfig will be used. " ,
142
144
)
143
145
144
- kwargs : Optional [Dict ] = Field (
146
+ generation_kwargs : Optional [Dict ] = Field (
145
147
default = None ,
146
148
description = "Any arguments to override generation_config arguments. Refer to "
147
- "the generation_config argument for a full list of supported variables. Only "
148
- "valid when generation_config is not None." ,
149
+ "the generation_config argument for a full list of supported variables." ,
149
150
)
150
151
151
152
@@ -201,6 +202,12 @@ class TextGenerationPipeline(TransformersPipeline):
201
202
of tokens supplied even if the stop token is reached.
202
203
:param internal_kv_cache: if True, the pipeline will use the deepsparse kv cache
203
204
for caching the model outputs.
205
+ :param generation_config: config file consisting of parameters used to control
206
+ sequences generated for each prompt. The current supported parameters are:
207
+ max_length, max_new_tokens, num_return_sequences, output_scores, top_p,
208
+ top_k, repetition_penalty, do_sample, temperature. If None is provided,
209
+ deepsparse defaults will be used. For all other input types, HuggingFace
210
+ defaults for GenerationConfig will be used.
204
211
:param kwargs: kwargs to pass to the TransformersPipeline
205
212
"""
206
213
@@ -409,6 +416,7 @@ def parse_inputs(self, *args, **kwargs) -> TextGenerationInput:
409
416
if "sequences" in kwargs and "prompt" not in kwargs :
410
417
# support prompt and sequences interchangeably
411
418
kwargs ["prompt" ] = kwargs ["sequences" ]
419
+
412
420
if (
413
421
args
414
422
and not isinstance (args [0 ], TextGenerationInput )
@@ -419,6 +427,14 @@ def parse_inputs(self, *args, **kwargs) -> TextGenerationInput:
419
427
kwargs ["prompt" ] = args [0 ]
420
428
args = args [1 :]
421
429
430
+ if kwargs :
431
+ generation_kwargs = kwargs .get ("generation_kwargs" , {})
432
+ for k , v in kwargs .items ():
433
+ if not generation_kwargs .get (k ) and hasattr (GenerationDefaults , k ):
434
+ generation_kwargs [k ] = v
435
+
436
+ kwargs ["generation_kwargs" ] = generation_kwargs
437
+
422
438
return super ().parse_inputs (* args , ** kwargs )
423
439
424
440
def process_inputs (
@@ -434,7 +450,7 @@ def process_inputs(
434
450
self .generation_config , inputs .generation_config , GenerationDefaults ()
435
451
)
436
452
437
- generation_config = override_config (inputs .kwargs , generation_config )
453
+ generation_config = override_config (inputs .generation_kwargs , generation_config )
438
454
439
455
self .streaming = inputs .streaming
440
456
if not self .cache_support_enabled and generation_config .max_length > 1 :
@@ -527,10 +543,10 @@ def _create_generated_text_output(
527
543
finished = False ,
528
544
)
529
545
530
- def _stream_engine_outputs (self , engine_outputs , prompts , kwargs ):
546
+ def _stream_engine_outputs (self , engine_outputs , prompts , generation_config ):
531
547
for output in engine_outputs :
532
548
generated_tokens , generated_logits , finished_reason = output
533
- logits = generated_logits if kwargs . get ( "return_logits" ) else None
549
+ logits = generated_logits if generation_config . output_scores else None
534
550
generation = self ._create_generated_text_output (
535
551
self .tokenizer .batch_decode (generated_tokens )[0 ],
536
552
finished_reason [0 ],
@@ -557,7 +573,9 @@ def process_engine_outputs(
557
573
streaming = kwargs .get ("streaming" )
558
574
559
575
if streaming :
560
- return self ._stream_engine_outputs (engine_outputs , prompts , kwargs )
576
+ return self ._stream_engine_outputs (
577
+ engine_outputs , prompts , generation_config
578
+ )
561
579
562
580
if self ._debug :
563
581
(
0 commit comments