File tree 1 file changed +6
-2
lines changed
src/deepsparse/transformers/pipelines
1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -435,13 +435,17 @@ def engine_forward(
435
435
else 100 * self .sequence_length
436
436
) # set safety for absolute max generation
437
437
438
+ # last prompt token is the first generated token
439
+ # add it to generated tokens, and the logits
438
440
generated_tokens = [tokens [- 1 ]]
439
441
generated_logits = (
440
- prompt_logits if context .get ("include_prompt_logits" ) else []
442
+ prompt_logits
443
+ if context .get ("include_prompt_logits" )
444
+ else [prompt_logits [- 1 ]]
441
445
)
442
446
443
447
with timer .time (_TextGenerationTimings .TOKEN_GENERATION ):
444
- while len (generated_tokens ) <= max_tokens :
448
+ while len (generated_tokens ) < max_tokens :
445
449
with timer .time (_TextGenerationTimings .TOKEN_GENERATION_SINGLE ):
446
450
token , logits = self .autoregressive_inference (tokens )
447
451
tokens .append (token )
You can’t perform that action at this time.
0 commit comments