Skip to content

Commit fdb5d44

Browse files
authored
[Text Generation][KVCacheStorage] TextGenerationPipeline refactor (#1254)
* initial commit * upload draft for review * initial implementation. testing now * in this form tests pass * cleanup * ready for review * PR review changes
1 parent 1ba444f commit fdb5d44

File tree

9 files changed

+346
-293
lines changed

9 files changed

+346
-293
lines changed

src/deepsparse/transformers/engines/nl_decoder_engine.py

+63-146
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,19 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415
import logging
15-
from typing import Any, Dict, List, Optional
16+
from typing import Any, Dict, List, Optional, Tuple
1617

1718
import numpy
18-
from transformers import AutoTokenizer
1919

2020
from deepsparse.engine import Context
2121
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
2222
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
23-
from deepsparse.transformers.utils.helpers import generate_session_id
2423
from deepsparse.transformers.utils.timings import TextGenerationTimings
2524
from deepsparse.utils import TimerManager
2625
from deepsparse.utils.onnx import (
2726
CACHE_INPUT_PREFIX,
28-
CACHE_OUTPUT_PREFIX,
2927
overwrite_onnx_model_inputs_for_kv_cache_models,
3028
)
3129

@@ -37,20 +35,16 @@
3735

3836
class NLDecoderEngine:
3937
"""
40-
The NLDecoderEngine (NaturalLanguageDecoderEngine) handles the
38+
The NLDecoderEngine (Natural Language Decoder Engine) handles the
4139
logic around the inference for Natural Language pipeline,
42-
including batching and kv cache logic.
40+
including batching and kv cache manipulation logic.
4341
4442
:param onnx_file_path: The path to the onnx model file
4543
:param engine_type: The type of engine to use for the inference
4644
:param engine_args: The arguments to pass to the engine
4745
:param sequence_length: The maximum sequence length to run the engine for
4846
:param input_ids_length: The maximum input ids length to run the engine for
4947
:param engine_context: The context to run the engine in
50-
:param sampling_temperature: The temperature to use for sampling
51-
:param deterministic: Whether to use deterministic sampling
52-
:param tokenizer: The tokenizer to used for engine inputs
53-
:param engine_context: The context to run the engine in
5448
:param internal_kv_cache: Whether to use the deepsparse
5549
kv cache in the DecoderKVCache object or not
5650
"""
@@ -62,9 +56,6 @@ def __init__(
6256
engine_args: Dict[str, Any],
6357
sequence_length: int,
6458
input_ids_length: int,
65-
tokenizer: AutoTokenizer,
66-
sampling_temperature: float = 1.0,
67-
deterministic: bool = True,
6859
engine_context: Optional[Context] = None,
6960
internal_kv_cache=False,
7061
timer_manager: TimerManager = None,
@@ -82,9 +73,7 @@ def __init__(
8273
input_ids_length=input_ids_length,
8374
)
8475

85-
kv_cache_enabled = False
8676
if any(output_indices_to_be_cached):
87-
kv_cache_enabled = True
8877
self.kv_cache_data_type = kv_cache_data_type
8978
if internal_kv_cache and engine_type == DEEPSPARSE_ENGINE:
9079
# inform the engine, that are using the kv cache
@@ -98,30 +87,10 @@ def __init__(
9887
)
9988
self.timer_manager = timer_manager or TimerManager()
10089
self.sequence_length = sequence_length
101-
self.sampling_temperature = sampling_temperature
102-
self.deterministic = deterministic
10390
self.input_ids_length = input_ids_length
10491
self.cache_length = sequence_length - input_ids_length
105-
self.kv_cache_enabled = kv_cache_enabled
106-
self.kv_cache = DecoderKVCache(internal_kv_cache) if kv_cache_enabled else None
107-
self._freeze_first_position = self._should_freeze_first_position(tokenizer)
108-
self._session_id = generate_session_id()
10992
self._engine_type = engine_type
11093

111-
@property
112-
def session_id(self) -> str:
113-
"""
114-
:return: The session id for the kv_cache if enabled
115-
"""
116-
return self._session_id
117-
118-
@session_id.setter
119-
def session_id(self, session_id: str):
120-
"""
121-
:param session_id: The session id to set for the kv_cache
122-
"""
123-
self._session_id = session_id
124-
12594
@property
12695
def onnx_input_names_no_cache(self) -> List[str]:
12796
"""
@@ -135,25 +104,43 @@ def onnx_input_names_no_cache(self) -> List[str]:
135104
]
136105

137106
@property
138-
def num_non_blank_cache_entries(self) -> int:
107+
def onnx_input_names_cached(self) -> List[str]:
108+
"""
109+
:return: The cached input names for the onnx model
110+
"""
111+
return [
112+
name
113+
for name in self.engine.input_names
114+
if name.startswith(CACHE_INPUT_PREFIX)
115+
]
116+
117+
@property
118+
def cache_shape(self) -> Tuple[int, int, int, int]:
139119
"""
140-
:return A number of non-blank entries in the
141-
kv cache
120+
:return: The shape of the kv cache inputs
121+
for the onnx model. The shape is
122+
(batch_size, num_heads, sequence_length, hidden_size)
142123
"""
143-
return self.kv_cache.num_non_blank_entries
124+
cache_engine_input_index = next(
125+
i
126+
for i, name in enumerate(self.engine.input_names)
127+
if CACHE_INPUT_PREFIX in name
128+
)
129+
return self.engine.input_shapes[cache_engine_input_index]
144130

145131
@property
146-
def internal_cache_active(self) -> bool:
132+
def output_names(self) -> List[str]:
147133
"""
148-
:return: Whether the internal kv cache is active
134+
:return: The output names for the onnx model
149135
"""
150-
return self.kv_cache_enabled and self.kv_cache.engine_internal_cache is not None
136+
return self.engine.output_names
151137

152-
def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]:
138+
def run(
139+
self, inputs: List[numpy.ndarray], val_inp: bool, kv_cache: DecoderKVCache
140+
) -> List[numpy.ndarray]:
153141
"""
154142
Run the engine with the given inputs.
155-
156-
If the self.internal_cache_active=True, the internal
143+
If the kv_cache.engine_internal_cache=True, the internal
157144
deepsparse kv cache management is enabled. In this case
158145
the LIB.kv_cache class object will be passed to the engine
159146
call as well. In this scenario also the inputs will not be
@@ -163,25 +150,27 @@ def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]
163150
164151
:param inputs: The inputs to run the engine with
165152
:param val_inp: Whether the input is for validation or not
153+
:param kv_cache: The kv cache object to use for the inference
154+
166155
:return: The output of the engine
167156
"""
168-
169-
if self.internal_cache_active:
157+
if bool(kv_cache.engine_internal_cache):
170158
# conventionally, before dispatching
171159
# inputs to the engine, we validate them
172160
# if val_inp=True. However, in this case
173161
# we want to pass the empty kv cache inputs
174162
# (batch_size=0) to the engine. Therefore,
175163
# we skip the validation
176164
return self.engine._eng_net.execute_list_out(
177-
inputs, self.kv_cache.engine_internal_cache
165+
inputs, kv_cache.engine_internal_cache
178166
)
179167
# run the engine without the LIB.kv_cache object
180168
return self.engine.run(inputs, val_inp)
181169

182170
def __call__(
183171
self,
184172
inp: List[numpy.ndarray],
173+
kv_cache: Optional[DecoderKVCache] = None,
185174
val_inp: bool = True,
186175
) -> numpy.ndarray:
187176
"""
@@ -190,23 +179,28 @@ def __call__(
190179
:param inp: The input to run the engine with. We expect a
191180
list of numpy arrays that contain the input ids,
192181
attention mask, and position ids (optionally)
182+
:param kv_cache: The DecoderKVCache object that contains
183+
the kv cache state
193184
:param val_inp: Whether the input is for validation or not
185+
194186
:return: The generated token and corresponding logits
195187
"""
196188
timer = self.timer_manager.current
197-
if self.kv_cache:
189+
if kv_cache:
198190
# if model has kv cache enabled, we need
199191
# to add the kv cache state to the input
200-
inp = self.add_kv_cache_to_input(inp)
192+
inp = self.add_kv_cache_to_input(inp, kv_cache)
201193

202194
with timer.time(f"EXECUTE_ENGINE_SEQ_LEN_{self.sequence_length}"):
203-
out = self.run(inp, val_inp)
195+
out = self.run(inp, val_inp, kv_cache)
204196

205-
if self.kv_cache:
197+
if kv_cache:
206198
with timer.time(TextGenerationTimings.KV_CACHE_UPDATE):
207199
logits, *kv_cache_state = out
208200
self.update_kv_cache(
209-
kv_cache_state=kv_cache_state, input_ids_len=self.input_ids_length
201+
kv_cache_state=kv_cache_state,
202+
input_ids_len=self.input_ids_length,
203+
kv_cache=kv_cache,
210204
)
211205
else:
212206
logits = out[0]
@@ -219,36 +213,11 @@ def __str__(self):
219213
def __repr__(self):
220214
return str(self)
221215

222-
def transfer_cache_state(self, cache: DecoderKVCache):
216+
def add_kv_cache_to_input(
217+
self, inp: List[numpy.ndarray], kv_cache: DecoderKVCache
218+
) -> List[numpy.ndarray]:
223219
"""
224-
Transfers the kv cache state and the number of tokens processed
225-
information from another NLDecoderEngine. Call this method when
226-
you want to transfer the kv cache state from one engine to another.
227-
228-
This method will also automatically set the kv cache capacity to
229-
the appropriate value for the new engine.
230-
231-
:param cache: The `DecoderKVCache` object to transfer to the engine
232-
from
233-
"""
234-
cache.set_capacity(self.cache_length)
235-
self.kv_cache = cache
236-
237-
def reset_kv_cache(self):
238-
"""
239-
Resets the kv cache state.
240-
"""
241-
kv_cache_state = self._initialize_kv_cache_state(self.cache_length)
242-
self.kv_cache.setup(
243-
session_id=self._session_id,
244-
state=kv_cache_state,
245-
num_processed_tokens=0,
246-
freeze_first_position=self._freeze_first_position,
247-
)
248-
249-
def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]:
250-
"""
251-
Takes the input and adds the past kv cache state to it.
220+
Takes the input and adds the kv cache state to it.
252221
253222
If the internal kv cache is enabled, the kv cache state
254223
will always be an empty array. This is just to make sure
@@ -262,17 +231,11 @@ def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]
262231
263232
264233
:param inp: The input to the model
234+
:param kv_cache: The kv cache object
235+
265236
:return The input with the kv cache state added to it
266237
"""
267-
if self.internal_cache_active:
268-
kv_cache_state = self._initialize_kv_cache_state(
269-
self.cache_length, empty=True
270-
)
271-
else:
272-
kv_cache_state = self.kv_cache.cached_inputs
273-
if kv_cache_state is None:
274-
self.reset_kv_cache()
275-
kv_cache_state = self.kv_cache.cached_inputs
238+
kv_cache_state = copy.copy(kv_cache.cached_inputs)
276239

277240
for idx, input_name in enumerate(self.onnx_input_names_no_cache):
278241
kv_cache_state[input_name] = inp[idx]
@@ -284,75 +247,29 @@ def update_kv_cache(
284247
self,
285248
kv_cache_state: List[numpy.ndarray],
286249
input_ids_len: int,
250+
kv_cache: DecoderKVCache,
287251
):
288252
"""
289-
Updates the state of the kv cache
253+
Updates the kv cache using the new kv cache state.
290254
291255
If the internal kv cache is enabled, we refrain from
292256
updating the kv cache state as it is being tracked internally
293257
inside the engine. We only update the number of tokens processed.
294258
295-
:param kv_cache_state: The state of the kv cache storage
259+
:param kv_cache_state: The new state of the kv cache storage
296260
:param input_ids_len: The length of input_ids
261+
:param kv_cache: The kv cache object to update
297262
"""
298-
if self.internal_cache_active:
299-
self.kv_cache.total_num_processed_tokens += input_ids_len
263+
if bool(kv_cache.engine_internal_cache):
264+
kv_cache.total_num_processed_tokens += input_ids_len
300265
return
301266

302-
cache_onnx_names = [
303-
name
304-
for name in self.engine.input_names
305-
if name.startswith(CACHE_INPUT_PREFIX)
306-
]
307267
kv_cache_state = {
308-
name: array for name, array in zip(cache_onnx_names, kv_cache_state)
268+
name: array
269+
for name, array in zip(self.onnx_input_names_cached, kv_cache_state)
309270
}
310271

311-
self.kv_cache.update(
272+
kv_cache.update(
312273
state=kv_cache_state,
313274
input_ids_len=input_ids_len,
314275
)
315-
316-
def _initialize_kv_cache_state(
317-
self, length: int, empty: bool = False
318-
) -> Dict[str, numpy.ndarray]:
319-
# initialize empty kv cache of size
320-
# (batch_size, num_attention_heads, length, hidden_dims)
321-
# if empty is True, we initialize empty kv_cache
322-
# and set the batch_size to 0
323-
324-
cache_engine_input_index = next(
325-
i
326-
for i, name in enumerate(self.engine.input_names)
327-
if CACHE_INPUT_PREFIX in name
328-
)
329-
batch_size, num_attention_heads, _, hidden_dims = self.engine.input_shapes[
330-
cache_engine_input_index
331-
]
332-
333-
empty_kv_cache_tensor = numpy.zeros(
334-
(
335-
batch_size if not empty else 0,
336-
num_attention_heads,
337-
length,
338-
hidden_dims,
339-
),
340-
dtype=self.kv_cache_data_type,
341-
)
342-
343-
cache_keys = [
344-
output_name.replace(CACHE_OUTPUT_PREFIX, CACHE_INPUT_PREFIX)
345-
for output_name in self.engine.output_names
346-
if output_name.startswith(CACHE_OUTPUT_PREFIX)
347-
]
348-
return {key: empty_kv_cache_tensor for key in cache_keys}
349-
350-
@staticmethod
351-
def _should_freeze_first_position(tokenizer) -> bool:
352-
# use tokenizer to find out whether we should freeze the first position
353-
# (True if tokenizer has a prefix for a BOS token)
354-
if tokenizer is None:
355-
return False
356-
if hasattr(tokenizer, "add_bos_token"):
357-
return True
358-
return False

src/deepsparse/transformers/pipelines/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@
1919
from .question_answering import *
2020
from .text_classification import *
2121
from .token_classification import *
22+
from .text_generation import *
2223
from .zero_shot_text_classification import *
2324
from .embedding_extraction import *

0 commit comments

Comments
 (0)