13
13
# limitations under the License.
14
14
15
15
import logging
16
+ from dataclasses import dataclass
16
17
from typing import List , Optional , Tuple , Type , Union
17
18
18
19
import numpy
30
31
__all__ = ["TextGenerationPipeline" ]
31
32
32
33
33
- PROMPT_PREFILL = "engine_prompt_prefill"
34
- PROMPT_PREFILL_SINGLE = "engine_prompt_prefill_single"
35
- TOKEN_GENERATION = "engine_token_generation"
36
- TOKEN_GENERATION_SINGLE = "engine_token_generation_single"
34
+ @dataclass (frozen = True )
35
+ class _TextGenerationTimings :
36
+ PROMPT_PREFILL = "engine_prompt_prefill"
37
+ PROMPT_PREFILL_SINGLE = "engine_prompt_prefill_single"
38
+ TOKEN_GENERATION = "engine_token_generation"
39
+ TOKEN_GENERATION_SINGLE = "engine_token_generation_single"
37
40
38
41
39
42
class TextGenerationInput (BaseModel ):
@@ -321,7 +324,7 @@ def engine_forward(
321
324
322
325
else :
323
326
# run the prompt through
324
- with timer .time (PROMPT_PREFILL ):
327
+ with timer .time (_TextGenerationTimings . PROMPT_PREFILL ):
325
328
tokens , prompt_logits = self .prompt_inference (engine_inputs )
326
329
327
330
# create the generated output
@@ -334,17 +337,17 @@ def engine_forward(
334
337
generated_tokens = [tokens [- 1 ]]
335
338
generated_logits = prompt_logits
336
339
337
- timer .start (TOKEN_GENERATION )
340
+ timer .start (_TextGenerationTimings . TOKEN_GENERATION )
338
341
while len (generated_tokens ) < max_tokens :
339
- with timer .time (TOKEN_GENERATION_SINGLE ):
342
+ with timer .time (_TextGenerationTimings . TOKEN_GENERATION_SINGLE ):
340
343
token , logits = self .autoregressive_inference (tokens )
341
344
tokens .append (token )
342
345
generated_tokens .append (token )
343
346
generated_logits .append (logits )
344
347
345
348
if token == self .tokenizer .eos_token_id and not self .force_max_tokens :
346
349
break
347
- timer .stop (TOKEN_GENERATION )
350
+ timer .stop (_TextGenerationTimings . TOKEN_GENERATION )
348
351
349
352
return numpy .array ([generated_tokens ]), numpy .concatenate (
350
353
generated_logits , axis = 1
@@ -400,7 +403,9 @@ def prompt_inference(
400
403
401
404
for token in tokens [num_tokens_processed :]:
402
405
run_tokens .append (token )
403
- with self .timer_manager .current .time (PROMPT_PREFILL_SINGLE ):
406
+ with self .timer_manager .current .time (
407
+ _TextGenerationTimings .PROMPT_PREFILL_SINGLE
408
+ ):
404
409
new_token , new_logits = self .autoregressive_inference (
405
410
run_tokens , shift_positions_by_one = not bool (num_tokens_processed )
406
411
)
0 commit comments