Skip to content

Commit 0fc35f0

Browse files
dbogunowiczbfineranSageMoore
authored
[Text Generation] Support for causal masks, internal KV cache, and initial testing framework (#1172)
* initial commit * improved logic * additional improvements * Update src/deepsparse/transformers/pipelines/text_generation.py * Update src/deepsparse/utils/onnx.py Co-authored-by: Benjamin Fineran <[email protected]> * Update src/deepsparse/utils/onnx.py Co-authored-by: Benjamin Fineran <[email protected]> * response to Ben's comments * finish rebasing * update user messages + add assertion for safety * minor improvements before landing * Fix the helper function that has been broken after a merge * [Text Generation] Internal KV Cache Support + Initial Testing Framework (#1163) * Create test_nl_decoder_engine.py * [Text Generation][Tests] DecoderKVCache (#1154) * [Text Generation][Tests] NLDecoderEngine (#1155) * initial commit * initial commit * [Text Generation][Tests] Text Generation Pipeline (#1162) * initial implementation * problems with multitoken prefill * almost there... * finally all tests pass * just need to change to stub * fix bad merge * Make tests work with stub (as much as possible), cleanup test names, disable heavy tests, include patch for running without causal mask * use patch from unittest library - remove additional dependency * Update tests/deepsparse/transformers/pipelines/test_text_generation.py * clarify todo comment * [Text Generation] KV Cache internal Deepsparse support (#1135) * fix kv cache * refactor * add validation pathway * avx2 support * initial commit * initial commit * initial implementation * problems with multitoken prefill * its working * almost there... * finally all tests pass * just need to change to stub * fix bad merge * added some tests * ready for review * full support --------- Co-authored-by: dbogunowicz <[email protected]> Co-authored-by: Damian <[email protected]> * incomplete string in parametrize * few nits before the merge --------- Co-authored-by: Benjamin Fineran <[email protected]> Co-authored-by: Sage Moore <[email protected]> --------- Co-authored-by: Benjamin Fineran <[email protected]> Co-authored-by: Sage Moore <[email protected]>
1 parent c723225 commit 0fc35f0

File tree

11 files changed

+675
-195
lines changed

11 files changed

+675
-195
lines changed

src/deepsparse/engine.py

+3
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def __init__(
300300
num_streams: int = None,
301301
scheduler: Scheduler = None,
302302
input_shapes: List[List[int]] = None,
303+
cached_outputs: List[bool] = None,
303304
):
304305
BaseEngine.construct(
305306
self, model, batch_size, num_cores, num_streams, scheduler, input_shapes
@@ -316,6 +317,7 @@ def __init__(
316317
self._num_streams,
317318
self._scheduler.value,
318319
None,
320+
cached_outputs,
319321
)
320322
else:
321323
self._eng_net = LIB.deepsparse_engine(
@@ -325,6 +327,7 @@ def __init__(
325327
self._num_streams,
326328
self._scheduler.value,
327329
None,
330+
cached_outputs,
328331
)
329332

330333
def __call__(

src/deepsparse/transformers/engines/nl_decoder_engine.py

+44-82
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import copy
1514
import logging
1615
from typing import Any, Dict, List, Optional, Tuple
1716

1817
import numpy
19-
import onnx
2018
from transformers import AutoTokenizer
2119

2220
from deepsparse.engine import Context
2321
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
2422
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
25-
from deepsparse.transformers.utils.helpers import generate_session_id
23+
from deepsparse.transformers.utils.helpers import (
24+
generate_session_id,
25+
overwrite_onnx_model_inputs,
26+
)
2627
from deepsparse.utils.data import numpy_softmax
27-
from deepsparse.utils.onnx import translate_onnx_type_to_numpy
28-
from sparsezoo.utils.onnx import save_onnx
2928

3029

3130
_LOGGER = logging.getLogger(__name__)
@@ -71,7 +70,11 @@ def __init__(
7170
# flag to indicate if the model is quantized or not
7271
self.kv_cache_data_type = None
7372

74-
onnx_file_path, output_indices_to_be_cached = self.overwrite_onnx_model_inputs(
73+
(
74+
onnx_file_path,
75+
output_indices_to_be_cached,
76+
kv_cache_data_type,
77+
) = overwrite_onnx_model_inputs(
7578
onnx_file_path=onnx_file_path,
7679
batch_size=engine_args.get("batch_size", 1),
7780
sequence_length=sequence_length,
@@ -80,9 +83,10 @@ def __init__(
8083
kv_cache_enabled = False
8184
if sum(output_indices_to_be_cached):
8285
kv_cache_enabled = True
86+
self.kv_cache_data_type = kv_cache_data_type
8387
if use_deepsparse_cache and engine_type == DEEPSPARSE_ENGINE:
8488
# inform the engine, that are using the kv cache
85-
engine_args["cache_output_bools"] = output_indices_to_be_cached
89+
engine_args["cached_outputs"] = output_indices_to_be_cached
8690

8791
self.engine = create_engine(
8892
onnx_file_path=onnx_file_path,
@@ -100,6 +104,7 @@ def __init__(
100104
)
101105
self._freeze_first_position = self._should_freeze_first_position(tokenizer)
102106
self._session_id = generate_session_id()
107+
self._engine_type = engine_type
103108

104109
@property
105110
def session_id(self) -> str:
@@ -135,6 +140,32 @@ def num_non_blank_cache_entries(self) -> int:
135140
"""
136141
return self.kv_cache.num_non_blank_entries
137142

143+
def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]:
144+
"""
145+
Run the engine with the given inputs.
146+
147+
If the internal deepsparse kv cache management is enable,
148+
the LIB.kv_cache class object will be passed to the engine
149+
call as well.
150+
151+
:param inputs: The inputs to run the engine with
152+
:param val_inp: Whether the input is for validation or not
153+
154+
:return: The output of the engine
155+
"""
156+
157+
if self.kv_cache is not None:
158+
if self.kv_cache._kv_cache is not None:
159+
if val_inp:
160+
self.engine._validate_inputs(inputs)
161+
# model has kv cache support, as well as deepsparse
162+
# internal management of the kv cache
163+
return self.engine._eng_net.execute_list_out(
164+
inputs, self.kv_cache._kv_cache
165+
)
166+
167+
return self.engine.run(inputs, val_inp)
168+
138169
def __call__(
139170
self,
140171
inp: List[numpy.ndarray],
@@ -154,7 +185,7 @@ def __call__(
154185
# to the input
155186
inp = self.add_kv_cache_to_input(inp)
156187

157-
out = self.engine.run(inp, val_inp)
188+
out = self.run(inp, val_inp)
158189

159190
if self.kv_cache:
160191
logits, *kv_cache_state = out
@@ -187,78 +218,9 @@ def transfer_cache_state(self, cache: DecoderKVCache):
187218
:param cache: The `DecoderKVCache` object to transfer to the engine
188219
from
189220
"""
190-
cache_to_copy = copy.deepcopy(cache)
191221
target_cache_capacity = self.sequence_length - self.input_ids_length
192-
cache_to_copy.set_capacity(target_cache_capacity)
193-
self.kv_cache = cache_to_copy
194-
195-
def overwrite_onnx_model_inputs(
196-
self,
197-
onnx_file_path: str,
198-
sequence_length: int,
199-
input_ids_length: int,
200-
batch_size: int = 1,
201-
) -> Tuple[str, List[int]]:
202-
"""
203-
Enforces the appropriate input shapes for the onnx model, as well as
204-
checks whether kv cache is enabled or not.
205-
206-
:param onnx_file_path: The path to the onnx model file that will be
207-
overwritten with the new input shapes
208-
:param batch_size: The batch size to use for the input
209-
:param sequence_length: The sequence length to use for the input
210-
:param input_ids_length: The length of input_ids
211-
:return: The path to the onnx model file that has been overwritten
212-
with the new input shapes, as well as the indices of the inputs
213-
that should be cached
214-
"""
215-
model = onnx.load(onnx_file_path, load_external_data=False)
216-
initializer_input_names = set(node.name for node in model.graph.initializer)
217-
external_inputs = [
218-
inp for inp in model.graph.input if inp.name not in initializer_input_names
219-
]
220-
for external_input in external_inputs:
221-
# overwrite the batch size for all the inputs
222-
external_input.type.tensor_type.shape.dim[0].dim_value = batch_size
223-
224-
if external_input.name in ["input_ids", "positions"]:
225-
external_input.type.tensor_type.shape.dim[
226-
1
227-
].dim_value = input_ids_length
228-
elif external_input.name == "attention_mask":
229-
external_input.type.tensor_type.shape.dim[1].dim_value = sequence_length
230-
elif external_input.name.startswith(_CACHE_INPUT_NAME):
231-
external_input.type.tensor_type.shape.dim[2].dim_value = (
232-
sequence_length - input_ids_length
233-
)
234-
elif external_input.name.startswith("causal_mask"):
235-
external_input.type.tensor_type.shape.dim[
236-
2
237-
].dim_value = input_ids_length
238-
external_input.type.tensor_type.shape.dim[3].dim_value = sequence_length
239-
else:
240-
raise ValueError(
241-
f"Unexpected external input name: {external_input.name}"
242-
)
243-
244-
_LOGGER.info(
245-
"Overwriting in-place the input shapes "
246-
f"of the transformer model at {onnx_file_path}"
247-
)
248-
save_onnx(model, onnx_file_path)
249-
250-
output_indices_to_be_cached = [
251-
1 if inp.name.startswith("present") else 0 for inp in model.graph.output
252-
]
253-
if any(output_indices_to_be_cached):
254-
kv_cache_elem_type = next(
255-
inp
256-
for inp in model.graph.input
257-
if inp.name.startswith(_CACHE_INPUT_NAME)
258-
).type.tensor_type.elem_type
259-
self.kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type)
260-
261-
return onnx_file_path, output_indices_to_be_cached
222+
cache.set_capacity(target_cache_capacity)
223+
self.kv_cache = cache
262224

263225
def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
264226
"""
@@ -283,7 +245,7 @@ def reset_kv_cache(self):
283245
kv_cache_state = self._initialize_kv_cache_state(
284246
self.sequence_length - self.input_ids_length
285247
)
286-
self.kv_cache.setup_session(
248+
self.kv_cache.setup(
287249
session_id=self._session_id,
288250
state=kv_cache_state,
289251
num_processed_tokens=0,
@@ -328,7 +290,7 @@ def update_kv_cache(
328290
name: array for name, array in zip(cache_onnx_names, kv_cache_state)
329291
}
330292

331-
self.kv_cache.update_session(
293+
self.kv_cache.update(
332294
state=kv_cache_state,
333295
input_ids_len=input_ids_len,
334296
)
@@ -364,6 +326,6 @@ def _should_freeze_first_position(tokenizer) -> bool:
364326
# (True if tokenizer has a prefix for a BOS token)
365327
if tokenizer is None:
366328
return False
367-
if hasattr(tokenizer, "bos_token"):
329+
if hasattr(tokenizer, "add_bos_token"):
368330
return True
369331
return False

0 commit comments

Comments
 (0)