Skip to content

Commit 1bd60d2

Browse files
authored
[Text Generation] Optimize the slow update method in the KVCacheDecoder (#1190)
* initial commit * Nit: docstring typo * fix style
1 parent 2cf112a commit 1bd60d2

File tree

1 file changed

+65
-74
lines changed

1 file changed

+65
-74
lines changed

Diff for: src/deepsparse/transformers/utils/decoder_kv_cache.py

+65-74
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def update(
9191
):
9292
"""
9393
Updating the session is identical with taking the kv cache
94-
output of from the forward pass and restructuring it, so it
94+
output from the forward pass and restructuring it, so it
9595
can be directly used as input for the next forward pass.
9696
9797
:param state: The state of the cache. This is a dictionary
@@ -103,79 +103,78 @@ def update(
103103
Corresponds to `input_ids.shape[1]`
104104
"""
105105
self.total_num_processed_tokens += input_ids_len
106-
total_cache_capacity = state[list(state.keys())[0]].shape[
106+
107+
input_state_capacity = state[list(state.keys())[0]].shape[
107108
self._sequence_len_axis
108109
]
109-
# total_capacity = num_tokens (num of non-blank tokens) +
110-
# + num_padded_entries (num of blank tokens)
110+
num_entries_to_delete = input_ids_len
111+
112+
# compute the number of blank (padded) entries in the cache
111113
num_padded_entries = max(
112-
0, total_cache_capacity - self.total_num_processed_tokens
114+
0, input_state_capacity - self.total_num_processed_tokens
113115
)
114-
num_entries_to_delete = input_ids_len
116+
# compute how many of those entries need to be deleted
117+
num_padded_entries_to_delete = min(num_padded_entries, num_entries_to_delete)
115118

116-
if num_padded_entries:
117-
"""
118-
Transforms input KV cache that contains blank entries.
119-
It removes the rightmost blank entries from the cache.
120-
Example 1:
121-
(entries in the cache denote the order in which they were
122-
added to the cache, zero is to denote a blank entry)
123-
```
124-
state["state_name"]: (1, 1, 5, 1) = array([[[[0], [0], [1], [2], [3]]]])
125-
-> num_padded_entries = 2
126-
-> num_entries_to_delete = 1
127-
-> num_padded_entries > num_entries_to_delete
128-
# there are more blank entries than entries to delete
129-
results in:
130-
state["state_name"]: (1, 1, 4, 1) = array([[[[0], [1], [2], [3]]]])
131-
```
132-
Example 2:
133-
```
134-
state["state_name"]: (1, 1, 6, 1) = array([[[[0], [0], [0], [1], [2], [3]]]]) # noqa: E501
135-
-> num_padded_entries = 3
136-
-> num_entries_to_delete = 5
137-
-> num_padded_entries < num_entries_to_delete
138-
# there are less blank entries than entries to delete
139-
results in:
140-
state["state_name"]: (1, 1, 3, 1) = array([[[[1], [2], [3]]]])
141-
```
142-
"""
143-
num_padded_entries_to_delete = min(
144-
num_padded_entries, num_entries_to_delete
145-
)
146-
idxs_to_remove = [
147-
num_padded_entries - i - 1 for i in range(num_padded_entries_to_delete)
148-
]
149-
# if we had fewer blank entries than entries to delete,
150-
# we updated the number of entries to delete to a non-zero value
151-
num_entries_to_delete = max(0, num_entries_to_delete - num_padded_entries)
152-
# update the state of the cache
153-
state = self._delete_entries(state, idxs_to_remove)
154-
155-
if num_entries_to_delete:
156-
"""
157-
Transforms the input KV cache that has been totally
158-
filled with non-blank entries.
159-
Example:
160-
```
161-
state["state_name"]: (1, 1, 5, 1) = array([[[[1], [2], [3], [4], [5]]]])
162-
num_entries_to_delete = 2
163-
if self.freeze_first_position == False:
164-
state["state_name"]: (1, 1, 3, 1) = array([[[[3], [4], [5]]]])
165-
else:
166-
167-
state["state_name"]: (1, 1, 3, 1) = array([[[[1], [4], [5]]]])
168-
```
169-
"""
170-
idxs_to_remove = [
171-
i + int(self._freeze_first_position)
172-
for i in range(num_entries_to_delete)
173-
]
174-
175-
state = self._delete_entries(state, idxs_to_remove)
119+
# if we had fewer padded entries than num_entries_to_delete,
120+
# we additionally are forced to delete some non-padded entries (the oldest ones)
121+
num_non_padded_entries_to_delete = max(
122+
0, num_entries_to_delete - num_padded_entries
123+
)
124+
125+
for name, cache_array in state.items():
126+
if num_padded_entries_to_delete:
127+
cache_array = self.remove_padded_entries(
128+
cache_array, num_padded_entries_to_delete
129+
)
130+
if num_non_padded_entries_to_delete:
131+
cache_array = self.remove_non_padded_entries(
132+
cache_array, num_entries_to_delete
133+
)
134+
state[name] = numpy.ascontiguousarray(cache_array)
176135

177136
self._state = state
178137

138+
def remove_padded_entries(
139+
self, cache_array: numpy.ndarray, num_padded_entries_to_delete: int
140+
):
141+
"""
142+
Remove the num_padded_entries_to_delete entries from the cache array.
143+
This function assumes that the cache_array has the number
144+
of padded (blank) entries that is equal/larger than
145+
num_padded_entries_to_delete.
146+
147+
:param cache_array: The cache array to be modified.
148+
:param num_padded_entries_to_delete: The number of padded entries to delete.
149+
"""
150+
return cache_array[:, :, num_padded_entries_to_delete:, :]
151+
152+
def remove_non_padded_entries(
153+
self, cache_array: numpy.ndarray, num_non_padded_entries_to_delete: int
154+
):
155+
"""
156+
Remove the num_non_padded_entries_to_delete entries from the cache array.
157+
This function assumes that the cache_array has no padded (blank) entries and
158+
thus we are forced to delete the oldest entries from the cache.
159+
160+
If self._freeze_first_position is set to True, that means that the oldest
161+
entry in the cache_array is the one that corresponds to the BOS token. Because
162+
we want to keep that entry in the cache, we will delete the oldest entry
163+
starting from the second oldest entry.
164+
"""
165+
new_cache_array = cache_array[
166+
:,
167+
:,
168+
bool(self._freeze_first_position) + num_non_padded_entries_to_delete :,
169+
:,
170+
]
171+
if self._freeze_first_position:
172+
bos_entries = cache_array[:, :, :1, :]
173+
new_cache_array = numpy.concatenate(
174+
[bos_entries, new_cache_array], axis=self._sequence_len_axis
175+
)
176+
return new_cache_array
177+
179178
def set_capacity(self, capacity: int):
180179
"""
181180
Enforce a new total capacity for the state
@@ -212,14 +211,6 @@ def set_capacity(self, capacity: int):
212211

213212
self._state = state
214213

215-
def _delete_entries(
216-
self, state: Dict[str, Any], indices: List[int]
217-
) -> Dict[str, Any]:
218-
for key, value in state.items():
219-
state[key] = numpy.delete(value, indices, axis=self._sequence_len_axis)
220-
state[key] = numpy.ascontiguousarray(state[key])
221-
return state
222-
223214
def _add_entries(
224215
self, state: Dict[str, Any], indices: List[int], padding_value: int = 0
225216
) -> Dict[str, Any]:

0 commit comments

Comments
 (0)