Skip to content

[Text Generation] Terminate the inference when kv cache is full #1446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,11 @@ def engine_forward(
generated_tokens.append(token)
generated_logits.append(logits)

if session.total_num_processed_tokens >= session.capacity:
# if the kv cache is full, stop generation
finished_reason.append(FinishReason.CAPACITY)
break

if (
token == self.tokenizer.eos_token_id
and not self.force_max_tokens
Expand Down
112 changes: 112 additions & 0 deletions tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,118 @@ def test_run_same_prompt_multiple_times(pipeline, prompt):
)


def _test_stop_inference_kv_cache_full(
pipeline,
prompt,
max_new_tokens,
expected_finished_reason,
expected_generated_tokens_length=None,
):
out = pipeline(prompt=prompt, max_new_tokens=max_new_tokens)
kv_cache_state = out.kv_cache_state[0]
finished_reason = out.generations[0].finished_reason
generated_text = out.generations[0].text
assert finished_reason == expected_finished_reason
assert len(pipeline.tokenizer(generated_text)["input_ids"]) == (
expected_generated_tokens_length or max_new_tokens
)
return kv_cache_state


def test_stop_inference_kv_cache_full(prompt):
# Tests the proper behavior of the kv cache around the
# scenario when the kv cache becomes full during the inference

# We set the sequence length to a small value to assert that
# the kv cache buffer fills up quickly
sequence_length = 32
# We set the prompt sequence length to 1 to assert that the
# inference will run until the kv cache is full. If the
# `prompt_sequence_length` is larger than 1, it is very probable
# that the inference will stop before the kv cache is full
# (as the `prompt_sequence_length` reduces the number of
# tokens that are generated in the first iteration)
prompt_sequence_length = 1

pipeline = Pipeline.create(
task="text_generation",
model_path="hf:mgoin/TinyStories-1M-deepsparse",
engine_type="onnxruntime",
sequence_length=sequence_length,
force_max_tokens=True,
prompt_sequence_length=prompt_sequence_length,
)
pipeline._debug = True

prompt_length = len(pipeline.tokenizer(prompt)["input_ids"])

cache_capacity = sequence_length - prompt_sequence_length
# we need to subtract 1 to account for the initial generated token during the
# prompt inference
cache_capacity -= 1

# max_new_tokens so that there is still one more "free" space in the kv cache
# (we can still do autoregressive inference)
max_new_tokens_minus_one = cache_capacity - prompt_length - 1
# max_new_tokens so that the kv cache is full
# (so we can still do one last correct autoregressive
# inference in the next iteration)
max_new_tokens = cache_capacity - prompt_length
# max_new_tokens so that kv cache has already removed the last entry
# (so we can no longer do autoregressive inference in the next iteration)
max_new_tokens_plus_one = cache_capacity - prompt_length + 1
# max_new_tokens so that kv cache would remove two last entries
# (but it will not, the inference terminates early and produces
# the same result as max_new_tokens_plus_one)
max_new_tokens_plus_two = cache_capacity - prompt_length + 2

kv_cache_state_full_minus_one = _test_stop_inference_kv_cache_full(
pipeline,
prompt,
max_new_tokens_minus_one,
expected_finished_reason="max_new_tokens",
)
kv_cache_state_full = _test_stop_inference_kv_cache_full(
pipeline, prompt, max_new_tokens, expected_finished_reason="max_new_tokens"
)
kv_cache_state_full_plus_one = _test_stop_inference_kv_cache_full(
pipeline, prompt, max_new_tokens_plus_one, expected_finished_reason="capacity"
)
kv_cache_state_full_plus_two = _test_stop_inference_kv_cache_full(
pipeline,
prompt,
max_new_tokens_plus_two,
expected_generated_tokens_length=max_new_tokens_plus_one,
expected_finished_reason="capacity",
)
"""
Check the following structure ok the kv cache:
minus_one | full | plus_one | plus_two
--------------------------------------
[- 0 -] | [row A] | [row B] | [row B]
[row A] | [row B] | [row C] | [row C]
[row B] | [row C] | [row D] | [row D]
... | ... | ... | ...
"""
# check for the "free" space in the kv cache
assert kv_cache_state_full_minus_one["past_key_values.0.key"][:, :, 0, :].sum() == 0
# check for the row A
assert numpy.array_equal(
kv_cache_state_full_minus_one["past_key_values.0.key"][:, :, 1, :],
kv_cache_state_full["past_key_values.0.key"][:, :, 0, :],
)
# check for the row B
assert numpy.array_equal(
kv_cache_state_full["past_key_values.0.key"][:, :, 1, :],
kv_cache_state_full_plus_one["past_key_values.0.key"][:, :, 0, :],
)
# check equality between plus_one and plus_two
assert numpy.array_equal(
kv_cache_state_full_plus_one["past_key_values.0.key"],
kv_cache_state_full_plus_two["past_key_values.0.key"],
)


def test_run_multiple_prompts_in_parallel(pipeline, prompt):
# Test the scenario, where multiple prompts are run in parallel
# Same two prompts should produce the same output
Expand Down