Skip to content

Commit 0907842

Browse files
committed
added tests
1 parent 8d0768d commit 0907842

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed

tests/deepsparse/transformers/pipelines/test_text_generation.py

+112
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,118 @@ def test_run_same_prompt_multiple_times(pipeline, prompt):
5555
)
5656

5757

58+
def _test_stop_inference_kv_cache_full(
59+
pipeline,
60+
prompt,
61+
max_new_tokens,
62+
expected_finished_reason,
63+
expected_generated_tokens_length=None,
64+
):
65+
out = pipeline(prompt=prompt, max_new_tokens=max_new_tokens)
66+
kv_cache_state = out.kv_cache_state[0]
67+
finished_reason = out.generations[0].finished_reason
68+
generated_text = out.generations[0].text
69+
assert finished_reason == expected_finished_reason
70+
assert len(pipeline.tokenizer(generated_text)["input_ids"]) == (
71+
expected_generated_tokens_length or max_new_tokens
72+
)
73+
return kv_cache_state
74+
75+
76+
def test_stop_inference_kv_cache_full(prompt):
77+
# Tests the proper behavior of the kv cache around the
78+
# scenario when the kv cache becomes full during the inference
79+
80+
# We set the sequence length to a small value to assert that
81+
# the kv cache buffer fills up quickly
82+
sequence_length = 32
83+
# We set the prompt sequence length to 1 to assert that the
84+
# inference will run until the kv cache is full. If the
85+
# `prompt_sequence_length` is larger than 1, it is very probable
86+
# that the inference will stop before the kv cache is full
87+
# (as the `prompt_sequence_length` reduces the number of
88+
# tokens that are generated in the first iteration)
89+
prompt_sequence_length = 1
90+
91+
pipeline = Pipeline.create(
92+
task="text_generation",
93+
model_path="hf:mgoin/TinyStories-1M-deepsparse",
94+
engine_type="onnxruntime",
95+
sequence_length=sequence_length,
96+
force_max_tokens=True,
97+
prompt_sequence_length=prompt_sequence_length,
98+
)
99+
pipeline._debug = True
100+
101+
prompt_length = len(pipeline.tokenizer(prompt)["input_ids"])
102+
103+
cache_capacity = sequence_length - prompt_sequence_length
104+
# we need to subtract 1 to account for the initial generated token during the
105+
# prompt inference
106+
cache_capacity -= 1
107+
108+
# max_new_tokens so that there is still one more "free" space in the kv cache
109+
# (we can still do autoregressive inference)
110+
max_new_tokens_minus_one = cache_capacity - prompt_length - 1
111+
# max_new_tokens so that the kv cache is full
112+
# (so we can still do one last correct autoregressive
113+
# inference in the next iteration)
114+
max_new_tokens = cache_capacity - prompt_length
115+
# max_new_tokens so that kv cache has already removed the last entry
116+
# (so we can no longer do autoregressive inference in the next iteration)
117+
max_new_tokens_plus_one = cache_capacity - prompt_length + 1
118+
# max_new_tokens so that kv cache would remove two last entries
119+
# (but it will not, the inference terminates early and produces
120+
# the same result as max_new_tokens_plus_one)
121+
max_new_tokens_plus_two = cache_capacity - prompt_length + 2
122+
123+
kv_cache_state_full_minus_one = _test_stop_inference_kv_cache_full(
124+
pipeline,
125+
prompt,
126+
max_new_tokens_minus_one,
127+
expected_finished_reason="max_new_tokens",
128+
)
129+
kv_cache_state_full = _test_stop_inference_kv_cache_full(
130+
pipeline, prompt, max_new_tokens, expected_finished_reason="max_new_tokens"
131+
)
132+
kv_cache_state_full_plus_one = _test_stop_inference_kv_cache_full(
133+
pipeline, prompt, max_new_tokens_plus_one, expected_finished_reason="capacity"
134+
)
135+
kv_cache_state_full_plus_two = _test_stop_inference_kv_cache_full(
136+
pipeline,
137+
prompt,
138+
max_new_tokens_plus_two,
139+
expected_generated_tokens_length=max_new_tokens_plus_one,
140+
expected_finished_reason="capacity",
141+
)
142+
"""
143+
Check the following structure ok the kv cache:
144+
minus_one | full | plus_one | plus_two
145+
--------------------------------------
146+
[- 0 -] | [row A] | [row B] | [row B]
147+
[row A] | [row B] | [row C] | [row C]
148+
[row B] | [row C] | [row D] | [row D]
149+
... | ... | ... | ...
150+
"""
151+
# check for the "free" space in the kv cache
152+
assert kv_cache_state_full_minus_one["past_key_values.0.key"][:, :, 0, :].sum() == 0
153+
# check for the row A
154+
assert numpy.allclose(
155+
kv_cache_state_full_minus_one["past_key_values.0.key"][:, :, 1, :],
156+
kv_cache_state_full["past_key_values.0.key"][:, :, 0, :],
157+
)
158+
# check for the row B
159+
assert numpy.allclose(
160+
kv_cache_state_full["past_key_values.0.key"][:, :, 1, :],
161+
kv_cache_state_full_plus_one["past_key_values.0.key"][:, :, 0, :],
162+
)
163+
# check equality between plus_one and plus_two
164+
assert numpy.allclose(
165+
kv_cache_state_full_plus_one["past_key_values.0.key"],
166+
kv_cache_state_full_plus_two["past_key_values.0.key"],
167+
)
168+
169+
58170
def test_run_multiple_prompts_in_parallel(pipeline, prompt):
59171
# Test the scenario, where multiple prompts are run in parallel
60172
# Same two prompts should produce the same output

0 commit comments

Comments
 (0)