Skip to content

Commit 459c21f

Browse files
authored
[Text Generation][Tests] Text Generation Pipeline (#1162)
* initial implementation * problems with multitoken prefill * almost there... * finally all tests pass * just need to change to stub * fix bad merge
1 parent d851ab4 commit 459c21f

File tree

4 files changed

+307
-116
lines changed

4 files changed

+307
-116
lines changed

Diff for: src/deepsparse/transformers/engines/nl_decoder_engine.py

+13-76
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@
1616
from typing import Any, Dict, List, Optional, Tuple
1717

1818
import numpy
19-
import onnx
2019
from transformers import AutoTokenizer
2120

2221
from deepsparse.engine import Context
2322
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
2423
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
25-
from deepsparse.transformers.utils.helpers import generate_session_id
24+
from deepsparse.transformers.utils.helpers import (
25+
generate_session_id,
26+
overwrite_onnx_model_inputs,
27+
)
2628
from deepsparse.utils.data import numpy_softmax
27-
from deepsparse.utils.onnx import translate_onnx_type_to_numpy
28-
from sparsezoo.utils.onnx import save_onnx
2929

3030

3131
_LOGGER = logging.getLogger(__name__)
@@ -71,7 +71,11 @@ def __init__(
7171
# flag to indicate if the model is quantized or not
7272
self.kv_cache_data_type = None
7373

74-
onnx_file_path, output_indices_to_be_cached = self.overwrite_onnx_model_inputs(
74+
(
75+
onnx_file_path,
76+
output_indices_to_be_cached,
77+
kv_cache_data_type,
78+
) = overwrite_onnx_model_inputs(
7579
onnx_file_path=onnx_file_path,
7680
batch_size=engine_args.get("batch_size", 1),
7781
sequence_length=sequence_length,
@@ -80,6 +84,7 @@ def __init__(
8084
kv_cache_enabled = False
8185
if sum(output_indices_to_be_cached):
8286
kv_cache_enabled = True
87+
self.kv_cache_data_type = kv_cache_data_type
8388
if use_deepsparse_cache and engine_type == DEEPSPARSE_ENGINE:
8489
# inform the engine, that are using the kv cache
8590
engine_args["cache_output_bools"] = output_indices_to_be_cached
@@ -192,74 +197,6 @@ def transfer_cache_state(self, cache: DecoderKVCache):
192197
cache_to_copy.set_capacity(target_cache_capacity)
193198
self.kv_cache = cache_to_copy
194199

195-
def overwrite_onnx_model_inputs(
196-
self,
197-
onnx_file_path: str,
198-
sequence_length: int,
199-
input_ids_length: int,
200-
batch_size: int = 1,
201-
) -> Tuple[str, List[int]]:
202-
"""
203-
Enforces the appropriate input shapes for the onnx model, as well as
204-
checks whether kv cache is enabled or not.
205-
206-
:param onnx_file_path: The path to the onnx model file that will be
207-
overwritten with the new input shapes
208-
:param batch_size: The batch size to use for the input
209-
:param sequence_length: The sequence length to use for the input
210-
:param input_ids_length: The length of input_ids
211-
:return: The path to the onnx model file that has been overwritten
212-
with the new input shapes, as well as the indices of the inputs
213-
that should be cached
214-
"""
215-
model = onnx.load(onnx_file_path, load_external_data=False)
216-
initializer_input_names = set(node.name for node in model.graph.initializer)
217-
external_inputs = [
218-
inp for inp in model.graph.input if inp.name not in initializer_input_names
219-
]
220-
for external_input in external_inputs:
221-
# overwrite the batch size for all the inputs
222-
external_input.type.tensor_type.shape.dim[0].dim_value = batch_size
223-
224-
if external_input.name in ["input_ids", "positions"]:
225-
external_input.type.tensor_type.shape.dim[
226-
1
227-
].dim_value = input_ids_length
228-
elif external_input.name == "attention_mask":
229-
external_input.type.tensor_type.shape.dim[1].dim_value = sequence_length
230-
elif external_input.name.startswith(_CACHE_INPUT_NAME):
231-
external_input.type.tensor_type.shape.dim[2].dim_value = (
232-
sequence_length - input_ids_length
233-
)
234-
elif external_input.name.startswith("causal_mask"):
235-
external_input.type.tensor_type.shape.dim[
236-
2
237-
].dim_value = input_ids_length
238-
external_input.type.tensor_type.shape.dim[3].dim_value = sequence_length
239-
else:
240-
raise ValueError(
241-
f"Unexpected external input name: {external_input.name}"
242-
)
243-
244-
_LOGGER.info(
245-
"Overwriting in-place the input shapes "
246-
f"of the transformer model at {onnx_file_path}"
247-
)
248-
save_onnx(model, onnx_file_path)
249-
250-
output_indices_to_be_cached = [
251-
1 if inp.name.startswith("present") else 0 for inp in model.graph.output
252-
]
253-
if any(output_indices_to_be_cached):
254-
kv_cache_elem_type = next(
255-
inp
256-
for inp in model.graph.input
257-
if inp.name.startswith(_CACHE_INPUT_NAME)
258-
).type.tensor_type.elem_type
259-
self.kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type)
260-
261-
return onnx_file_path, output_indices_to_be_cached
262-
263200
def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
264201
"""
265202
Samples a token from the logits using the sampling temperature.
@@ -283,7 +220,7 @@ def reset_kv_cache(self):
283220
kv_cache_state = self._initialize_kv_cache_state(
284221
self.sequence_length - self.input_ids_length
285222
)
286-
self.kv_cache.setup_session(
223+
self.kv_cache.setup(
287224
session_id=self._session_id,
288225
state=kv_cache_state,
289226
num_processed_tokens=0,
@@ -328,7 +265,7 @@ def update_kv_cache(
328265
name: array for name, array in zip(cache_onnx_names, kv_cache_state)
329266
}
330267

331-
self.kv_cache.update_session(
268+
self.kv_cache.update(
332269
state=kv_cache_state,
333270
input_ids_len=input_ids_len,
334271
)
@@ -364,6 +301,6 @@ def _should_freeze_first_position(tokenizer) -> bool:
364301
# (True if tokenizer has a prefix for a BOS token)
365302
if tokenizer is None:
366303
return False
367-
if hasattr(tokenizer, "bos_token"):
304+
if hasattr(tokenizer, "add_bos_token"):
368305
return True
369306
return False

Diff for: src/deepsparse/transformers/pipelines/text_generation.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,8 @@ def prompt_inference(
432432
with self.timer_manager.current.time(
433433
_TextGenerationTimings.PROMPT_PREFILL_SINGLE
434434
):
435-
new_token, new_logits = self.autoregressive_inference(
436-
run_tokens, shift_positions_by_one=not bool(num_tokens_processed)
437-
)
435+
new_token, new_logits = self.autoregressive_inference(run_tokens)
436+
438437
prompt_logits.append(new_logits)
439438

440439
tokens.append(new_token)
@@ -444,16 +443,12 @@ def prompt_inference(
444443
def autoregressive_inference(
445444
self,
446445
tokens: List[int],
447-
shift_positions_by_one: bool = False,
448446
) -> Tuple[int, numpy.ndarray]:
449447
"""
450448
An inference run that processes the last token to generate
451449
a new token and new logits.
452450
453451
:param tokens: The current context (prompt + generated tokens so far)
454-
:param shift_positions_by_one: Whether to shift the positions
455-
by one. Used if we are processing the prompt from the scratch
456-
(i.e. not using the multitoken engine)
457452
:return: The new, generated token and the logits for the new token
458453
(with dimensions ['batch_size', 'num_tokens', 'vocab_size'])
459454
"""
@@ -465,8 +460,7 @@ def autoregressive_inference(
465460
num_tokens_processed = min(len(tokens), self.sequence_length) # cap by seq len
466461
attention_mask[:, -num_tokens_processed:] = 1
467462
positions = numpy.array([[len(tokens)]], dtype=numpy.int64)
468-
if shift_positions_by_one:
469-
positions -= 1
463+
positions -= 1
470464
input_ids = numpy.array([[new_token]])
471465
causal_mask = create_causal_mask(input_ids, attention_mask)
472466

@@ -523,28 +517,28 @@ def engine_inputs_for_prefill(
523517
num_batches = len(tokens) // self.prompt_processing_sequence_length
524518

525519
token_batches = [
526-
tokens[i : i + self.prompt_processing_sequence_length]
527-
for i in range(num_batches)
520+
tokens[
521+
i
522+
* self.prompt_processing_sequence_length : (i + 1)
523+
* self.prompt_processing_sequence_length
524+
]
525+
for i in range(0, num_batches)
528526
]
529527

530528
for idx, token_batch in enumerate(token_batches):
531529
engine_inputs = []
532-
530+
num_cached_entries = self.multitoken_engine.num_non_blank_cache_entries
533531
for name in self.multitoken_engine.onnx_input_names_no_cache:
534532
if name == "input_ids":
535533
engine_input = numpy.array([token_batch])
536534

537535
elif name == "attention_mask":
538-
num_cached_entries = (
539-
self.multitoken_engine.num_non_blank_cache_entries
540-
)
541-
542536
# create an empty attention mask
543537
engine_input = numpy.zeros(
544538
(1, self.sequence_length), dtype=numpy.int64
545539
)
546540
# fill it out with 1s (from the right), so that the number
547-
# of unmaksed entries is equal to the sum of:
541+
# of unmasked entries is equal to the sum of:
548542
engine_input[
549543
:,
550544
-(
@@ -564,7 +558,11 @@ def engine_inputs_for_prefill(
564558
engine_input = numpy.array([[idx]], dtype=numpy.int64)
565559
else:
566560
engine_input = (
567-
numpy.arange(self.prompt_processing_sequence_length)
561+
numpy.arange(
562+
num_cached_entries,
563+
num_cached_entries
564+
+ self.prompt_processing_sequence_length,
565+
)
568566
.reshape(1, -1)
569567
.astype(numpy.int64)
570568
)

Diff for: src/deepsparse/transformers/utils/helpers.py

+76-1
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,93 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
import uuid
16-
from typing import List, Union
17+
from typing import List, Tuple, Union
1718

1819
import numpy
20+
import onnx
21+
22+
from deepsparse.utils.onnx import translate_onnx_type_to_numpy
23+
from sparsezoo.utils import save_onnx
1924

2025

2126
__all__ = [
2227
"generate_session_id",
2328
"pad_to_fixed_length",
2429
"create_causal_mask",
30+
"overwrite_onnx_model_inputs",
2531
]
2632

33+
_LOGGER = logging.getLogger(__name__)
34+
35+
36+
def overwrite_onnx_model_inputs(
37+
onnx_file_path: str,
38+
sequence_length: int,
39+
input_ids_length: int,
40+
batch_size: int = 1,
41+
) -> Tuple[str, List[int]]:
42+
"""
43+
Enforces the appropriate input shapes for the onnx model, as well as
44+
checks whether kv cache is enabled or not.
45+
46+
:param onnx_file_path: The path to the onnx model file that will be
47+
overwritten with the new input shapes
48+
:param batch_size: The batch size to use for the input
49+
:param sequence_length: The sequence length to use for the input
50+
:param input_ids_length: The length of input_ids
51+
:return: A tuple that contains:
52+
- the path to the onnx model file that has been overwritten
53+
with the new input shapes
54+
- boolean list, where elements are set to True if the
55+
corresponding model output should be cached or False
56+
if not.
57+
- the data type of the kv cache. If the model does not
58+
use kv cache, then the data type is None
59+
"""
60+
model = onnx.load(onnx_file_path, load_external_data=False)
61+
initializer_input_names = set(node.name for node in model.graph.initializer)
62+
external_inputs = [
63+
inp for inp in model.graph.input if inp.name not in initializer_input_names
64+
]
65+
for external_input in external_inputs:
66+
# overwrite the batch size for all the inputs
67+
external_input.type.tensor_type.shape.dim[0].dim_value = batch_size
68+
69+
if external_input.name in ["input_ids", "positions"]:
70+
external_input.type.tensor_type.shape.dim[1].dim_value = input_ids_length
71+
elif external_input.name == "attention_mask":
72+
external_input.type.tensor_type.shape.dim[1].dim_value = sequence_length
73+
elif external_input.name.startswith("past_key_values"):
74+
external_input.type.tensor_type.shape.dim[2].dim_value = (
75+
sequence_length - input_ids_length
76+
)
77+
elif external_input.name.startswith("causal_mask"):
78+
external_input.type.tensor_type.shape.dim[2].dim_value = input_ids_length
79+
external_input.type.tensor_type.shape.dim[3].dim_value = sequence_length
80+
else:
81+
raise ValueError(f"Unexpected external input name: {external_input.name}")
82+
83+
_LOGGER.info(
84+
"Overwriting in-place the input shapes "
85+
f"of the transformer model at {onnx_file_path}"
86+
)
87+
save_onnx(model, onnx_file_path)
88+
89+
output_indices_to_be_cached = [
90+
1 if inp.name.startswith("present") else 0 for inp in model.graph.output
91+
]
92+
93+
kv_cache_data_type = None
94+
if any(output_indices_to_be_cached):
95+
kv_cache_elem_type = next(
96+
inp for inp in model.graph.input if inp.name.startswith("past_key_values")
97+
).type.tensor_type.elem_type
98+
kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type)
99+
100+
return onnx_file_path, output_indices_to_be_cached, kv_cache_data_type
101+
27102

28103
def generate_session_id() -> str:
29104
"""

0 commit comments

Comments
 (0)