|
35 | 35 |
|
36 | 36 | @dataclass(frozen=True)
|
37 | 37 | class _TextGenerationTimings:
|
38 |
| - PROMPT_PREFILL = "engine_prompt_prefill" |
39 |
| - PROMPT_PREFILL_SINGLE = "engine_prompt_prefill_single" |
40 |
| - TOKEN_GENERATION = "engine_token_generation" |
41 |
| - TOKEN_GENERATION_SINGLE = "engine_token_generation_single" |
| 38 | + PROMPT_PREFILL: str = "engine_prompt_prefill" |
| 39 | + PROMPT_PREFILL_SINGLE: str = "engine_prompt_prefill_single" |
| 40 | + TOKEN_GENERATION: str = "engine_token_generation" |
| 41 | + TOKEN_GENERATION_SINGLE: str = "engine_token_generation_single" |
42 | 42 |
|
43 | 43 |
|
44 | 44 | class TextGenerationInput(BaseModel):
|
@@ -344,17 +344,19 @@ def engine_forward(
|
344 | 344 | generated_tokens = [tokens[-1]]
|
345 | 345 | generated_logits = prompt_logits
|
346 | 346 |
|
347 |
| - timer.start(_TextGenerationTimings.TOKEN_GENERATION) |
348 |
| - while len(generated_tokens) < max_tokens: |
349 |
| - with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE): |
350 |
| - token, logits = self.autoregressive_inference(tokens) |
351 |
| - tokens.append(token) |
352 |
| - generated_tokens.append(token) |
353 |
| - generated_logits.append(logits) |
354 |
| - |
355 |
| - if token == self.tokenizer.eos_token_id and not self.force_max_tokens: |
356 |
| - break |
357 |
| - timer.stop(_TextGenerationTimings.TOKEN_GENERATION) |
| 347 | + with timer.time(_TextGenerationTimings.TOKEN_GENERATION): |
| 348 | + while len(generated_tokens) < max_tokens: |
| 349 | + with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE): |
| 350 | + token, logits = self.autoregressive_inference(tokens) |
| 351 | + tokens.append(token) |
| 352 | + generated_tokens.append(token) |
| 353 | + generated_logits.append(logits) |
| 354 | + |
| 355 | + if ( |
| 356 | + token == self.tokenizer.eos_token_id |
| 357 | + and not self.force_max_tokens |
| 358 | + ): |
| 359 | + break |
358 | 360 |
|
359 | 361 | return numpy.array([generated_tokens]), numpy.concatenate(
|
360 | 362 | generated_logits, axis=1
|
|
0 commit comments