Skip to content

Commit 92b2fac

Browse files
committed
Add types to _TextGenerationTimings attributes
Revert to using timer.time for `TOKEN_GENERATION` Remove finally clause from `contextmanagers` Address review comments from @rahul-tuli
1 parent 3ad06e3 commit 92b2fac

File tree

2 files changed

+22
-26
lines changed

2 files changed

+22
-26
lines changed

src/deepsparse/transformers/pipelines/text_generation.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@
3535

3636
@dataclass(frozen=True)
3737
class _TextGenerationTimings:
38-
PROMPT_PREFILL = "engine_prompt_prefill"
39-
PROMPT_PREFILL_SINGLE = "engine_prompt_prefill_single"
40-
TOKEN_GENERATION = "engine_token_generation"
41-
TOKEN_GENERATION_SINGLE = "engine_token_generation_single"
38+
PROMPT_PREFILL: str = "engine_prompt_prefill"
39+
PROMPT_PREFILL_SINGLE: str = "engine_prompt_prefill_single"
40+
TOKEN_GENERATION: str = "engine_token_generation"
41+
TOKEN_GENERATION_SINGLE: str = "engine_token_generation_single"
4242

4343

4444
class TextGenerationInput(BaseModel):
@@ -344,17 +344,19 @@ def engine_forward(
344344
generated_tokens = [tokens[-1]]
345345
generated_logits = prompt_logits
346346

347-
timer.start(_TextGenerationTimings.TOKEN_GENERATION)
348-
while len(generated_tokens) < max_tokens:
349-
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
350-
token, logits = self.autoregressive_inference(tokens)
351-
tokens.append(token)
352-
generated_tokens.append(token)
353-
generated_logits.append(logits)
354-
355-
if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
356-
break
357-
timer.stop(_TextGenerationTimings.TOKEN_GENERATION)
347+
with timer.time(_TextGenerationTimings.TOKEN_GENERATION):
348+
while len(generated_tokens) < max_tokens:
349+
with timer.time(_TextGenerationTimings.TOKEN_GENERATION_SINGLE):
350+
token, logits = self.autoregressive_inference(tokens)
351+
tokens.append(token)
352+
generated_tokens.append(token)
353+
generated_logits.append(logits)
354+
355+
if (
356+
token == self.tokenizer.eos_token_id
357+
and not self.force_max_tokens
358+
):
359+
break
358360

359361
return numpy.array([generated_tokens]), numpy.concatenate(
360362
generated_logits, axis=1

src/deepsparse/utils/timer.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,8 @@ def time(self, stage: str):
128128
:param stage: the name of the stage to time
129129
"""
130130
self.start(stage)
131-
132-
try:
133-
yield
134-
finally:
135-
self.stop(stage)
131+
yield
132+
self.stop(stage)
136133

137134
def start(self, stage: str):
138135
"""
@@ -363,9 +360,6 @@ def new_timer_context(self, total_inference: bool = True) -> StagedTimer:
363360
self._timers = [timer]
364361

365362
timer_context.set(timer)
366-
367-
try:
368-
yield timer
369-
finally:
370-
if total_inference:
371-
timer.stop(InferenceStages.TOTAL_INFERENCE)
363+
yield timer
364+
if total_inference:
365+
timer.stop(InferenceStages.TOTAL_INFERENCE)

0 commit comments

Comments
 (0)