16
16
from typing import Any , Dict , List , Optional , Tuple
17
17
18
18
import numpy
19
- import onnx
20
19
from transformers import AutoTokenizer
21
20
22
21
from deepsparse .engine import Context
23
22
from deepsparse .pipeline import DEEPSPARSE_ENGINE , create_engine
24
23
from deepsparse .transformers .utils .decoder_kv_cache import DecoderKVCache
25
- from deepsparse .transformers .utils .helpers import generate_session_id
24
+ from deepsparse .transformers .utils .helpers import (
25
+ generate_session_id ,
26
+ overwrite_onnx_model_inputs ,
27
+ )
26
28
from deepsparse .utils .data import numpy_softmax
27
- from deepsparse .utils .onnx import translate_onnx_type_to_numpy
28
- from sparsezoo .utils .onnx import save_onnx
29
29
30
30
31
31
_LOGGER = logging .getLogger (__name__ )
@@ -71,7 +71,11 @@ def __init__(
71
71
# flag to indicate if the model is quantized or not
72
72
self .kv_cache_data_type = None
73
73
74
- onnx_file_path , output_indices_to_be_cached = self .overwrite_onnx_model_inputs (
74
+ (
75
+ onnx_file_path ,
76
+ output_indices_to_be_cached ,
77
+ kv_cache_data_type ,
78
+ ) = overwrite_onnx_model_inputs (
75
79
onnx_file_path = onnx_file_path ,
76
80
batch_size = engine_args .get ("batch_size" , 1 ),
77
81
sequence_length = sequence_length ,
@@ -80,6 +84,7 @@ def __init__(
80
84
kv_cache_enabled = False
81
85
if sum (output_indices_to_be_cached ):
82
86
kv_cache_enabled = True
87
+ self .kv_cache_data_type = kv_cache_data_type
83
88
if use_deepsparse_cache and engine_type == DEEPSPARSE_ENGINE :
84
89
# inform the engine, that are using the kv cache
85
90
engine_args ["cache_output_bools" ] = output_indices_to_be_cached
@@ -192,74 +197,6 @@ def transfer_cache_state(self, cache: DecoderKVCache):
192
197
cache_to_copy .set_capacity (target_cache_capacity )
193
198
self .kv_cache = cache_to_copy
194
199
195
- def overwrite_onnx_model_inputs (
196
- self ,
197
- onnx_file_path : str ,
198
- sequence_length : int ,
199
- input_ids_length : int ,
200
- batch_size : int = 1 ,
201
- ) -> Tuple [str , List [int ]]:
202
- """
203
- Enforces the appropriate input shapes for the onnx model, as well as
204
- checks whether kv cache is enabled or not.
205
-
206
- :param onnx_file_path: The path to the onnx model file that will be
207
- overwritten with the new input shapes
208
- :param batch_size: The batch size to use for the input
209
- :param sequence_length: The sequence length to use for the input
210
- :param input_ids_length: The length of input_ids
211
- :return: The path to the onnx model file that has been overwritten
212
- with the new input shapes, as well as the indices of the inputs
213
- that should be cached
214
- """
215
- model = onnx .load (onnx_file_path , load_external_data = False )
216
- initializer_input_names = set (node .name for node in model .graph .initializer )
217
- external_inputs = [
218
- inp for inp in model .graph .input if inp .name not in initializer_input_names
219
- ]
220
- for external_input in external_inputs :
221
- # overwrite the batch size for all the inputs
222
- external_input .type .tensor_type .shape .dim [0 ].dim_value = batch_size
223
-
224
- if external_input .name in ["input_ids" , "positions" ]:
225
- external_input .type .tensor_type .shape .dim [
226
- 1
227
- ].dim_value = input_ids_length
228
- elif external_input .name == "attention_mask" :
229
- external_input .type .tensor_type .shape .dim [1 ].dim_value = sequence_length
230
- elif external_input .name .startswith (_CACHE_INPUT_NAME ):
231
- external_input .type .tensor_type .shape .dim [2 ].dim_value = (
232
- sequence_length - input_ids_length
233
- )
234
- elif external_input .name .startswith ("causal_mask" ):
235
- external_input .type .tensor_type .shape .dim [
236
- 2
237
- ].dim_value = input_ids_length
238
- external_input .type .tensor_type .shape .dim [3 ].dim_value = sequence_length
239
- else :
240
- raise ValueError (
241
- f"Unexpected external input name: { external_input .name } "
242
- )
243
-
244
- _LOGGER .info (
245
- "Overwriting in-place the input shapes "
246
- f"of the transformer model at { onnx_file_path } "
247
- )
248
- save_onnx (model , onnx_file_path )
249
-
250
- output_indices_to_be_cached = [
251
- 1 if inp .name .startswith ("present" ) else 0 for inp in model .graph .output
252
- ]
253
- if any (output_indices_to_be_cached ):
254
- kv_cache_elem_type = next (
255
- inp
256
- for inp in model .graph .input
257
- if inp .name .startswith (_CACHE_INPUT_NAME )
258
- ).type .tensor_type .elem_type
259
- self .kv_cache_data_type = translate_onnx_type_to_numpy (kv_cache_elem_type )
260
-
261
- return onnx_file_path , output_indices_to_be_cached
262
-
263
200
def generate_token (self , logits : numpy .ndarray ) -> numpy .ndarray :
264
201
"""
265
202
Samples a token from the logits using the sampling temperature.
@@ -283,7 +220,7 @@ def reset_kv_cache(self):
283
220
kv_cache_state = self ._initialize_kv_cache_state (
284
221
self .sequence_length - self .input_ids_length
285
222
)
286
- self .kv_cache .setup_session (
223
+ self .kv_cache .setup (
287
224
session_id = self ._session_id ,
288
225
state = kv_cache_state ,
289
226
num_processed_tokens = 0 ,
@@ -328,7 +265,7 @@ def update_kv_cache(
328
265
name : array for name , array in zip (cache_onnx_names , kv_cache_state )
329
266
}
330
267
331
- self .kv_cache .update_session (
268
+ self .kv_cache .update (
332
269
state = kv_cache_state ,
333
270
input_ids_len = input_ids_len ,
334
271
)
@@ -364,6 +301,6 @@ def _should_freeze_first_position(tokenizer) -> bool:
364
301
# (True if tokenizer has a prefix for a BOS token)
365
302
if tokenizer is None :
366
303
return False
367
- if hasattr (tokenizer , "bos_token " ):
304
+ if hasattr (tokenizer , "add_bos_token " ):
368
305
return True
369
306
return False
0 commit comments