@@ -91,7 +91,7 @@ def update(
91
91
):
92
92
"""
93
93
Updating the session is identical with taking the kv cache
94
- output of from the forward pass and restructuring it, so it
94
+ output from the forward pass and restructuring it, so it
95
95
can be directly used as input for the next forward pass.
96
96
97
97
:param state: The state of the cache. This is a dictionary
@@ -103,79 +103,78 @@ def update(
103
103
Corresponds to `input_ids.shape[1]`
104
104
"""
105
105
self .total_num_processed_tokens += input_ids_len
106
- total_cache_capacity = state [list (state .keys ())[0 ]].shape [
106
+
107
+ input_state_capacity = state [list (state .keys ())[0 ]].shape [
107
108
self ._sequence_len_axis
108
109
]
109
- # total_capacity = num_tokens (num of non-blank tokens) +
110
- # + num_padded_entries (num of blank tokens)
110
+ num_entries_to_delete = input_ids_len
111
+
112
+ # compute the number of blank (padded) entries in the cache
111
113
num_padded_entries = max (
112
- 0 , total_cache_capacity - self .total_num_processed_tokens
114
+ 0 , input_state_capacity - self .total_num_processed_tokens
113
115
)
114
- num_entries_to_delete = input_ids_len
116
+ # compute how many of those entries need to be deleted
117
+ num_padded_entries_to_delete = min (num_padded_entries , num_entries_to_delete )
115
118
116
- if num_padded_entries :
117
- """
118
- Transforms input KV cache that contains blank entries.
119
- It removes the rightmost blank entries from the cache.
120
- Example 1:
121
- (entries in the cache denote the order in which they were
122
- added to the cache, zero is to denote a blank entry)
123
- ```
124
- state["state_name"]: (1, 1, 5, 1) = array([[[[0], [0], [1], [2], [3]]]])
125
- -> num_padded_entries = 2
126
- -> num_entries_to_delete = 1
127
- -> num_padded_entries > num_entries_to_delete
128
- # there are more blank entries than entries to delete
129
- results in:
130
- state["state_name"]: (1, 1, 4, 1) = array([[[[0], [1], [2], [3]]]])
131
- ```
132
- Example 2:
133
- ```
134
- state["state_name"]: (1, 1, 6, 1) = array([[[[0], [0], [0], [1], [2], [3]]]]) # noqa: E501
135
- -> num_padded_entries = 3
136
- -> num_entries_to_delete = 5
137
- -> num_padded_entries < num_entries_to_delete
138
- # there are less blank entries than entries to delete
139
- results in:
140
- state["state_name"]: (1, 1, 3, 1) = array([[[[1], [2], [3]]]])
141
- ```
142
- """
143
- num_padded_entries_to_delete = min (
144
- num_padded_entries , num_entries_to_delete
145
- )
146
- idxs_to_remove = [
147
- num_padded_entries - i - 1 for i in range (num_padded_entries_to_delete )
148
- ]
149
- # if we had fewer blank entries than entries to delete,
150
- # we updated the number of entries to delete to a non-zero value
151
- num_entries_to_delete = max (0 , num_entries_to_delete - num_padded_entries )
152
- # update the state of the cache
153
- state = self ._delete_entries (state , idxs_to_remove )
154
-
155
- if num_entries_to_delete :
156
- """
157
- Transforms the input KV cache that has been totally
158
- filled with non-blank entries.
159
- Example:
160
- ```
161
- state["state_name"]: (1, 1, 5, 1) = array([[[[1], [2], [3], [4], [5]]]])
162
- num_entries_to_delete = 2
163
- if self.freeze_first_position == False:
164
- state["state_name"]: (1, 1, 3, 1) = array([[[[3], [4], [5]]]])
165
- else:
166
-
167
- state["state_name"]: (1, 1, 3, 1) = array([[[[1], [4], [5]]]])
168
- ```
169
- """
170
- idxs_to_remove = [
171
- i + int (self ._freeze_first_position )
172
- for i in range (num_entries_to_delete )
173
- ]
174
-
175
- state = self ._delete_entries (state , idxs_to_remove )
119
+ # if we had fewer padded entries than num_entries_to_delete,
120
+ # we additionally are forced to delete some non-padded entries (the oldest ones)
121
+ num_non_padded_entries_to_delete = max (
122
+ 0 , num_entries_to_delete - num_padded_entries
123
+ )
124
+
125
+ for name , cache_array in state .items ():
126
+ if num_padded_entries_to_delete :
127
+ cache_array = self .remove_padded_entries (
128
+ cache_array , num_padded_entries_to_delete
129
+ )
130
+ if num_non_padded_entries_to_delete :
131
+ cache_array = self .remove_non_padded_entries (
132
+ cache_array , num_entries_to_delete
133
+ )
134
+ state [name ] = numpy .ascontiguousarray (cache_array )
176
135
177
136
self ._state = state
178
137
138
+ def remove_padded_entries (
139
+ self , cache_array : numpy .ndarray , num_padded_entries_to_delete : int
140
+ ):
141
+ """
142
+ Remove the num_padded_entries_to_delete entries from the cache array.
143
+ This function assumes that the cache_array has the number
144
+ of padded (blank) entries that is equal/larger than
145
+ num_padded_entries_to_delete.
146
+
147
+ :param cache_array: The cache array to be modified.
148
+ :param num_padded_entries_to_delete: The number of padded entries to delete.
149
+ """
150
+ return cache_array [:, :, num_padded_entries_to_delete :, :]
151
+
152
+ def remove_non_padded_entries (
153
+ self , cache_array : numpy .ndarray , num_non_padded_entries_to_delete : int
154
+ ):
155
+ """
156
+ Remove the num_non_padded_entries_to_delete entries from the cache array.
157
+ This function assumes that the cache_array has no padded (blank) entries and
158
+ thus we are forced to delete the oldest entries from the cache.
159
+
160
+ If self._freeze_first_position is set to True, that means that the oldest
161
+ entry in the cache_array is the one that corresponds to the BOS token. Because
162
+ we want to keep that entry in the cache, we will delete the oldest entry
163
+ starting from the second oldest entry.
164
+ """
165
+ new_cache_array = cache_array [
166
+ :,
167
+ :,
168
+ bool (self ._freeze_first_position ) + num_non_padded_entries_to_delete :,
169
+ :,
170
+ ]
171
+ if self ._freeze_first_position :
172
+ bos_entries = cache_array [:, :, :1 , :]
173
+ new_cache_array = numpy .concatenate (
174
+ [bos_entries , new_cache_array ], axis = self ._sequence_len_axis
175
+ )
176
+ return new_cache_array
177
+
179
178
def set_capacity (self , capacity : int ):
180
179
"""
181
180
Enforce a new total capacity for the state
@@ -212,14 +211,6 @@ def set_capacity(self, capacity: int):
212
211
213
212
self ._state = state
214
213
215
- def _delete_entries (
216
- self , state : Dict [str , Any ], indices : List [int ]
217
- ) -> Dict [str , Any ]:
218
- for key , value in state .items ():
219
- state [key ] = numpy .delete (value , indices , axis = self ._sequence_len_axis )
220
- state [key ] = numpy .ascontiguousarray (state [key ])
221
- return state
222
-
223
214
def _add_entries (
224
215
self , state : Dict [str , Any ], indices : List [int ], padding_value : int = 0
225
216
) -> Dict [str , Any ]:
0 commit comments