Skip to content

Commit 31a71a9

Browse files
authored
Conform output of TextGenerationPipelines with HuggingfacePipelines (#1205)
* Conform with Huggingface pipelines * Match logits to generated tokens
1 parent ba617eb commit 31a71a9

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/deepsparse/transformers/pipelines/text_generation.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -435,13 +435,17 @@ def engine_forward(
435435
else 100 * self.sequence_length
436436
) # set safety for absolute max generation
437437

438+
# last prompt token is the first generated token
439+
# add it to generated tokens, and the logits
438440
generated_tokens = [tokens[-1]]
439441
generated_logits = (
440-
prompt_logits if context.get("include_prompt_logits") else []
442+
prompt_logits
443+
if context.get("include_prompt_logits")
444+
else [prompt_logits[-1]]
441445
)
442446

443447
with timer.time(_TextGenerationTimings.TOKEN_GENERATION):
444-
while len(generated_tokens) <= max_tokens:
448+
while len(generated_tokens) < max_tokens:
445449
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
446450
token, logits = self.autoregressive_inference(tokens)
447451
tokens.append(token)

0 commit comments

Comments
 (0)