11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import copy
14
15
import logging
15
- from typing import Any , Dict , List , Optional
16
+ from typing import Any , Dict , List , Optional , Tuple
16
17
17
18
import numpy
18
- from transformers import AutoTokenizer
19
19
20
20
from deepsparse .engine import Context
21
21
from deepsparse .pipeline import DEEPSPARSE_ENGINE , create_engine
22
22
from deepsparse .transformers .utils .decoder_kv_cache import DecoderKVCache
23
- from deepsparse .transformers .utils .helpers import generate_session_id
24
23
from deepsparse .transformers .utils .timings import TextGenerationTimings
25
24
from deepsparse .utils import TimerManager
26
25
from deepsparse .utils .onnx import (
27
26
CACHE_INPUT_PREFIX ,
28
- CACHE_OUTPUT_PREFIX ,
29
27
overwrite_onnx_model_inputs_for_kv_cache_models ,
30
28
)
31
29
37
35
38
36
class NLDecoderEngine :
39
37
"""
40
- The NLDecoderEngine (NaturalLanguageDecoderEngine ) handles the
38
+ The NLDecoderEngine (Natural Language Decoder Engine ) handles the
41
39
logic around the inference for Natural Language pipeline,
42
- including batching and kv cache logic.
40
+ including batching and kv cache manipulation logic.
43
41
44
42
:param onnx_file_path: The path to the onnx model file
45
43
:param engine_type: The type of engine to use for the inference
46
44
:param engine_args: The arguments to pass to the engine
47
45
:param sequence_length: The maximum sequence length to run the engine for
48
46
:param input_ids_length: The maximum input ids length to run the engine for
49
47
:param engine_context: The context to run the engine in
50
- :param sampling_temperature: The temperature to use for sampling
51
- :param deterministic: Whether to use deterministic sampling
52
- :param tokenizer: The tokenizer to used for engine inputs
53
- :param engine_context: The context to run the engine in
54
48
:param internal_kv_cache: Whether to use the deepsparse
55
49
kv cache in the DecoderKVCache object or not
56
50
"""
@@ -62,9 +56,6 @@ def __init__(
62
56
engine_args : Dict [str , Any ],
63
57
sequence_length : int ,
64
58
input_ids_length : int ,
65
- tokenizer : AutoTokenizer ,
66
- sampling_temperature : float = 1.0 ,
67
- deterministic : bool = True ,
68
59
engine_context : Optional [Context ] = None ,
69
60
internal_kv_cache = False ,
70
61
timer_manager : TimerManager = None ,
@@ -82,9 +73,7 @@ def __init__(
82
73
input_ids_length = input_ids_length ,
83
74
)
84
75
85
- kv_cache_enabled = False
86
76
if any (output_indices_to_be_cached ):
87
- kv_cache_enabled = True
88
77
self .kv_cache_data_type = kv_cache_data_type
89
78
if internal_kv_cache and engine_type == DEEPSPARSE_ENGINE :
90
79
# inform the engine, that are using the kv cache
@@ -98,30 +87,10 @@ def __init__(
98
87
)
99
88
self .timer_manager = timer_manager or TimerManager ()
100
89
self .sequence_length = sequence_length
101
- self .sampling_temperature = sampling_temperature
102
- self .deterministic = deterministic
103
90
self .input_ids_length = input_ids_length
104
91
self .cache_length = sequence_length - input_ids_length
105
- self .kv_cache_enabled = kv_cache_enabled
106
- self .kv_cache = DecoderKVCache (internal_kv_cache ) if kv_cache_enabled else None
107
- self ._freeze_first_position = self ._should_freeze_first_position (tokenizer )
108
- self ._session_id = generate_session_id ()
109
92
self ._engine_type = engine_type
110
93
111
- @property
112
- def session_id (self ) -> str :
113
- """
114
- :return: The session id for the kv_cache if enabled
115
- """
116
- return self ._session_id
117
-
118
- @session_id .setter
119
- def session_id (self , session_id : str ):
120
- """
121
- :param session_id: The session id to set for the kv_cache
122
- """
123
- self ._session_id = session_id
124
-
125
94
@property
126
95
def onnx_input_names_no_cache (self ) -> List [str ]:
127
96
"""
@@ -135,25 +104,43 @@ def onnx_input_names_no_cache(self) -> List[str]:
135
104
]
136
105
137
106
@property
138
- def num_non_blank_cache_entries (self ) -> int :
107
+ def onnx_input_names_cached (self ) -> List [str ]:
108
+ """
109
+ :return: The cached input names for the onnx model
110
+ """
111
+ return [
112
+ name
113
+ for name in self .engine .input_names
114
+ if name .startswith (CACHE_INPUT_PREFIX )
115
+ ]
116
+
117
+ @property
118
+ def cache_shape (self ) -> Tuple [int , int , int , int ]:
139
119
"""
140
- :return A number of non-blank entries in the
141
- kv cache
120
+ :return: The shape of the kv cache inputs
121
+ for the onnx model. The shape is
122
+ (batch_size, num_heads, sequence_length, hidden_size)
142
123
"""
143
- return self .kv_cache .num_non_blank_entries
124
+ cache_engine_input_index = next (
125
+ i
126
+ for i , name in enumerate (self .engine .input_names )
127
+ if CACHE_INPUT_PREFIX in name
128
+ )
129
+ return self .engine .input_shapes [cache_engine_input_index ]
144
130
145
131
@property
146
- def internal_cache_active (self ) -> bool :
132
+ def output_names (self ) -> List [ str ] :
147
133
"""
148
- :return: Whether the internal kv cache is active
134
+ :return: The output names for the onnx model
149
135
"""
150
- return self .kv_cache_enabled and self . kv_cache . engine_internal_cache is not None
136
+ return self .engine . output_names
151
137
152
- def run (self , inputs : List [numpy .ndarray ], val_inp : bool ) -> List [numpy .ndarray ]:
138
+ def run (
139
+ self , inputs : List [numpy .ndarray ], val_inp : bool , kv_cache : DecoderKVCache
140
+ ) -> List [numpy .ndarray ]:
153
141
"""
154
142
Run the engine with the given inputs.
155
-
156
- If the self.internal_cache_active=True, the internal
143
+ If the kv_cache.engine_internal_cache=True, the internal
157
144
deepsparse kv cache management is enabled. In this case
158
145
the LIB.kv_cache class object will be passed to the engine
159
146
call as well. In this scenario also the inputs will not be
@@ -163,25 +150,27 @@ def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]
163
150
164
151
:param inputs: The inputs to run the engine with
165
152
:param val_inp: Whether the input is for validation or not
153
+ :param kv_cache: The kv cache object to use for the inference
154
+
166
155
:return: The output of the engine
167
156
"""
168
-
169
- if self .internal_cache_active :
157
+ if bool (kv_cache .engine_internal_cache ):
170
158
# conventionally, before dispatching
171
159
# inputs to the engine, we validate them
172
160
# if val_inp=True. However, in this case
173
161
# we want to pass the empty kv cache inputs
174
162
# (batch_size=0) to the engine. Therefore,
175
163
# we skip the validation
176
164
return self .engine ._eng_net .execute_list_out (
177
- inputs , self . kv_cache .engine_internal_cache
165
+ inputs , kv_cache .engine_internal_cache
178
166
)
179
167
# run the engine without the LIB.kv_cache object
180
168
return self .engine .run (inputs , val_inp )
181
169
182
170
def __call__ (
183
171
self ,
184
172
inp : List [numpy .ndarray ],
173
+ kv_cache : Optional [DecoderKVCache ] = None ,
185
174
val_inp : bool = True ,
186
175
) -> numpy .ndarray :
187
176
"""
@@ -190,23 +179,28 @@ def __call__(
190
179
:param inp: The input to run the engine with. We expect a
191
180
list of numpy arrays that contain the input ids,
192
181
attention mask, and position ids (optionally)
182
+ :param kv_cache: The DecoderKVCache object that contains
183
+ the kv cache state
193
184
:param val_inp: Whether the input is for validation or not
185
+
194
186
:return: The generated token and corresponding logits
195
187
"""
196
188
timer = self .timer_manager .current
197
- if self . kv_cache :
189
+ if kv_cache :
198
190
# if model has kv cache enabled, we need
199
191
# to add the kv cache state to the input
200
- inp = self .add_kv_cache_to_input (inp )
192
+ inp = self .add_kv_cache_to_input (inp , kv_cache )
201
193
202
194
with timer .time (f"EXECUTE_ENGINE_SEQ_LEN_{ self .sequence_length } " ):
203
- out = self .run (inp , val_inp )
195
+ out = self .run (inp , val_inp , kv_cache )
204
196
205
- if self . kv_cache :
197
+ if kv_cache :
206
198
with timer .time (TextGenerationTimings .KV_CACHE_UPDATE ):
207
199
logits , * kv_cache_state = out
208
200
self .update_kv_cache (
209
- kv_cache_state = kv_cache_state , input_ids_len = self .input_ids_length
201
+ kv_cache_state = kv_cache_state ,
202
+ input_ids_len = self .input_ids_length ,
203
+ kv_cache = kv_cache ,
210
204
)
211
205
else :
212
206
logits = out [0 ]
@@ -219,36 +213,11 @@ def __str__(self):
219
213
def __repr__ (self ):
220
214
return str (self )
221
215
222
- def transfer_cache_state (self , cache : DecoderKVCache ):
216
+ def add_kv_cache_to_input (
217
+ self , inp : List [numpy .ndarray ], kv_cache : DecoderKVCache
218
+ ) -> List [numpy .ndarray ]:
223
219
"""
224
- Transfers the kv cache state and the number of tokens processed
225
- information from another NLDecoderEngine. Call this method when
226
- you want to transfer the kv cache state from one engine to another.
227
-
228
- This method will also automatically set the kv cache capacity to
229
- the appropriate value for the new engine.
230
-
231
- :param cache: The `DecoderKVCache` object to transfer to the engine
232
- from
233
- """
234
- cache .set_capacity (self .cache_length )
235
- self .kv_cache = cache
236
-
237
- def reset_kv_cache (self ):
238
- """
239
- Resets the kv cache state.
240
- """
241
- kv_cache_state = self ._initialize_kv_cache_state (self .cache_length )
242
- self .kv_cache .setup (
243
- session_id = self ._session_id ,
244
- state = kv_cache_state ,
245
- num_processed_tokens = 0 ,
246
- freeze_first_position = self ._freeze_first_position ,
247
- )
248
-
249
- def add_kv_cache_to_input (self , inp : List [numpy .ndarray ]) -> List [numpy .ndarray ]:
250
- """
251
- Takes the input and adds the past kv cache state to it.
220
+ Takes the input and adds the kv cache state to it.
252
221
253
222
If the internal kv cache is enabled, the kv cache state
254
223
will always be an empty array. This is just to make sure
@@ -262,17 +231,11 @@ def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]
262
231
263
232
264
233
:param inp: The input to the model
234
+ :param kv_cache: The kv cache object
235
+
265
236
:return The input with the kv cache state added to it
266
237
"""
267
- if self .internal_cache_active :
268
- kv_cache_state = self ._initialize_kv_cache_state (
269
- self .cache_length , empty = True
270
- )
271
- else :
272
- kv_cache_state = self .kv_cache .cached_inputs
273
- if kv_cache_state is None :
274
- self .reset_kv_cache ()
275
- kv_cache_state = self .kv_cache .cached_inputs
238
+ kv_cache_state = copy .copy (kv_cache .cached_inputs )
276
239
277
240
for idx , input_name in enumerate (self .onnx_input_names_no_cache ):
278
241
kv_cache_state [input_name ] = inp [idx ]
@@ -284,75 +247,29 @@ def update_kv_cache(
284
247
self ,
285
248
kv_cache_state : List [numpy .ndarray ],
286
249
input_ids_len : int ,
250
+ kv_cache : DecoderKVCache ,
287
251
):
288
252
"""
289
- Updates the state of the kv cache
253
+ Updates the kv cache using the new kv cache state.
290
254
291
255
If the internal kv cache is enabled, we refrain from
292
256
updating the kv cache state as it is being tracked internally
293
257
inside the engine. We only update the number of tokens processed.
294
258
295
- :param kv_cache_state: The state of the kv cache storage
259
+ :param kv_cache_state: The new state of the kv cache storage
296
260
:param input_ids_len: The length of input_ids
261
+ :param kv_cache: The kv cache object to update
297
262
"""
298
- if self . internal_cache_active :
299
- self . kv_cache .total_num_processed_tokens += input_ids_len
263
+ if bool ( kv_cache . engine_internal_cache ) :
264
+ kv_cache .total_num_processed_tokens += input_ids_len
300
265
return
301
266
302
- cache_onnx_names = [
303
- name
304
- for name in self .engine .input_names
305
- if name .startswith (CACHE_INPUT_PREFIX )
306
- ]
307
267
kv_cache_state = {
308
- name : array for name , array in zip (cache_onnx_names , kv_cache_state )
268
+ name : array
269
+ for name , array in zip (self .onnx_input_names_cached , kv_cache_state )
309
270
}
310
271
311
- self . kv_cache .update (
272
+ kv_cache .update (
312
273
state = kv_cache_state ,
313
274
input_ids_len = input_ids_len ,
314
275
)
315
-
316
- def _initialize_kv_cache_state (
317
- self , length : int , empty : bool = False
318
- ) -> Dict [str , numpy .ndarray ]:
319
- # initialize empty kv cache of size
320
- # (batch_size, num_attention_heads, length, hidden_dims)
321
- # if empty is True, we initialize empty kv_cache
322
- # and set the batch_size to 0
323
-
324
- cache_engine_input_index = next (
325
- i
326
- for i , name in enumerate (self .engine .input_names )
327
- if CACHE_INPUT_PREFIX in name
328
- )
329
- batch_size , num_attention_heads , _ , hidden_dims = self .engine .input_shapes [
330
- cache_engine_input_index
331
- ]
332
-
333
- empty_kv_cache_tensor = numpy .zeros (
334
- (
335
- batch_size if not empty else 0 ,
336
- num_attention_heads ,
337
- length ,
338
- hidden_dims ,
339
- ),
340
- dtype = self .kv_cache_data_type ,
341
- )
342
-
343
- cache_keys = [
344
- output_name .replace (CACHE_OUTPUT_PREFIX , CACHE_INPUT_PREFIX )
345
- for output_name in self .engine .output_names
346
- if output_name .startswith (CACHE_OUTPUT_PREFIX )
347
- ]
348
- return {key : empty_kv_cache_tensor for key in cache_keys }
349
-
350
- @staticmethod
351
- def _should_freeze_first_position (tokenizer ) -> bool :
352
- # use tokenizer to find out whether we should freeze the first position
353
- # (True if tokenizer has a prefix for a BOS token)
354
- if tokenizer is None :
355
- return False
356
- if hasattr (tokenizer , "add_bos_token" ):
357
- return True
358
- return False
0 commit comments