Skip to content

Commit 27fec25

Browse files
author
Benjamin
committed
review suggestion - names to dataclass
1 parent 83b1412 commit 27fec25

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

Diff for: src/deepsparse/transformers/pipelines/text_generation.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16+
from dataclasses import dataclass
1617
from typing import List, Optional, Tuple, Type, Union
1718

1819
import numpy
@@ -30,10 +31,12 @@
3031
__all__ = ["TextGenerationPipeline"]
3132

3233

33-
PROMPT_PREFILL = "engine_prompt_prefill"
34-
PROMPT_PREFILL_SINGLE = "engine_prompt_prefill_single"
35-
TOKEN_GENERATION = "engine_token_generation"
36-
TOKEN_GENERATION_SINGLE = "engine_token_generation_single"
34+
@dataclass(frozen=True)
35+
class _TextGenerationTimings:
36+
PROMPT_PREFILL = "engine_prompt_prefill"
37+
PROMPT_PREFILL_SINGLE = "engine_prompt_prefill_single"
38+
TOKEN_GENERATION = "engine_token_generation"
39+
TOKEN_GENERATION_SINGLE = "engine_token_generation_single"
3740

3841

3942
class TextGenerationInput(BaseModel):
@@ -321,7 +324,7 @@ def engine_forward(
321324

322325
else:
323326
# run the prompt through
324-
with timer.time(PROMPT_PREFILL):
327+
with timer.time(_TextGenerationTimings.PROMPT_PREFILL):
325328
tokens, prompt_logits = self.prompt_inference(engine_inputs)
326329

327330
# create the generated output
@@ -334,17 +337,17 @@ def engine_forward(
334337
generated_tokens = [tokens[-1]]
335338
generated_logits = prompt_logits
336339

337-
timer.start(TOKEN_GENERATION)
340+
timer.start(_TextGenerationTimings.TOKEN_GENERATION)
338341
while len(generated_tokens) < max_tokens:
339-
with timer.time(TOKEN_GENERATION_SINGLE):
342+
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
340343
token, logits = self.autoregressive_inference(tokens)
341344
tokens.append(token)
342345
generated_tokens.append(token)
343346
generated_logits.append(logits)
344347

345348
if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
346349
break
347-
timer.stop(TOKEN_GENERATION)
350+
timer.stop(_TextGenerationTimings.TOKEN_GENERATION)
348351

349352
return numpy.array([generated_tokens]), numpy.concatenate(
350353
generated_logits, axis=1
@@ -400,7 +403,9 @@ def prompt_inference(
400403

401404
for token in tokens[num_tokens_processed:]:
402405
run_tokens.append(token)
403-
with self.timer_manager.current.time(PROMPT_PREFILL_SINGLE):
406+
with self.timer_manager.current.time(
407+
_TextGenerationTimings.PROMPT_PREFILL_SINGLE
408+
):
404409
new_token, new_logits = self.autoregressive_inference(
405410
run_tokens, shift_positions_by_one=not bool(num_tokens_processed)
406411
)

0 commit comments

Comments
 (0)