Skip to content

Commit a6d46be

Browse files
dbogunowiczSageMoorebfineranBenjamin
authored
[Text Generation] Turn off the (currently) inefficient external KV cache logic when internal KV cache management enabled (#1175)
* fix kv cache * refactor * add validation pathway * avx2 support * initial commit * initial commit * initial implementation * problems with multitoken prefill * its working * Create test_nl_decoder_engine.py * almost there... * finally all tests pass * just need to change to stub * fix bad merge * added some tests * ready for review * [Text Generation][Tests] DecoderKVCache (#1154) * [Text Generation][Tests] NLDecoderEngine (#1155) * initial commit * initial commit * [Text Generation][Tests] Text Generation Pipeline (#1162) * initial implementation * problems with multitoken prefill * almost there... * finally all tests pass * just need to change to stub * fix bad merge * Make tests work with stub (as much as possible), cleanup test names, disable heavy tests, include patch for running without causal mask * initial commit * use patch from unittest library - remove additional dependency * improved logic * additional improvements * Update src/deepsparse/transformers/pipelines/text_generation.py * Update src/deepsparse/utils/onnx.py Co-authored-by: Benjamin Fineran <[email protected]> * Update src/deepsparse/utils/onnx.py Co-authored-by: Benjamin Fineran <[email protected]> * response to Ben's comments * finish rebasing * full support * Update tests/deepsparse/transformers/pipelines/test_text_generation.py * initial commit * clarify todo comment * update user messages + add assertion for safety * [Text Generation] KV Cache internal Deepsparse support (#1135) * fix kv cache * refactor * add validation pathway * avx2 support * initial commit * initial commit * initial implementation * problems with multitoken prefill * its working * almost there... * finally all tests pass * just need to change to stub * fix bad merge * added some tests * ready for review * full support --------- Co-authored-by: dbogunowicz <[email protected]> Co-authored-by: Damian <[email protected]> * minor improvements before landing * Fix the helper function that has been broken after a merge * incomplete string in parametrize * few nits before the merge * pass dummy cache if internal cache management supported * Apply suggestions from code review * add missing property * cleaner func * PR ready * add timing for KV cache update * initial commit * code review comments * Nit: docstring typo * nit: docstring style * fix style * fix broken test * fixing bad rebase --------- Co-authored-by: Sage Moore <[email protected]> Co-authored-by: Benjamin Fineran <[email protected]> Co-authored-by: Benjamin <[email protected]>
1 parent 9cd948f commit a6d46be

File tree

3 files changed

+61
-29
lines changed

3 files changed

+61
-29
lines changed

src/deepsparse/transformers/engines/nl_decoder_engine.py

+52-25
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
sampling_temperature: float = 1.0,
6565
deterministic: bool = True,
6666
engine_context: Optional[Context] = None,
67-
use_deepsparse_cache=False,
67+
use_deepsparse_cache: bool = False,
6868
):
6969
# flag to indicate if the model is quantized or not
7070
self.kv_cache_data_type = None
@@ -93,10 +93,12 @@ def __init__(
9393
engine_args=engine_args,
9494
context=engine_context,
9595
)
96+
9697
self.sequence_length = sequence_length
9798
self.sampling_temperature = sampling_temperature
9899
self.deterministic = deterministic
99100
self.input_ids_length = input_ids_length
101+
self.cache_length = sequence_length - input_ids_length
100102
self.kv_cache_enabled = kv_cache_enabled
101103
self.kv_cache = (
102104
DecoderKVCache(use_deepsparse_cache) if kv_cache_enabled else None
@@ -134,35 +136,41 @@ def onnx_input_names_no_cache(self) -> List[str]:
134136
@property
135137
def num_non_blank_cache_entries(self) -> int:
136138
"""
137-
:return a number of non-blank entries in the
138-
kv cache
139+
:return A number of non-blank entries in the
140+
kv cache
139141
"""
140142
return self.kv_cache.num_non_blank_entries
141143

144+
@property
145+
def internal_cache_active(self) -> bool:
146+
"""
147+
:return: Whether the internal kv cache is active
148+
"""
149+
return self.kv_cache_enabled and self.kv_cache.engine_internal_cache is not None
150+
142151
def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]:
143152
"""
144153
Run the engine with the given inputs.
145154
146-
If the internal deepsparse kv cache management is enable,
155+
If the self.internal_cache_active=True, the internal
156+
deepsparse kv cache management is enabled. In this case
147157
the LIB.kv_cache class object will be passed to the engine
148158
call as well.
149159
150160
:param inputs: The inputs to run the engine with
151161
:param val_inp: Whether the input is for validation or not
152-
153162
:return: The output of the engine
154163
"""
155164

156-
if self.kv_cache is not None:
157-
if self.kv_cache._kv_cache is not None:
158-
if val_inp:
159-
self.engine._validate_inputs(inputs)
160-
# model has kv cache support, as well as deepsparse
161-
# internal management of the kv cache
162-
return self.engine._eng_net.execute_list_out(
163-
inputs, self.kv_cache._kv_cache
164-
)
165-
165+
if self.internal_cache_active:
166+
# validate the inputs if needed
167+
if val_inp:
168+
self.engine._validate_inputs(inputs)
169+
# run the engine with the LIB.kv_cache object
170+
return self.engine._eng_net.execute_list_out(
171+
inputs, self.kv_cache.engine_internal_cache
172+
)
173+
# run the engine without the LIB.kv_cache object
166174
return self.engine.run(inputs, val_inp)
167175

168176
def __call__(
@@ -180,8 +188,8 @@ def __call__(
180188
:return: The generated token and corresponding logits
181189
"""
182190
if self.kv_cache:
183-
# if kv cache is enabled, we need to add the kv cache state
184-
# to the input
191+
# if model has kv cache enabled, we need
192+
# to add the kv cache state to the input
185193
inp = self.add_kv_cache_to_input(inp)
186194

187195
out = self.run(inp, val_inp)
@@ -217,8 +225,7 @@ def transfer_cache_state(self, cache: DecoderKVCache):
217225
:param cache: The `DecoderKVCache` object to transfer to the engine
218226
from
219227
"""
220-
target_cache_capacity = self.sequence_length - self.input_ids_length
221-
cache.set_capacity(target_cache_capacity)
228+
cache.set_capacity(self.cache_length)
222229
self.kv_cache = cache
223230

224231
def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
@@ -241,9 +248,7 @@ def reset_kv_cache(self):
241248
"""
242249
Resets the kv cache state.
243250
"""
244-
kv_cache_state = self._initialize_kv_cache_state(
245-
self.sequence_length - self.input_ids_length
246-
)
251+
kv_cache_state = self._initialize_kv_cache_state(self.cache_length)
247252
self.kv_cache.setup(
248253
session_id=self._session_id,
249254
state=kv_cache_state,
@@ -255,13 +260,27 @@ def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]
255260
"""
256261
Takes the input and adds the past kv cache state to it.
257262
263+
If the internal kv cache is enabled, the kv cache state
264+
will always be reinitialized to zeros. This is just to make sure
265+
that the input shapes of the kv cache arrays to the
266+
model are correct, the actual values are
267+
being tracked internally inside the engine.
268+
269+
If the internal kv cache is disabled, we need to
270+
fetch the kv cache state as numpy arrays
271+
from the current session, or initialize it if required.
272+
273+
258274
:param inp: The input to the model
259275
:return The input with the kv cache state added to it
260276
"""
261-
kv_cache_state = self.kv_cache.cached_inputs
262-
if kv_cache_state is None:
263-
self.reset_kv_cache()
277+
if self.internal_cache_active:
278+
kv_cache_state = self._initialize_kv_cache_state(self.cache_length)
279+
else:
264280
kv_cache_state = self.kv_cache.cached_inputs
281+
if kv_cache_state is None:
282+
self.reset_kv_cache()
283+
kv_cache_state = self.kv_cache.cached_inputs
265284

266285
for idx, input_name in enumerate(self.onnx_input_names_no_cache):
267286
kv_cache_state[input_name] = inp[idx]
@@ -277,9 +296,17 @@ def update_kv_cache(
277296
"""
278297
Updates the state of the kv cache
279298
299+
If the internal kv cache is enabled, we refrain from
300+
updating the kv cache state as it is being tracked internally
301+
inside the engine. We only update the number of tokens processed.
302+
280303
:param kv_cache_state: The state of the kv cache storage
281304
:param input_ids_len: The length of input_ids
282305
"""
306+
if self.internal_cache_active:
307+
self.kv_cache.total_num_processed_tokens += input_ids_len
308+
return
309+
283310
cache_onnx_names = [
284311
name
285312
for name in self.engine.input_names

src/deepsparse/transformers/utils/decoder_kv_cache.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ def __init__(self, use_deepsparse_cache: bool = False):
3131
The goal this object is to handle the manipulation
3232
of the key value cache.
3333
34-
:param use_deepsparse_cache: If set to True, the `kv_cache` object
35-
from the deepsparse.LIB will be loaded as an attribute.
34+
:param use_deepsparse_cache: If set to True, the
35+
`kv_cache` object from the deepsparse.LIB will
36+
be loaded as an engine_internal_cache attribute.
3637
This object is used to handle the manipulation of the
3738
key/value buffers on the DeepSparse engine side.
3839
"""
@@ -45,7 +46,7 @@ def __init__(self, use_deepsparse_cache: bool = False):
4546
self._session_id = None
4647
self._freeze_first_position = None
4748
self._state = None
48-
self._kv_cache = None
49+
self.engine_internal_cache = None
4950

5051
def setup(
5152
self,
@@ -82,7 +83,9 @@ def setup(
8283
if self._use_deepsparse_cache:
8384
prev_num_tokens = self.total_num_processed_tokens
8485
num_frozen_tokens = int(self._freeze_first_position)
85-
self._kv_cache = LIB.kv_cache(prev_num_tokens, num_frozen_tokens)
86+
self.engine_internal_cache = LIB.kv_cache(
87+
prev_num_tokens, num_frozen_tokens
88+
)
8689

8790
def update(
8891
self,

tests/deepsparse/transformers/engine/test_nl_decoder_engine.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class DummyKVCacheDecoder:
2525
"past_key_values_1": np.array([10, 11, 12]),
2626
"past_key_values_2": np.array([13, 14, 15]),
2727
}
28+
engine_internal_cache = None
2829

2930

3031
class DummyEngine:
@@ -62,6 +63,7 @@ def test_add_kv_cache_to_input():
6263
nl_decoder_engine = NLDecoderEngine(None, None)
6364
nl_decoder_engine.engine = DummyEngine()
6465
nl_decoder_engine.kv_cache = DummyKVCacheDecoder()
66+
nl_decoder_engine.kv_cache_enabled = True
6567
result = nl_decoder_engine.add_kv_cache_to_input(inp)
6668

6769
for (x, y) in zip(result, expected_result):

0 commit comments

Comments
 (0)