11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- import copy
15
14
import logging
16
15
from typing import Any , Dict , List , Optional , Tuple
17
16
18
17
import numpy
19
- import onnx
20
18
from transformers import AutoTokenizer
21
19
22
20
from deepsparse .engine import Context
23
21
from deepsparse .pipeline import DEEPSPARSE_ENGINE , create_engine
24
22
from deepsparse .transformers .utils .decoder_kv_cache import DecoderKVCache
25
- from deepsparse .transformers .utils .helpers import generate_session_id
23
+ from deepsparse .transformers .utils .helpers import (
24
+ generate_session_id ,
25
+ overwrite_onnx_model_inputs ,
26
+ )
26
27
from deepsparse .utils .data import numpy_softmax
27
- from deepsparse .utils .onnx import translate_onnx_type_to_numpy
28
- from sparsezoo .utils .onnx import save_onnx
29
28
30
29
31
30
_LOGGER = logging .getLogger (__name__ )
@@ -71,7 +70,11 @@ def __init__(
71
70
# flag to indicate if the model is quantized or not
72
71
self .kv_cache_data_type = None
73
72
74
- onnx_file_path , output_indices_to_be_cached = self .overwrite_onnx_model_inputs (
73
+ (
74
+ onnx_file_path ,
75
+ output_indices_to_be_cached ,
76
+ kv_cache_data_type ,
77
+ ) = overwrite_onnx_model_inputs (
75
78
onnx_file_path = onnx_file_path ,
76
79
batch_size = engine_args .get ("batch_size" , 1 ),
77
80
sequence_length = sequence_length ,
@@ -80,9 +83,10 @@ def __init__(
80
83
kv_cache_enabled = False
81
84
if sum (output_indices_to_be_cached ):
82
85
kv_cache_enabled = True
86
+ self .kv_cache_data_type = kv_cache_data_type
83
87
if use_deepsparse_cache and engine_type == DEEPSPARSE_ENGINE :
84
88
# inform the engine, that are using the kv cache
85
- engine_args ["cache_output_bools " ] = output_indices_to_be_cached
89
+ engine_args ["cached_outputs " ] = output_indices_to_be_cached
86
90
87
91
self .engine = create_engine (
88
92
onnx_file_path = onnx_file_path ,
@@ -100,6 +104,7 @@ def __init__(
100
104
)
101
105
self ._freeze_first_position = self ._should_freeze_first_position (tokenizer )
102
106
self ._session_id = generate_session_id ()
107
+ self ._engine_type = engine_type
103
108
104
109
@property
105
110
def session_id (self ) -> str :
@@ -135,6 +140,32 @@ def num_non_blank_cache_entries(self) -> int:
135
140
"""
136
141
return self .kv_cache .num_non_blank_entries
137
142
143
+ def run (self , inputs : List [numpy .ndarray ], val_inp : bool ) -> List [numpy .ndarray ]:
144
+ """
145
+ Run the engine with the given inputs.
146
+
147
+ If the internal deepsparse kv cache management is enable,
148
+ the LIB.kv_cache class object will be passed to the engine
149
+ call as well.
150
+
151
+ :param inputs: The inputs to run the engine with
152
+ :param val_inp: Whether the input is for validation or not
153
+
154
+ :return: The output of the engine
155
+ """
156
+
157
+ if self .kv_cache is not None :
158
+ if self .kv_cache ._kv_cache is not None :
159
+ if val_inp :
160
+ self .engine ._validate_inputs (inputs )
161
+ # model has kv cache support, as well as deepsparse
162
+ # internal management of the kv cache
163
+ return self .engine ._eng_net .execute_list_out (
164
+ inputs , self .kv_cache ._kv_cache
165
+ )
166
+
167
+ return self .engine .run (inputs , val_inp )
168
+
138
169
def __call__ (
139
170
self ,
140
171
inp : List [numpy .ndarray ],
@@ -154,7 +185,7 @@ def __call__(
154
185
# to the input
155
186
inp = self .add_kv_cache_to_input (inp )
156
187
157
- out = self .engine . run (inp , val_inp )
188
+ out = self .run (inp , val_inp )
158
189
159
190
if self .kv_cache :
160
191
logits , * kv_cache_state = out
@@ -187,78 +218,9 @@ def transfer_cache_state(self, cache: DecoderKVCache):
187
218
:param cache: The `DecoderKVCache` object to transfer to the engine
188
219
from
189
220
"""
190
- cache_to_copy = copy .deepcopy (cache )
191
221
target_cache_capacity = self .sequence_length - self .input_ids_length
192
- cache_to_copy .set_capacity (target_cache_capacity )
193
- self .kv_cache = cache_to_copy
194
-
195
- def overwrite_onnx_model_inputs (
196
- self ,
197
- onnx_file_path : str ,
198
- sequence_length : int ,
199
- input_ids_length : int ,
200
- batch_size : int = 1 ,
201
- ) -> Tuple [str , List [int ]]:
202
- """
203
- Enforces the appropriate input shapes for the onnx model, as well as
204
- checks whether kv cache is enabled or not.
205
-
206
- :param onnx_file_path: The path to the onnx model file that will be
207
- overwritten with the new input shapes
208
- :param batch_size: The batch size to use for the input
209
- :param sequence_length: The sequence length to use for the input
210
- :param input_ids_length: The length of input_ids
211
- :return: The path to the onnx model file that has been overwritten
212
- with the new input shapes, as well as the indices of the inputs
213
- that should be cached
214
- """
215
- model = onnx .load (onnx_file_path , load_external_data = False )
216
- initializer_input_names = set (node .name for node in model .graph .initializer )
217
- external_inputs = [
218
- inp for inp in model .graph .input if inp .name not in initializer_input_names
219
- ]
220
- for external_input in external_inputs :
221
- # overwrite the batch size for all the inputs
222
- external_input .type .tensor_type .shape .dim [0 ].dim_value = batch_size
223
-
224
- if external_input .name in ["input_ids" , "positions" ]:
225
- external_input .type .tensor_type .shape .dim [
226
- 1
227
- ].dim_value = input_ids_length
228
- elif external_input .name == "attention_mask" :
229
- external_input .type .tensor_type .shape .dim [1 ].dim_value = sequence_length
230
- elif external_input .name .startswith (_CACHE_INPUT_NAME ):
231
- external_input .type .tensor_type .shape .dim [2 ].dim_value = (
232
- sequence_length - input_ids_length
233
- )
234
- elif external_input .name .startswith ("causal_mask" ):
235
- external_input .type .tensor_type .shape .dim [
236
- 2
237
- ].dim_value = input_ids_length
238
- external_input .type .tensor_type .shape .dim [3 ].dim_value = sequence_length
239
- else :
240
- raise ValueError (
241
- f"Unexpected external input name: { external_input .name } "
242
- )
243
-
244
- _LOGGER .info (
245
- "Overwriting in-place the input shapes "
246
- f"of the transformer model at { onnx_file_path } "
247
- )
248
- save_onnx (model , onnx_file_path )
249
-
250
- output_indices_to_be_cached = [
251
- 1 if inp .name .startswith ("present" ) else 0 for inp in model .graph .output
252
- ]
253
- if any (output_indices_to_be_cached ):
254
- kv_cache_elem_type = next (
255
- inp
256
- for inp in model .graph .input
257
- if inp .name .startswith (_CACHE_INPUT_NAME )
258
- ).type .tensor_type .elem_type
259
- self .kv_cache_data_type = translate_onnx_type_to_numpy (kv_cache_elem_type )
260
-
261
- return onnx_file_path , output_indices_to_be_cached
222
+ cache .set_capacity (target_cache_capacity )
223
+ self .kv_cache = cache
262
224
263
225
def generate_token (self , logits : numpy .ndarray ) -> numpy .ndarray :
264
226
"""
@@ -283,7 +245,7 @@ def reset_kv_cache(self):
283
245
kv_cache_state = self ._initialize_kv_cache_state (
284
246
self .sequence_length - self .input_ids_length
285
247
)
286
- self .kv_cache .setup_session (
248
+ self .kv_cache .setup (
287
249
session_id = self ._session_id ,
288
250
state = kv_cache_state ,
289
251
num_processed_tokens = 0 ,
@@ -328,7 +290,7 @@ def update_kv_cache(
328
290
name : array for name , array in zip (cache_onnx_names , kv_cache_state )
329
291
}
330
292
331
- self .kv_cache .update_session (
293
+ self .kv_cache .update (
332
294
state = kv_cache_state ,
333
295
input_ids_len = input_ids_len ,
334
296
)
@@ -364,6 +326,6 @@ def _should_freeze_first_position(tokenizer) -> bool:
364
326
# (True if tokenizer has a prefix for a BOS token)
365
327
if tokenizer is None :
366
328
return False
367
- if hasattr (tokenizer , "bos_token " ):
329
+ if hasattr (tokenizer , "add_bos_token " ):
368
330
return True
369
331
return False
0 commit comments