Skip to content

Commit 43e70a5

Browse files
dbogunowiczbfineranSageMoore
authored
[Text Generation] Internal KV Cache Support + Initial Testing Framework (#1163)
* Create test_nl_decoder_engine.py * [Text Generation][Tests] DecoderKVCache (#1154) * [Text Generation][Tests] NLDecoderEngine (#1155) * initial commit * initial commit * [Text Generation][Tests] Text Generation Pipeline (#1162) * initial implementation * problems with multitoken prefill * almost there... * finally all tests pass * just need to change to stub * fix bad merge * Make tests work with stub (as much as possible), cleanup test names, disable heavy tests, include patch for running without causal mask * use patch from unittest library - remove additional dependency * Update tests/deepsparse/transformers/pipelines/test_text_generation.py * clarify todo comment * [Text Generation] KV Cache internal Deepsparse support (#1135) * fix kv cache * refactor * add validation pathway * avx2 support * initial commit * initial commit * initial implementation * problems with multitoken prefill * its working * almost there... * finally all tests pass * just need to change to stub * fix bad merge * added some tests * ready for review * full support --------- Co-authored-by: dbogunowicz <[email protected]> Co-authored-by: Damian <[email protected]> * incomplete string in parametrize * few nits before the merge --------- Co-authored-by: Benjamin Fineran <[email protected]> Co-authored-by: Sage Moore <[email protected]>
1 parent 4881d30 commit 43e70a5

File tree

11 files changed

+552
-147
lines changed

11 files changed

+552
-147
lines changed

src/deepsparse/engine.py

+3
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def __init__(
300300
num_streams: int = None,
301301
scheduler: Scheduler = None,
302302
input_shapes: List[List[int]] = None,
303+
cached_outputs: List[bool] = None,
303304
):
304305
BaseEngine.construct(
305306
self, model, batch_size, num_cores, num_streams, scheduler, input_shapes
@@ -316,6 +317,7 @@ def __init__(
316317
self._num_streams,
317318
self._scheduler.value,
318319
None,
320+
cached_outputs,
319321
)
320322
else:
321323
self._eng_net = LIB.deepsparse_engine(
@@ -325,6 +327,7 @@ def __init__(
325327
self._num_streams,
326328
self._scheduler.value,
327329
None,
330+
cached_outputs,
328331
)
329332

330333
def __call__(

src/deepsparse/transformers/engines/nl_decoder_engine.py

+44-82
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import copy
1514
import logging
1615
from typing import Any, Dict, List, Optional, Tuple
1716

1817
import numpy
19-
import onnx
2018
from transformers import AutoTokenizer
2119

2220
from deepsparse.engine import Context
2321
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
2422
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
25-
from deepsparse.transformers.utils.helpers import generate_session_id
23+
from deepsparse.transformers.utils.helpers import (
24+
generate_session_id,
25+
overwrite_onnx_model_inputs,
26+
)
2627
from deepsparse.utils.data import numpy_softmax
27-
from deepsparse.utils.onnx import translate_onnx_type_to_numpy
28-
from sparsezoo.utils.onnx import save_onnx
2928

3029

3130
_LOGGER = logging.getLogger(__name__)
@@ -71,7 +70,11 @@ def __init__(
7170
# flag to indicate if the model is quantized or not
7271
self.kv_cache_data_type = None
7372

74-
onnx_file_path, output_indices_to_be_cached = self.overwrite_onnx_model_inputs(
73+
(
74+
onnx_file_path,
75+
output_indices_to_be_cached,
76+
kv_cache_data_type,
77+
) = overwrite_onnx_model_inputs(
7578
onnx_file_path=onnx_file_path,
7679
batch_size=engine_args.get("batch_size", 1),
7780
sequence_length=sequence_length,
@@ -80,9 +83,10 @@ def __init__(
8083
kv_cache_enabled = False
8184
if sum(output_indices_to_be_cached):
8285
kv_cache_enabled = True
86+
self.kv_cache_data_type = kv_cache_data_type
8387
if use_deepsparse_cache and engine_type == DEEPSPARSE_ENGINE:
8488
# inform the engine, that are using the kv cache
85-
engine_args["cache_output_bools"] = output_indices_to_be_cached
89+
engine_args["cached_outputs"] = output_indices_to_be_cached
8690

8791
self.engine = create_engine(
8892
onnx_file_path=onnx_file_path,
@@ -100,6 +104,7 @@ def __init__(
100104
)
101105
self._freeze_first_position = self._should_freeze_first_position(tokenizer)
102106
self._session_id = generate_session_id()
107+
self._engine_type = engine_type
103108

104109
@property
105110
def session_id(self) -> str:
@@ -135,6 +140,32 @@ def num_non_blank_cache_entries(self) -> int:
135140
"""
136141
return self.kv_cache.num_non_blank_entries
137142

143+
def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]:
144+
"""
145+
Run the engine with the given inputs.
146+
147+
If the internal deepsparse kv cache management is enable,
148+
the LIB.kv_cache class object will be passed to the engine
149+
call as well.
150+
151+
:param inputs: The inputs to run the engine with
152+
:param val_inp: Whether the input is for validation or not
153+
154+
:return: The output of the engine
155+
"""
156+
157+
if self.kv_cache is not None:
158+
if self.kv_cache._kv_cache is not None:
159+
if val_inp:
160+
self.engine._validate_inputs(inputs)
161+
# model has kv cache support, as well as deepsparse
162+
# internal management of the kv cache
163+
return self.engine._eng_net.execute_list_out(
164+
inputs, self.kv_cache._kv_cache
165+
)
166+
167+
return self.engine.run(inputs, val_inp)
168+
138169
def __call__(
139170
self,
140171
inp: List[numpy.ndarray],
@@ -154,7 +185,7 @@ def __call__(
154185
# to the input
155186
inp = self.add_kv_cache_to_input(inp)
156187

157-
out = self.engine.run(inp, val_inp)
188+
out = self.run(inp, val_inp)
158189

159190
if self.kv_cache:
160191
logits, *kv_cache_state = out
@@ -187,78 +218,9 @@ def transfer_cache_state(self, cache: DecoderKVCache):
187218
:param cache: The `DecoderKVCache` object to transfer to the engine
188219
from
189220
"""
190-
cache_to_copy = copy.deepcopy(cache)
191221
target_cache_capacity = self.sequence_length - self.input_ids_length
192-
cache_to_copy.set_capacity(target_cache_capacity)
193-
self.kv_cache = cache_to_copy
194-
195-
def overwrite_onnx_model_inputs(
196-
self,
197-
onnx_file_path: str,
198-
sequence_length: int,
199-
input_ids_length: int,
200-
batch_size: int = 1,
201-
) -> Tuple[str, List[int]]:
202-
"""
203-
Enforces the appropriate input shapes for the onnx model, as well as
204-
checks whether kv cache is enabled or not.
205-
206-
:param onnx_file_path: The path to the onnx model file that will be
207-
overwritten with the new input shapes
208-
:param batch_size: The batch size to use for the input
209-
:param sequence_length: The sequence length to use for the input
210-
:param input_ids_length: The length of input_ids
211-
:return: The path to the onnx model file that has been overwritten
212-
with the new input shapes, as well as the indices of the inputs
213-
that should be cached
214-
"""
215-
model = onnx.load(onnx_file_path, load_external_data=False)
216-
initializer_input_names = set(node.name for node in model.graph.initializer)
217-
external_inputs = [
218-
inp for inp in model.graph.input if inp.name not in initializer_input_names
219-
]
220-
for external_input in external_inputs:
221-
# overwrite the batch size for all the inputs
222-
external_input.type.tensor_type.shape.dim[0].dim_value = batch_size
223-
224-
if external_input.name in ["input_ids", "positions"]:
225-
external_input.type.tensor_type.shape.dim[
226-
1
227-
].dim_value = input_ids_length
228-
elif external_input.name == "attention_mask":
229-
external_input.type.tensor_type.shape.dim[1].dim_value = sequence_length
230-
elif external_input.name.startswith(_CACHE_INPUT_NAME):
231-
external_input.type.tensor_type.shape.dim[2].dim_value = (
232-
sequence_length - input_ids_length
233-
)
234-
elif external_input.name.startswith("causal_mask"):
235-
external_input.type.tensor_type.shape.dim[
236-
2
237-
].dim_value = input_ids_length
238-
external_input.type.tensor_type.shape.dim[3].dim_value = sequence_length
239-
else:
240-
raise ValueError(
241-
f"Unexpected external input name: {external_input.name}"
242-
)
243-
244-
_LOGGER.info(
245-
"Overwriting in-place the input shapes "
246-
f"of the transformer model at {onnx_file_path}"
247-
)
248-
save_onnx(model, onnx_file_path)
249-
250-
output_indices_to_be_cached = [
251-
1 if inp.name.startswith("present") else 0 for inp in model.graph.output
252-
]
253-
if any(output_indices_to_be_cached):
254-
kv_cache_elem_type = next(
255-
inp
256-
for inp in model.graph.input
257-
if inp.name.startswith(_CACHE_INPUT_NAME)
258-
).type.tensor_type.elem_type
259-
self.kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type)
260-
261-
return onnx_file_path, output_indices_to_be_cached
222+
cache.set_capacity(target_cache_capacity)
223+
self.kv_cache = cache
262224

263225
def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
264226
"""
@@ -283,7 +245,7 @@ def reset_kv_cache(self):
283245
kv_cache_state = self._initialize_kv_cache_state(
284246
self.sequence_length - self.input_ids_length
285247
)
286-
self.kv_cache.setup_session(
248+
self.kv_cache.setup(
287249
session_id=self._session_id,
288250
state=kv_cache_state,
289251
num_processed_tokens=0,
@@ -328,7 +290,7 @@ def update_kv_cache(
328290
name: array for name, array in zip(cache_onnx_names, kv_cache_state)
329291
}
330292

331-
self.kv_cache.update_session(
293+
self.kv_cache.update(
332294
state=kv_cache_state,
333295
input_ids_len=input_ids_len,
334296
)
@@ -364,6 +326,6 @@ def _should_freeze_first_position(tokenizer) -> bool:
364326
# (True if tokenizer has a prefix for a BOS token)
365327
if tokenizer is None:
366328
return False
367-
if hasattr(tokenizer, "bos_token"):
329+
if hasattr(tokenizer, "add_bos_token"):
368330
return True
369331
return False

src/deepsparse/transformers/pipelines/text_generation.py

+19-28
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from transformers import TextStreamer
2525

2626
from deepsparse import Pipeline
27-
from deepsparse.cpu import cpu_avx512_compatible
2827
from deepsparse.pipeline import DEEPSPARSE_ENGINE
2928
from deepsparse.transformers.engines import NLDecoderEngine
3029
from deepsparse.transformers.pipelines import TransformersPipeline
@@ -146,22 +145,16 @@ def __init__(
146145
**kwargs,
147146
):
148147
kwargs_engine_type = kwargs.get("engine_type", DEEPSPARSE_ENGINE)
149-
if not cpu_avx512_compatible() and kwargs_engine_type == DEEPSPARSE_ENGINE:
150-
warnings.warn(
151-
"AVX512 support not detected, disabling internal management "
152-
"of KV cache which may affect performance. To enable full "
153-
"performance, deploy on an AVX512-compatible system."
154-
)
155-
use_deepsparse_cache = False
156148

157149
if use_deepsparse_cache:
158150
if kwargs_engine_type != DEEPSPARSE_ENGINE:
159-
raise ValueError(
151+
_LOGGER.warning(
160152
"`use_deepsparse_cache` is set to True "
161153
"but the chosen `engine_type` "
162154
f"is {kwargs_engine_type}. "
163-
f"Make sure to set `engine_type` to {DEEPSPARSE_ENGINE}"
155+
f"The optimized kv cache management is disabled."
164156
)
157+
use_deepsparse_cache = False
165158

166159
super().__init__(
167160
**kwargs, _delay_engine_initialize=True, _delay_overwriting_inputs=True
@@ -493,9 +486,8 @@ def prompt_inference(
493486
with self.timer_manager.current.time(
494487
_TextGenerationTimings.PROMPT_PREFILL_SINGLE
495488
):
496-
new_token, new_logits = self.autoregressive_inference(
497-
run_tokens, shift_positions_by_one=not bool(num_tokens_processed)
498-
)
489+
new_token, new_logits = self.autoregressive_inference(run_tokens)
490+
499491
prompt_logits.append(new_logits)
500492

501493
tokens.append(new_token)
@@ -505,16 +497,12 @@ def prompt_inference(
505497
def autoregressive_inference(
506498
self,
507499
tokens: List[int],
508-
shift_positions_by_one: bool = False,
509500
) -> Tuple[int, numpy.ndarray]:
510501
"""
511502
An inference run that processes the last token to generate
512503
a new token and new logits.
513504
514505
:param tokens: The current context (prompt + generated tokens so far)
515-
:param shift_positions_by_one: Whether to shift the positions
516-
by one. Used if we are processing the prompt from the scratch
517-
(i.e. not using the multitoken engine)
518506
:return: The new, generated token and the logits for the new token
519507
(with dimensions ['batch_size', 'num_tokens', 'vocab_size'])
520508
"""
@@ -526,8 +514,7 @@ def autoregressive_inference(
526514
num_tokens_processed = min(len(tokens), self.sequence_length) # cap by seq len
527515
attention_mask[:, -num_tokens_processed:] = 1
528516
positions = numpy.array([[len(tokens)]], dtype=numpy.int64)
529-
if shift_positions_by_one:
530-
positions -= 1
517+
positions -= 1
531518
input_ids = numpy.array([[new_token]])
532519
causal_mask = create_causal_mask(input_ids, attention_mask)
533520

@@ -584,28 +571,28 @@ def engine_inputs_for_prefill(
584571
num_batches = len(tokens) // self.prompt_processing_sequence_length
585572

586573
token_batches = [
587-
tokens[i : i + self.prompt_processing_sequence_length]
588-
for i in range(num_batches)
574+
tokens[
575+
i
576+
* self.prompt_processing_sequence_length : (i + 1)
577+
* self.prompt_processing_sequence_length
578+
]
579+
for i in range(0, num_batches)
589580
]
590581

591582
for idx, token_batch in enumerate(token_batches):
592583
engine_inputs = []
593-
584+
num_cached_entries = self.multitoken_engine.num_non_blank_cache_entries
594585
for name in self.multitoken_engine.onnx_input_names_no_cache:
595586
if name == "input_ids":
596587
engine_input = numpy.array([token_batch])
597588

598589
elif name == "attention_mask":
599-
num_cached_entries = (
600-
self.multitoken_engine.num_non_blank_cache_entries
601-
)
602-
603590
# create an empty attention mask
604591
engine_input = numpy.zeros(
605592
(1, self.sequence_length), dtype=numpy.int64
606593
)
607594
# fill it out with 1s (from the right), so that the number
608-
# of unmaksed entries is equal to the sum of:
595+
# of unmasked entries is equal to the sum of:
609596
engine_input[
610597
:,
611598
-(
@@ -625,7 +612,11 @@ def engine_inputs_for_prefill(
625612
engine_input = numpy.array([[idx]], dtype=numpy.int64)
626613
else:
627614
engine_input = (
628-
numpy.arange(self.prompt_processing_sequence_length)
615+
numpy.arange(
616+
num_cached_entries,
617+
num_cached_entries
618+
+ self.prompt_processing_sequence_length,
619+
)
629620
.reshape(1, -1)
630621
.astype(numpy.int64)
631622
)

0 commit comments

Comments
 (0)