@@ -55,6 +55,118 @@ def test_run_same_prompt_multiple_times(pipeline, prompt):
55
55
)
56
56
57
57
58
+ def _test_stop_inference_kv_cache_full (
59
+ pipeline ,
60
+ prompt ,
61
+ max_new_tokens ,
62
+ expected_finished_reason ,
63
+ expected_generated_tokens_length = None ,
64
+ ):
65
+ out = pipeline (prompt = prompt , max_new_tokens = max_new_tokens )
66
+ kv_cache_state = out .kv_cache_state [0 ]
67
+ finished_reason = out .generations [0 ].finished_reason
68
+ generated_text = out .generations [0 ].text
69
+ assert finished_reason == expected_finished_reason
70
+ assert len (pipeline .tokenizer (generated_text )["input_ids" ]) == (
71
+ expected_generated_tokens_length or max_new_tokens
72
+ )
73
+ return kv_cache_state
74
+
75
+
76
+ def test_stop_inference_kv_cache_full (prompt ):
77
+ # Tests the proper behavior of the kv cache around the
78
+ # scenario when the kv cache becomes full during the inference
79
+
80
+ # We set the sequence length to a small value to assert that
81
+ # the kv cache buffer fills up quickly
82
+ sequence_length = 32
83
+ # We set the prompt sequence length to 1 to assert that the
84
+ # inference will run until the kv cache is full. If the
85
+ # `prompt_sequence_length` is larger than 1, it is very probable
86
+ # that the inference will stop before the kv cache is full
87
+ # (as the `prompt_sequence_length` reduces the number of
88
+ # tokens that are generated in the first iteration)
89
+ prompt_sequence_length = 1
90
+
91
+ pipeline = Pipeline .create (
92
+ task = "text_generation" ,
93
+ model_path = "hf:mgoin/TinyStories-1M-deepsparse" ,
94
+ engine_type = "onnxruntime" ,
95
+ sequence_length = sequence_length ,
96
+ force_max_tokens = True ,
97
+ prompt_sequence_length = prompt_sequence_length ,
98
+ )
99
+ pipeline ._debug = True
100
+
101
+ prompt_length = len (pipeline .tokenizer (prompt )["input_ids" ])
102
+
103
+ cache_capacity = sequence_length - prompt_sequence_length
104
+ # we need to subtract 1 to account for the initial generated token during the
105
+ # prompt inference
106
+ cache_capacity -= 1
107
+
108
+ # max_new_tokens so that there is still one more "free" space in the kv cache
109
+ # (we can still do autoregressive inference)
110
+ max_new_tokens_minus_one = cache_capacity - prompt_length - 1
111
+ # max_new_tokens so that the kv cache is full
112
+ # (so we can still do one last correct autoregressive
113
+ # inference in the next iteration)
114
+ max_new_tokens = cache_capacity - prompt_length
115
+ # max_new_tokens so that kv cache has already removed the last entry
116
+ # (so we can no longer do autoregressive inference in the next iteration)
117
+ max_new_tokens_plus_one = cache_capacity - prompt_length + 1
118
+ # max_new_tokens so that kv cache would remove two last entries
119
+ # (but it will not, the inference terminates early and produces
120
+ # the same result as max_new_tokens_plus_one)
121
+ max_new_tokens_plus_two = cache_capacity - prompt_length + 2
122
+
123
+ kv_cache_state_full_minus_one = _test_stop_inference_kv_cache_full (
124
+ pipeline ,
125
+ prompt ,
126
+ max_new_tokens_minus_one ,
127
+ expected_finished_reason = "max_new_tokens" ,
128
+ )
129
+ kv_cache_state_full = _test_stop_inference_kv_cache_full (
130
+ pipeline , prompt , max_new_tokens , expected_finished_reason = "max_new_tokens"
131
+ )
132
+ kv_cache_state_full_plus_one = _test_stop_inference_kv_cache_full (
133
+ pipeline , prompt , max_new_tokens_plus_one , expected_finished_reason = "capacity"
134
+ )
135
+ kv_cache_state_full_plus_two = _test_stop_inference_kv_cache_full (
136
+ pipeline ,
137
+ prompt ,
138
+ max_new_tokens_plus_two ,
139
+ expected_generated_tokens_length = max_new_tokens_plus_one ,
140
+ expected_finished_reason = "capacity" ,
141
+ )
142
+ """
143
+ Check the following structure ok the kv cache:
144
+ minus_one | full | plus_one | plus_two
145
+ --------------------------------------
146
+ [- 0 -] | [row A] | [row B] | [row B]
147
+ [row A] | [row B] | [row C] | [row C]
148
+ [row B] | [row C] | [row D] | [row D]
149
+ ... | ... | ... | ...
150
+ """
151
+ # check for the "free" space in the kv cache
152
+ assert kv_cache_state_full_minus_one ["past_key_values.0.key" ][:, :, 0 , :].sum () == 0
153
+ # check for the row A
154
+ assert numpy .allclose (
155
+ kv_cache_state_full_minus_one ["past_key_values.0.key" ][:, :, 1 , :],
156
+ kv_cache_state_full ["past_key_values.0.key" ][:, :, 0 , :],
157
+ )
158
+ # check for the row B
159
+ assert numpy .allclose (
160
+ kv_cache_state_full ["past_key_values.0.key" ][:, :, 1 , :],
161
+ kv_cache_state_full_plus_one ["past_key_values.0.key" ][:, :, 0 , :],
162
+ )
163
+ # check equality between plus_one and plus_two
164
+ assert numpy .allclose (
165
+ kv_cache_state_full_plus_one ["past_key_values.0.key" ],
166
+ kv_cache_state_full_plus_two ["past_key_values.0.key" ],
167
+ )
168
+
169
+
58
170
def test_run_multiple_prompts_in_parallel (pipeline , prompt ):
59
171
# Test the scenario, where multiple prompts are run in parallel
60
172
# Same two prompts should produce the same output
0 commit comments