Skip to content

Commit 86ff638

Browse files
authored
Merge branch 'main' into feature_readme
2 parents 730f6ab + 8e7f3cc commit 86ff638

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

src/deepsparse/transformers/pipelines/text_generation.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,15 @@ class Config:
138138
description="GenerationConfig file consisting of parameters used to control "
139139
"sequences generated for each prompt. The current supported parameters are: "
140140
"max_length, max_new_tokens, num_return_sequences, output_scores, top_p, "
141-
"top_k, repetition_penalty, do_sample, temperature",
141+
"top_k, repetition_penalty, do_sample, temperature. If None is provided, "
142+
"deepsparse defaults will be used. For all other input types, HuggingFace "
143+
"defaults for GenerationConfig will be used. ",
142144
)
143145

144-
kwargs: Optional[Dict] = Field(
146+
generation_kwargs: Optional[Dict] = Field(
145147
default=None,
146148
description="Any arguments to override generation_config arguments. Refer to "
147-
"the generation_config argument for a full list of supported variables. Only "
148-
"valid when generation_config is not None.",
149+
"the generation_config argument for a full list of supported variables.",
149150
)
150151

151152

@@ -201,6 +202,12 @@ class TextGenerationPipeline(TransformersPipeline):
201202
of tokens supplied even if the stop token is reached.
202203
:param internal_kv_cache: if True, the pipeline will use the deepsparse kv cache
203204
for caching the model outputs.
205+
:param generation_config: config file consisting of parameters used to control
206+
sequences generated for each prompt. The current supported parameters are:
207+
max_length, max_new_tokens, num_return_sequences, output_scores, top_p,
208+
top_k, repetition_penalty, do_sample, temperature. If None is provided,
209+
deepsparse defaults will be used. For all other input types, HuggingFace
210+
defaults for GenerationConfig will be used.
204211
:param kwargs: kwargs to pass to the TransformersPipeline
205212
"""
206213

@@ -409,6 +416,7 @@ def parse_inputs(self, *args, **kwargs) -> TextGenerationInput:
409416
if "sequences" in kwargs and "prompt" not in kwargs:
410417
# support prompt and sequences interchangeably
411418
kwargs["prompt"] = kwargs["sequences"]
419+
412420
if (
413421
args
414422
and not isinstance(args[0], TextGenerationInput)
@@ -419,6 +427,14 @@ def parse_inputs(self, *args, **kwargs) -> TextGenerationInput:
419427
kwargs["prompt"] = args[0]
420428
args = args[1:]
421429

430+
if kwargs:
431+
generation_kwargs = kwargs.get("generation_kwargs", {})
432+
for k, v in kwargs.items():
433+
if not generation_kwargs.get(k) and hasattr(GenerationDefaults, k):
434+
generation_kwargs[k] = v
435+
436+
kwargs["generation_kwargs"] = generation_kwargs
437+
422438
return super().parse_inputs(*args, **kwargs)
423439

424440
def process_inputs(
@@ -434,7 +450,7 @@ def process_inputs(
434450
self.generation_config, inputs.generation_config, GenerationDefaults()
435451
)
436452

437-
generation_config = override_config(inputs.kwargs, generation_config)
453+
generation_config = override_config(inputs.generation_kwargs, generation_config)
438454

439455
self.streaming = inputs.streaming
440456
if not self.cache_support_enabled and generation_config.max_length > 1:
@@ -527,10 +543,10 @@ def _create_generated_text_output(
527543
finished=False,
528544
)
529545

530-
def _stream_engine_outputs(self, engine_outputs, prompts, kwargs):
546+
def _stream_engine_outputs(self, engine_outputs, prompts, generation_config):
531547
for output in engine_outputs:
532548
generated_tokens, generated_logits, finished_reason = output
533-
logits = generated_logits if kwargs.get("return_logits") else None
549+
logits = generated_logits if generation_config.output_scores else None
534550
generation = self._create_generated_text_output(
535551
self.tokenizer.batch_decode(generated_tokens)[0],
536552
finished_reason[0],
@@ -557,7 +573,9 @@ def process_engine_outputs(
557573
streaming = kwargs.get("streaming")
558574

559575
if streaming:
560-
return self._stream_engine_outputs(engine_outputs, prompts, kwargs)
576+
return self._stream_engine_outputs(
577+
engine_outputs, prompts, generation_config
578+
)
561579

562580
if self._debug:
563581
(

src/deepsparse/transformers/utils/helpers.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,14 @@ def override_config(
246246
return generation_config
247247

248248
for k, v in overrides.items():
249-
try:
250-
if getattr(generation_config, k):
251-
setattr(generation_config, k, v)
252-
_LOGGER.debug(f"Overriding attribute {k} in the generation config")
253-
except AttributeError as exception:
249+
if hasattr(generation_config, k):
250+
setattr(generation_config, k, v)
251+
_LOGGER.debug(f"Overriding attribute {k} in the generation config")
252+
else:
254253
raise AttributeError(
255-
"Argument provided for GenerationConfig is not "
254+
f"Argument {k} provided for GenerationConfig is not "
256255
"valid. Refer to the TextGenerationInput for supported attributes. "
257-
) from exception
256+
)
258257

259258
return generation_config
260259

0 commit comments

Comments
 (0)