From 837905a56d97edf1c8830d5f17ba70c44f4d9d6a Mon Sep 17 00:00:00 2001 From: Damian Date: Thu, 14 Sep 2023 15:20:15 +0000 Subject: [PATCH 1/7] initial commit --- tests/test_data/pipeline_bench_config.json | 1 - tests/test_pipeline_benchmark.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/test_data/pipeline_bench_config.json b/tests/test_data/pipeline_bench_config.json index 5886762cea..afd4db352d 100644 --- a/tests/test_data/pipeline_bench_config.json +++ b/tests/test_data/pipeline_bench_config.json @@ -2,7 +2,6 @@ "data_type": "dummy", "gen_sequence_length": 100, "input_image_shape": [500,500,3], - "data_folder": "/home/sadkins/imagenette2-320/", "recursive_search": true, "max_string_length": -1, "pipeline_kwargs": {}, diff --git a/tests/test_pipeline_benchmark.py b/tests/test_pipeline_benchmark.py index 485599d044..782a1f8016 100644 --- a/tests/test_pipeline_benchmark.py +++ b/tests/test_pipeline_benchmark.py @@ -95,7 +95,6 @@ def test_pipeline_benchmark( if res.stdout is not None: print(f"\n==== test_benchmark output ====\n{res.stdout}") assert res.returncode == 0 - assert "error" not in res.stdout.lower() assert "fail" not in res.stdout.lower() assert "total_inference" in res.stdout.lower() From 803e5369ee0d57d9c090dbb07cd40c6372c0e94b Mon Sep 17 00:00:00 2001 From: Damian Date: Tue, 19 Sep 2023 07:43:48 +0000 Subject: [PATCH 2/7] upload draft for review --- .../transformers/engines/nl_decoder_engine.py | 9 +- .../transformers/pipelines/__init__.py | 1 + .../transformers/pipelines/chat_pipeline.py | 50 ++++++++++ .../transformers/pipelines/text_generation.py | 52 +++++------ src/deepsparse/transformers/utils/__init__.py | 1 + .../transformers/utils/storage_kv_cache.py | 92 +++++++++++++++++++ 6 files changed, 171 insertions(+), 34 deletions(-) create mode 100644 src/deepsparse/transformers/pipelines/chat_pipeline.py create mode 100644 src/deepsparse/transformers/utils/storage_kv_cache.py diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index 223d4f0a60..1ea2cd89f9 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -183,6 +183,7 @@ def __call__( self, inp: List[numpy.ndarray], val_inp: bool = True, + decoder=DecoderKVCache, ) -> numpy.ndarray: """ The main entry point for running the engine. @@ -197,7 +198,7 @@ def __call__( if self.kv_cache: # if model has kv cache enabled, we need # to add the kv cache state to the input - inp = self.add_kv_cache_to_input(inp) + inp = self.add_kv_cache_to_input(inp, decoder) with timer.time(f"EXECUTE_ENGINE_SEQ_LEN_{self.sequence_length}"): out = self.run(inp, val_inp) @@ -206,12 +207,14 @@ def __call__( with timer.time(TextGenerationTimings.KV_CACHE_UPDATE): logits, *kv_cache_state = out self.update_kv_cache( - kv_cache_state=kv_cache_state, input_ids_len=self.input_ids_length + kv_cache_state=kv_cache_state, + input_ids_len=self.input_ids_length, + decoder=decoder, ) else: logits = out[0] - return logits + return logits, decoder def __str__(self): return f"{self.__class__.__name__}: {self.engine}" diff --git a/src/deepsparse/transformers/pipelines/__init__.py b/src/deepsparse/transformers/pipelines/__init__.py index 3e2e88381d..fdc32b17ee 100644 --- a/src/deepsparse/transformers/pipelines/__init__.py +++ b/src/deepsparse/transformers/pipelines/__init__.py @@ -19,5 +19,6 @@ from .question_answering import * from .text_classification import * from .token_classification import * +from .text_generation import * from .zero_shot_text_classification import * from .embedding_extraction import * diff --git a/src/deepsparse/transformers/pipelines/chat_pipeline.py b/src/deepsparse/transformers/pipelines/chat_pipeline.py new file mode 100644 index 0000000000..4e76add403 --- /dev/null +++ b/src/deepsparse/transformers/pipelines/chat_pipeline.py @@ -0,0 +1,50 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline, TextGenerationOutput, TextGenerationInput +from deepsparse.transformers.utils import SessionStorageKVCache, DecoderKVCache +from pydantic import Field + +class ChatOutput(TextGenerationOutput): + session_id: Optional[str] = Field( + default=None, description="A string identifier for the kv cache session." + +class ChatInput(TextGenerationInput): + session_id: Optional[str] = Field( + default=None, description="A string identifier for the kv cache session." + ) + +class ChatPipeline(TextGenerationPipeline): + def __init__(self, **kwargs): + self.session_storage = SessionStorageKVCache() + super().__init__(**kwargs) + + + + def get_decoder_kv_cache(self, context) -> Optional[DecoderKVCache]: + session_id = context.get("session_id", None) + session = self.session_storage.get(session_id) + if session is None: + session = self._create_decoder(...) + return session + + def process_inputs(...): + + engine_input, context = super().process_inputs(...) + # add session_id context + return engine_input, context + + def split_engine_inputs(...): + pass diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 454cbcb05f..7b7c55171e 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -37,6 +37,7 @@ from deepsparse.pipeline import DEEPSPARSE_ENGINE from deepsparse.transformers.engines import NLDecoderEngine from deepsparse.transformers.pipelines import TransformersPipeline +from deepsparse.transformers.utils import DecoderKVCache from deepsparse.transformers.utils.helpers import ( create_causal_mask, pad_to_fixed_length, @@ -85,13 +86,7 @@ class Config: "Note: This flag is only applicable when return_logits " "is `True`.", ) - session_id: Optional[str] = Field( - default=None, - description="A user may set a string identifier " - "for the kv cache session. If None, " - "and the model is using kv cache, it " - "will be set to a random uuid.", - ) + fixed_sequences_length: bool = Field( default=False, description="A flag that indicates whether to modify " @@ -156,9 +151,6 @@ class TextGenerationOutput(BaseModel): "The logits have dimensions " "[batch_size, sequence_length, vocab_size]", ) - session_id: Optional[str] = Field( - default=None, description="A string identifier for the kv cache session." - ) class Config: arbitrary_types_allowed = True @@ -451,11 +443,6 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]: ) engine_input = self.tokens_to_engine_input(input_tokens, onnx_input_names) - if inputs.session_id is not None: - # if session_id is provided, we need to set it in engines - self.engine.session_id = inputs.session_id - self.multitoken_engine.session_id = inputs.session_id - context = dict( num_generated_predictions=inputs.num_generated_predictions, return_logits=inputs.return_logits, @@ -537,7 +524,7 @@ def engine_forward( else: # run the prompt through with timer.time(TextGenerationTimings.PROMPT_PREFILL): - prompt_logits = self.prompt_inference(engine_inputs) + prompt_logits, decoder = self.prompt_inference(engine_inputs) tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist() token_generator = TokenGenerator( @@ -569,8 +556,8 @@ def engine_forward( with timer.time(TextGenerationTimings.TOKEN_GENERATION): while len(generated_tokens) < max_tokens: with timer.time(TextGenerationTimings.TOKEN_GENERATION_SINGLE): - logits = self.autoregressive_inference( - tokens=token_generator.tokens + logits, decoder = self.autoregressive_inference( + tokens=token_generator.tokens, decoder ) token = token_generator.generate(logits=logits[0, -1, :]) generated_tokens.append(token) @@ -609,7 +596,7 @@ def engine_forward( def prompt_inference( self, engine_inputs: List[numpy.ndarray], - ) -> Tuple[List[int], List[numpy.ndarray]]: + ) -> Tuple[List[numpy.ndarray], DecoderKVCache]: """ An inference run that processes the prompt through the model to generate the new token and logits @@ -617,7 +604,6 @@ def prompt_inference( :param engine_inputs: the prompt (context) represented by a list of numpy inputs to the engine :return: A tuple of: - - The list of prompt tokens plus the new, generated token - The logits generated from the prompt (with dimensions ['batch_size', 'num_tokens', 'vocab_size']) """ @@ -628,17 +614,14 @@ def prompt_inference( num_tokens_processed = 0 if len(tokens) > self.prompt_sequence_length and self.enable_multitoken_prefill: - self.multitoken_engine.reset_kv_cache() + decoder = get_decoder_kv_cache(...) for engine_inputs in self.engine_inputs_for_prefill(tokens): - new_logits = self.multitoken_engine(engine_inputs) + new_logits, decoder = self.multitoken_engine(engine_inputs, decoder) num_tokens_processed += self.prompt_sequence_length prompt_logits.append(new_logits) - if num_tokens_processed: - # transfer the cache state from the multi-token engine to the main engine - self.engine.transfer_cache_state(cache=self.multitoken_engine.kv_cache) - else: - self.engine.reset_kv_cache() + if not num_tokens_processed: + decoder = get_decoder_kv_cache(...) # prompt size is small, run autoregressive inference to populate kv cache run_tokens = [] if num_tokens_processed == 0 else tokens[:num_tokens_processed] @@ -648,15 +631,16 @@ def prompt_inference( with self.timer_manager.current.time( TextGenerationTimings.PROMPT_PREFILL_SINGLE ): - new_logits = self.autoregressive_inference(run_tokens) + new_logits, decoder = self.autoregressive_inference(run_tokens, decoder) prompt_logits.append(new_logits) - return prompt_logits + return prompt_logits, decoder def autoregressive_inference( self, tokens: List[int], + decoder: DecoderKVCache, ) -> Tuple[int, numpy.ndarray]: """ An inference run that processes the last token to generate @@ -689,9 +673,9 @@ def autoregressive_inference( engine_inputs_map[name] for name in self.engine.onnx_input_names_no_cache ] - generated_logits = self.engine(engine_inputs) + generated_logits, decoder = self.engine(engine_inputs, decoder) - return generated_logits + return generated_logits, decoder def engine_inputs_for_prefill( self, tokens: List[int] @@ -861,6 +845,12 @@ def causal_mask_input_present(model_path: str) -> bool: return is_causal_mask_input + def get_decoder_kv_cache(self) -> DecoderKVCache: + return self._create_decoder() + + def _create_decoder(self, context) -> Optional[DecoderKVCache]: + pass + def _stop_token_generated( self, token, stop_tokens: Union[None, str, Sequence[str]] ) -> bool: diff --git a/src/deepsparse/transformers/utils/__init__.py b/src/deepsparse/transformers/utils/__init__.py index 80f24f1040..d5c081936d 100644 --- a/src/deepsparse/transformers/utils/__init__.py +++ b/src/deepsparse/transformers/utils/__init__.py @@ -14,6 +14,7 @@ # flake8: noqa +from .storage_kv_cache import * from .decoder_kv_cache import * from .helpers import * from .timings import * diff --git a/src/deepsparse/transformers/utils/storage_kv_cache.py b/src/deepsparse/transformers/utils/storage_kv_cache.py new file mode 100644 index 0000000000..6e525b9693 --- /dev/null +++ b/src/deepsparse/transformers/utils/storage_kv_cache.py @@ -0,0 +1,92 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, Union + +from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["SessionStorageKVCache"] + + +class SessionStorageKVCache: + """ + A storage that stores the kv cache sessions. + Each session is a DecoderKVCache object that + stores the state of the kv cache. + The storage is a dictionary that where keys are session_ids + and values are of all the active sessions. + """ + + def __init__(self): + self._memory: Dict[str, DecoderKVCache] = dict() + + def __len__(self): + return len(self._memory) + + def __str__(self): + return ( + f"{SessionStorageKVCache.__name__}:\n " + f"\tsessions: {[session_name for session_name in self._memory.keys()]}\n" + ) + + def has_session(self, session_id: str) -> bool: + """ + Check if the storage has a session with the given session id. + :param session_id: The identifier of the cache session. + :return: True if the storage has a session with the given session id. + """ + return session_id in self._memory + + def put(self, session: DecoderKVCache): + """ + Put the cache session in the storage. + + :param session: The session to store. + """ + session_id = session.id + if self.has_session(session_id): + _LOGGER.debug( + f"Session: {session_id} already exists in the storage. " + f"It will be overwritten." + ) + self._memory[session.id] = session + + def get(self, session_id: str) -> Union[DecoderKVCache, None]: + """ + Get the state of the kv cache for a session from the storage. + + :param session_id: The identifier of the cache session. + :return: The state of the kv cache for the session. + """ + session = self._memory.get(session_id) + if session is None: + _LOGGER.debug(f"No cache session found for session id: {session_id}") + return session + + def pop(self, session_id: str) -> DecoderKVCache: + """ + Pop the session correspond to session_id from the storage. + :param session_id: The identifier of the cache session. + """ + session = self._memory.pop(session_id, None) + if session is None: + raise ValueError( + f"Attempting to remove session: {session_id} from the storage. " + f"However, the session does not exist in the storage." + ) + return session From 76ae85627e646c8f7a264ad9ddfed7522c5ab718 Mon Sep 17 00:00:00 2001 From: Damian Date: Tue, 19 Sep 2023 16:43:55 +0000 Subject: [PATCH 3/7] initial implementation. testing now --- src/deepsparse/pipeline.py | 7 +- .../transformers/engines/nl_decoder_engine.py | 198 ++++-------------- .../transformers/pipelines/chat_pipeline.py | 50 ----- .../transformers/pipelines/text_generation.py | 83 ++++---- src/deepsparse/transformers/utils/__init__.py | 5 +- .../transformers/utils/decoder_kv_cache.py | 70 +++---- src/deepsparse/transformers/utils/helpers.py | 60 +++++- .../pipelines/test_text_generation.py | 60 ++---- 8 files changed, 194 insertions(+), 339 deletions(-) delete mode 100644 src/deepsparse/transformers/pipelines/chat_pipeline.py diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 1fb09cca0f..96aea4ed8e 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -254,9 +254,10 @@ def __call__(self, *args, **kwargs) -> BaseModel: # submit split batches to engine threadpool engine_forward_with_context = partial(self.engine_forward, context=context) - batch_outputs = list( - self.executor.map(engine_forward_with_context, batches) - ) + # batch_outputs = list( + # self.executor.map(engine_forward_with_context, batches) + # ) + batch_outputs = [engine_forward_with_context(x) for x in batches] # join together the batches of size `self._batch_size` engine_outputs = self.join_engine_outputs(batch_outputs, orig_batch_size) diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index 1ea2cd89f9..e49d0091bb 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -11,21 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import numpy -from transformers import AutoTokenizer from deepsparse.engine import Context from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache -from deepsparse.transformers.utils.helpers import generate_session_id from deepsparse.transformers.utils.timings import TextGenerationTimings from deepsparse.utils import TimerManager from deepsparse.utils.onnx import ( CACHE_INPUT_PREFIX, - CACHE_OUTPUT_PREFIX, overwrite_onnx_model_inputs_for_kv_cache_models, ) @@ -37,9 +35,9 @@ class NLDecoderEngine: """ - The NLDecoderEngine (NaturalLanguageDecoderEngine) handles the + The NLDecoderEngine (Natural Language Decoder Engine) handles the logic around the inference for Natural Language pipeline, - including batching and kv cache logic. + including batching and kv cache manipulation logic. :param onnx_file_path: The path to the onnx model file :param engine_type: The type of engine to use for the inference @@ -47,10 +45,6 @@ class NLDecoderEngine: :param sequence_length: The maximum sequence length to run the engine for :param input_ids_length: The maximum input ids length to run the engine for :param engine_context: The context to run the engine in - :param sampling_temperature: The temperature to use for sampling - :param deterministic: Whether to use deterministic sampling - :param tokenizer: The tokenizer to used for engine inputs - :param engine_context: The context to run the engine in :param internal_kv_cache: Whether to use the deepsparse kv cache in the DecoderKVCache object or not """ @@ -62,9 +56,6 @@ def __init__( engine_args: Dict[str, Any], sequence_length: int, input_ids_length: int, - tokenizer: AutoTokenizer, - sampling_temperature: float = 1.0, - deterministic: bool = True, engine_context: Optional[Context] = None, internal_kv_cache=False, timer_manager: TimerManager = None, @@ -98,30 +89,11 @@ def __init__( ) self.timer_manager = timer_manager or TimerManager() self.sequence_length = sequence_length - self.sampling_temperature = sampling_temperature - self.deterministic = deterministic self.input_ids_length = input_ids_length self.cache_length = sequence_length - input_ids_length self.kv_cache_enabled = kv_cache_enabled - self.kv_cache = DecoderKVCache(internal_kv_cache) if kv_cache_enabled else None - self._freeze_first_position = self._should_freeze_first_position(tokenizer) - self._session_id = generate_session_id() self._engine_type = engine_type - @property - def session_id(self) -> str: - """ - :return: The session id for the kv_cache if enabled - """ - return self._session_id - - @session_id.setter - def session_id(self, session_id: str): - """ - :param session_id: The session id to set for the kv_cache - """ - self._session_id = session_id - @property def onnx_input_names_no_cache(self) -> List[str]: """ @@ -135,38 +107,26 @@ def onnx_input_names_no_cache(self) -> List[str]: ] @property - def num_non_blank_cache_entries(self) -> int: - """ - :return A number of non-blank entries in the - kv cache - """ - return self.kv_cache.num_non_blank_entries + def cache_shape(self) -> Tuple[int, int, int, int]: + cache_engine_input_index = next( + i + for i, name in enumerate(self.engine.input_names) + if CACHE_INPUT_PREFIX in name + ) + return self.engine.input_shapes[cache_engine_input_index] @property - def internal_cache_active(self) -> bool: - """ - :return: Whether the internal kv cache is active - """ - return self.kv_cache_enabled and self.kv_cache.engine_internal_cache is not None - - def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]: + def output_names(self) -> List[str]: """ - Run the engine with the given inputs. - - If the self.internal_cache_active=True, the internal - deepsparse kv cache management is enabled. In this case - the LIB.kv_cache class object will be passed to the engine - call as well. In this scenario also the inputs will not be - validated, even if the val_inp=True. This is because we - want to pass the empty kv cache inputs (batch_size=0) to - the engine. - - :param inputs: The inputs to run the engine with - :param val_inp: Whether the input is for validation or not - :return: The output of the engine + :return: The output names for the onnx model """ + return self.engine.output_names - if self.internal_cache_active: + def run( + self, inputs: List[numpy.ndarray], val_inp: bool, kv_cache: DecoderKVCache + ) -> List[numpy.ndarray]: + """ """ + if bool(kv_cache.engine_internal_cache): # conventionally, before dispatching # inputs to the engine, we validate them # if val_inp=True. However, in this case @@ -174,7 +134,7 @@ def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray] # (batch_size=0) to the engine. Therefore, # we skip the validation return self.engine._eng_net.execute_list_out( - inputs, self.kv_cache.engine_internal_cache + inputs, kv_cache.engine_internal_cache ) # run the engine without the LIB.kv_cache object return self.engine.run(inputs, val_inp) @@ -182,8 +142,8 @@ def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray] def __call__( self, inp: List[numpy.ndarray], + kv_cache: Optional[DecoderKVCache] = None, val_inp: bool = True, - decoder=DecoderKVCache, ) -> numpy.ndarray: """ The main entry point for running the engine. @@ -191,30 +151,33 @@ def __call__( :param inp: The input to run the engine with. We expect a list of numpy arrays that contain the input ids, attention mask, and position ids (optionally) + :param kv_cache: The DecoderKVCache object that contains + the kv cache state :param val_inp: Whether the input is for validation or not + :return: The generated token and corresponding logits """ timer = self.timer_manager.current - if self.kv_cache: + if self.kv_cache_enabled: # if model has kv cache enabled, we need # to add the kv cache state to the input - inp = self.add_kv_cache_to_input(inp, decoder) + inp = self.add_kv_cache_to_input(inp, kv_cache) with timer.time(f"EXECUTE_ENGINE_SEQ_LEN_{self.sequence_length}"): - out = self.run(inp, val_inp) + out = self.run(inp, val_inp, kv_cache) - if self.kv_cache: + if self.kv_cache_enabled: with timer.time(TextGenerationTimings.KV_CACHE_UPDATE): logits, *kv_cache_state = out self.update_kv_cache( kv_cache_state=kv_cache_state, input_ids_len=self.input_ids_length, - decoder=decoder, + kv_cache=kv_cache, ) else: logits = out[0] - return logits, decoder + return logits def __str__(self): return f"{self.__class__.__name__}: {self.engine}" @@ -222,36 +185,11 @@ def __str__(self): def __repr__(self): return str(self) - def transfer_cache_state(self, cache: DecoderKVCache): + def add_kv_cache_to_input( + self, inp: List[numpy.ndarray], kv_cache: DecoderKVCache + ) -> List[numpy.ndarray]: """ - Transfers the kv cache state and the number of tokens processed - information from another NLDecoderEngine. Call this method when - you want to transfer the kv cache state from one engine to another. - - This method will also automatically set the kv cache capacity to - the appropriate value for the new engine. - - :param cache: The `DecoderKVCache` object to transfer to the engine - from - """ - cache.set_capacity(self.cache_length) - self.kv_cache = cache - - def reset_kv_cache(self): - """ - Resets the kv cache state. - """ - kv_cache_state = self._initialize_kv_cache_state(self.cache_length) - self.kv_cache.setup( - session_id=self._session_id, - state=kv_cache_state, - num_processed_tokens=0, - freeze_first_position=self._freeze_first_position, - ) - - def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]: - """ - Takes the input and adds the past kv cache state to it. + Takes the input and adds the kv cache state to it. If the internal kv cache is enabled, the kv cache state will always be an empty array. This is just to make sure @@ -265,17 +203,11 @@ def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray] :param inp: The input to the model + :param kv_cache: The kv cache object + :return The input with the kv cache state added to it """ - if self.internal_cache_active: - kv_cache_state = self._initialize_kv_cache_state( - self.cache_length, empty=True - ) - else: - kv_cache_state = self.kv_cache.cached_inputs - if kv_cache_state is None: - self.reset_kv_cache() - kv_cache_state = self.kv_cache.cached_inputs + kv_cache_state = copy.copy(kv_cache.cached_inputs) for idx, input_name in enumerate(self.onnx_input_names_no_cache): kv_cache_state[input_name] = inp[idx] @@ -287,19 +219,21 @@ def update_kv_cache( self, kv_cache_state: List[numpy.ndarray], input_ids_len: int, + kv_cache: DecoderKVCache, ): """ - Updates the state of the kv cache + Updates the kv cache using the new kv cache state. If the internal kv cache is enabled, we refrain from updating the kv cache state as it is being tracked internally inside the engine. We only update the number of tokens processed. - :param kv_cache_state: The state of the kv cache storage + :param kv_cache_state: The new state of the kv cache storage :param input_ids_len: The length of input_ids + :param kv_cache: The kv cache object to update """ - if self.internal_cache_active: - self.kv_cache.total_num_processed_tokens += input_ids_len + if bool(kv_cache.engine_internal_cache): + kv_cache.total_num_processed_tokens += input_ids_len return cache_onnx_names = [ @@ -311,51 +245,7 @@ def update_kv_cache( name: array for name, array in zip(cache_onnx_names, kv_cache_state) } - self.kv_cache.update( + kv_cache.update( state=kv_cache_state, input_ids_len=input_ids_len, ) - - def _initialize_kv_cache_state( - self, length: int, empty: bool = False - ) -> Dict[str, numpy.ndarray]: - # initialize empty kv cache of size - # (batch_size, num_attention_heads, length, hidden_dims) - # if empty is True, we initialize empty kv_cache - # and set the batch_size to 0 - - cache_engine_input_index = next( - i - for i, name in enumerate(self.engine.input_names) - if CACHE_INPUT_PREFIX in name - ) - batch_size, num_attention_heads, _, hidden_dims = self.engine.input_shapes[ - cache_engine_input_index - ] - - empty_kv_cache_tensor = numpy.zeros( - ( - batch_size if not empty else 0, - num_attention_heads, - length, - hidden_dims, - ), - dtype=self.kv_cache_data_type, - ) - - cache_keys = [ - output_name.replace(CACHE_OUTPUT_PREFIX, CACHE_INPUT_PREFIX) - for output_name in self.engine.output_names - if output_name.startswith(CACHE_OUTPUT_PREFIX) - ] - return {key: empty_kv_cache_tensor for key in cache_keys} - - @staticmethod - def _should_freeze_first_position(tokenizer) -> bool: - # use tokenizer to find out whether we should freeze the first position - # (True if tokenizer has a prefix for a BOS token) - if tokenizer is None: - return False - if hasattr(tokenizer, "add_bos_token"): - return True - return False diff --git a/src/deepsparse/transformers/pipelines/chat_pipeline.py b/src/deepsparse/transformers/pipelines/chat_pipeline.py deleted file mode 100644 index 4e76add403..0000000000 --- a/src/deepsparse/transformers/pipelines/chat_pipeline.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional -from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline, TextGenerationOutput, TextGenerationInput -from deepsparse.transformers.utils import SessionStorageKVCache, DecoderKVCache -from pydantic import Field - -class ChatOutput(TextGenerationOutput): - session_id: Optional[str] = Field( - default=None, description="A string identifier for the kv cache session." - -class ChatInput(TextGenerationInput): - session_id: Optional[str] = Field( - default=None, description="A string identifier for the kv cache session." - ) - -class ChatPipeline(TextGenerationPipeline): - def __init__(self, **kwargs): - self.session_storage = SessionStorageKVCache() - super().__init__(**kwargs) - - - - def get_decoder_kv_cache(self, context) -> Optional[DecoderKVCache]: - session_id = context.get("session_id", None) - session = self.session_storage.get(session_id) - if session is None: - session = self._create_decoder(...) - return session - - def process_inputs(...): - - engine_input, context = super().process_inputs(...) - # add session_id context - return engine_input, context - - def split_engine_inputs(...): - pass diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 7b7c55171e..85ac3e2cd3 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -40,7 +40,9 @@ from deepsparse.transformers.utils import DecoderKVCache from deepsparse.transformers.utils.helpers import ( create_causal_mask, + initialize_kv_cache_state, pad_to_fixed_length, + prepends_bos_token, repeat_inputs, ) from deepsparse.transformers.utils.timings import TextGenerationTimings @@ -314,11 +316,8 @@ def initialize_engines( engine_type=self.engine_type, engine_args=self.engine_args, engine_context=self.context, - sampling_temperature=self.sampling_temperature, - deterministic=self.deterministic, sequence_length=self.sequence_length, input_ids_length=input_ids_length, - tokenizer=self.tokenizer, internal_kv_cache=self.internal_kv_cache, timer_manager=self.timer_manager, ) @@ -329,11 +328,8 @@ def initialize_engines( engine_type=self.engine_type, engine_args=self.engine_args, engine_context=self.context, - sampling_temperature=self.sampling_temperature, - deterministic=self.deterministic, sequence_length=self.sequence_length, input_ids_length=1, - tokenizer=self.tokenizer, internal_kv_cache=self.internal_kv_cache, timer_manager=self.timer_manager, ) @@ -524,7 +520,7 @@ def engine_forward( else: # run the prompt through with timer.time(TextGenerationTimings.PROMPT_PREFILL): - prompt_logits, decoder = self.prompt_inference(engine_inputs) + prompt_logits, session = self.prompt_inference(engine_inputs) tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist() token_generator = TokenGenerator( @@ -556,8 +552,8 @@ def engine_forward( with timer.time(TextGenerationTimings.TOKEN_GENERATION): while len(generated_tokens) < max_tokens: with timer.time(TextGenerationTimings.TOKEN_GENERATION_SINGLE): - logits, decoder = self.autoregressive_inference( - tokens=token_generator.tokens, decoder + logits = self.autoregressive_inference( + tokens=token_generator.tokens, kv_cache=session ) token = token_generator.generate(logits=logits[0, -1, :]) generated_tokens.append(token) @@ -613,15 +609,17 @@ def prompt_inference( prompt_logits = [] num_tokens_processed = 0 + session = self.get_kv_cache_decoder() + if len(tokens) > self.prompt_sequence_length and self.enable_multitoken_prefill: - decoder = get_decoder_kv_cache(...) - for engine_inputs in self.engine_inputs_for_prefill(tokens): - new_logits, decoder = self.multitoken_engine(engine_inputs, decoder) + for engine_inputs in self.engine_inputs_for_prefill( + tokens, num_total_processed_tokens=session.total_num_processed_tokens + ): + new_logits = self.multitoken_engine(engine_inputs, kv_cache=session) num_tokens_processed += self.prompt_sequence_length prompt_logits.append(new_logits) - if not num_tokens_processed: - decoder = get_decoder_kv_cache(...) + session.set_capacity(self.sequence_length - 1) # prompt size is small, run autoregressive inference to populate kv cache run_tokens = [] if num_tokens_processed == 0 else tokens[:num_tokens_processed] @@ -631,16 +629,16 @@ def prompt_inference( with self.timer_manager.current.time( TextGenerationTimings.PROMPT_PREFILL_SINGLE ): - new_logits, decoder = self.autoregressive_inference(run_tokens, decoder) + new_logits = self.autoregressive_inference(run_tokens, session) prompt_logits.append(new_logits) - return prompt_logits, decoder + return prompt_logits, session def autoregressive_inference( self, tokens: List[int], - decoder: DecoderKVCache, + kv_cache: DecoderKVCache, ) -> Tuple[int, numpy.ndarray]: """ An inference run that processes the last token to generate @@ -673,12 +671,12 @@ def autoregressive_inference( engine_inputs_map[name] for name in self.engine.onnx_input_names_no_cache ] - generated_logits, decoder = self.engine(engine_inputs, decoder) + generated_logits = self.engine(engine_inputs, kv_cache) - return generated_logits, decoder + return generated_logits def engine_inputs_for_prefill( - self, tokens: List[int] + self, tokens: List[int], num_total_processed_tokens: int ) -> Generator[List[numpy.ndarray], None, None]: """ Takes a list of tokens and creates a generator @@ -698,8 +696,8 @@ def engine_inputs_for_prefill( the sum of: a) the number of tokens in the batch (self.prompt_sequence_length) - b) the number of non-blank cache entries - (num_non_blank_cache_entries) + b) the number of processed tokens so far + (num_total_processed_tokens) so that the attention_mask properly attends to the current input tokens, as well as the previous cache entries. @@ -723,7 +721,6 @@ def engine_inputs_for_prefill( for idx, token_batch in enumerate(token_batches): engine_inputs = [] - num_cached_entries = self.multitoken_engine.num_non_blank_cache_entries for name in self.multitoken_engine.onnx_input_names_no_cache: if name == "input_ids": engine_input = numpy.array([token_batch]) @@ -733,25 +730,19 @@ def engine_inputs_for_prefill( engine_input = numpy.zeros( (1, self.sequence_length), dtype=numpy.int64 ) - # fill it out with 1s (from the right), so that the number - # of unmasked entries is equal to the sum of: - engine_input[ - :, - -( - # ...the number of current input tokens... - self.prompt_sequence_length - # ...and the number of the previous cache entries - + num_cached_entries - ) :, - ] = 1 + num_attention_entries_to_unmask = min( + num_total_processed_tokens + self.prompt_sequence_length, + self.sequence_length, + ) + engine_input[:, -num_attention_entries_to_unmask:] = 1 elif name == "causal_mask": # delay creation of the causal mask continue elif name == "positions": engine_input = ( numpy.arange( - num_cached_entries, - num_cached_entries + self.prompt_sequence_length, + num_total_processed_tokens, + num_total_processed_tokens + self.prompt_sequence_length, ) .reshape(1, -1) .astype(numpy.int64) @@ -845,11 +836,23 @@ def causal_mask_input_present(model_path: str) -> bool: return is_causal_mask_input - def get_decoder_kv_cache(self) -> DecoderKVCache: - return self._create_decoder() + def get_kv_cache_decoder(self) -> DecoderKVCache: + engine = self.multitoken_engine or self.engine + + kv_cache_state = initialize_kv_cache_state( + cache_shape=engine.cache_shape, + kv_cache_data_type=engine.kv_cache_data_type, + output_names=engine.output_names, + length=self.sequence_length - self.prompt_sequence_length, + empty=bool(self.internal_kv_cache), + ) - def _create_decoder(self, context) -> Optional[DecoderKVCache]: - pass + kv_cache = DecoderKVCache(self.internal_kv_cache) + kv_cache.setup( + state=kv_cache_state, + freeze_first_position=prepends_bos_token(self.tokenizer), + ) + return kv_cache def _stop_token_generated( self, token, stop_tokens: Union[None, str, Sequence[str]] diff --git a/src/deepsparse/transformers/utils/__init__.py b/src/deepsparse/transformers/utils/__init__.py index d5c081936d..2caefa5216 100644 --- a/src/deepsparse/transformers/utils/__init__.py +++ b/src/deepsparse/transformers/utils/__init__.py @@ -13,8 +13,9 @@ # limitations under the License. -# flake8: noqa -from .storage_kv_cache import * from .decoder_kv_cache import * from .helpers import * + +# flake8: noqa +from .storage_kv_cache import * from .timings import * diff --git a/src/deepsparse/transformers/utils/decoder_kv_cache.py b/src/deepsparse/transformers/utils/decoder_kv_cache.py index bc182f4931..08a03394e2 100644 --- a/src/deepsparse/transformers/utils/decoder_kv_cache.py +++ b/src/deepsparse/transformers/utils/decoder_kv_cache.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import Any, Dict import numpy @@ -42,14 +42,12 @@ def __init__(self, internal_kv_cache: bool = False): # [batch_size, num_heads, sequence_length, hidden_size] self._sequence_len_axis = SEQUENCE_LENGTH_AXIS self._internal_kv_cache = internal_kv_cache - self._session_id = None self._freeze_first_position = None self._state = None self.engine_internal_cache = None def setup( self, - session_id: str, state: Dict[str, Any], num_processed_tokens: int = 0, freeze_first_position: bool = False, @@ -58,8 +56,6 @@ def setup( Setup the session - a level of abstraction that allocates the resources to store and manipulate the kv cache. - :param session_id: The session id to use for the current - session. Used to identify the kv cache state :param state: The state of the cache. This is a dictionary that maps the name of the cache array to the cache array. The cache tensor is a numpy array of shape @@ -74,7 +70,6 @@ def setup( that corresponds to the BOS token in the sequence. By default, is set to False. """ - self._session_id = session_id self._state = state self._freeze_first_position = freeze_first_position self.total_num_processed_tokens = num_processed_tokens @@ -90,6 +85,7 @@ def update( self, state: Dict[str, Any], input_ids_len: int, + increment_total_num_processed_tokens: int = True, ): """ Updating the session is identical with taking the kv cache @@ -103,8 +99,12 @@ def update( :param input_ids_len: The number of input ids in the current input batch: (batch_size, length). Corresponds to `input_ids.shape[1]` + :param increment_total_num_processed_tokens: If set to True, + the total number of processed tokens will be incremented + by the input_ids_len. """ - self.total_num_processed_tokens += input_ids_len + if increment_total_num_processed_tokens: + self.total_num_processed_tokens += input_ids_len input_state_capacity = state[list(state.keys())[0]].shape[ self._sequence_len_axis @@ -131,7 +131,7 @@ def update( ) if num_non_padded_entries_to_delete: cache_array = self.remove_non_padded_entries( - cache_array, num_entries_to_delete + cache_array, num_non_padded_entries_to_delete ) state[name] = numpy.ascontiguousarray(cache_array) @@ -167,7 +167,7 @@ def remove_non_padded_entries( new_cache_array = cache_array[ :, :, - bool(self._freeze_first_position) + num_non_padded_entries_to_delete :, + int(self._freeze_first_position) + num_non_padded_entries_to_delete :, :, ] if self._freeze_first_position: @@ -198,50 +198,32 @@ def set_capacity(self, capacity: int): state = self.cached_inputs if capacity_difference > 0: - raise NotImplementedError( - "The scenario when capacity" - "needs to be expanded is not yet" - "supported." + self.update( + state, + input_ids_len=capacity_difference, + increment_total_num_processed_tokens=False, ) elif capacity_difference < 0: - indices = [0] * abs(capacity_difference) - state = self._add_entries(state, indices=indices) - + state = self._expand_capacity( + state, num_additional_entries=abs(capacity_difference) + ) + self._state = state else: - return + pass - self._state = state + return - def _add_entries( - self, state: Dict[str, Any], indices: List[int], padding_value: int = 0 + def _expand_capacity( + self, state: Dict[str, Any], num_additional_entries: int, padding_value: int = 0 ) -> Dict[str, Any]: for key, value in state.items(): - # required to make sure that both - # quantized and non quantized caches - # are supported - state_dtype = value.dtype - # change padding_value dtype to match the state dtype - padding_value = numpy.array(padding_value, dtype=state_dtype) - - state[key] = numpy.insert( - value, indices, padding_value, axis=self._sequence_len_axis + zeros = numpy.zeros_like( + (value[:, :, :num_additional_entries, :]), dtype=value.dtype ) + state[key] = numpy.concatenate([zeros, value], axis=self._sequence_len_axis) return state - @property - def id(self): - if self._session_id is None: - raise ValueError("Attempted to access session_id before setting up session") - return self._session_id - - @property - def num_non_blank_entries(self): - """ - :return: the number of non-blank entries in the kv cache - """ - return min(self.capacity, self.total_num_processed_tokens) - @property def capacity(self) -> int: """ @@ -256,10 +238,6 @@ def capacity(self) -> int: self._sequence_len_axis ] - @id.setter - def id(self, session_id: str): - self._session_id = session_id - @property def cached_inputs(self): return self._state diff --git a/src/deepsparse/transformers/utils/helpers.py b/src/deepsparse/transformers/utils/helpers.py index b904570492..92dd400079 100644 --- a/src/deepsparse/transformers/utils/helpers.py +++ b/src/deepsparse/transformers/utils/helpers.py @@ -13,9 +13,12 @@ # limitations under the License. import logging import uuid -from typing import List, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy +from transformers import AutoTokenizer + +from deepsparse.utils.onnx import CACHE_INPUT_PREFIX, CACHE_OUTPUT_PREFIX __all__ = [ @@ -23,11 +26,66 @@ "pad_to_fixed_length", "create_causal_mask", "repeat_inputs", + "initialize_kv_cache_state", + "prepends_bos_token", ] _LOGGER = logging.getLogger(__name__) +def prepends_bos_token(tokenizer: AutoTokenizer) -> bool: + """ + Check whether the tokenizer prepends a BOS token to the input sequence. + + :param tokenizer: tokenizer to check + :return: True if the tokenizer prepends a BOS token to the input sequence, + False otherwise + """ + if hasattr(tokenizer, "add_bos_token"): + return bool(tokenizer.add_bos_token) + return False + + +def initialize_kv_cache_state( + cache_shape: Tuple[int, int, int, int], + kv_cache_data_type: Any, # TODO: add type + output_names: List[str], + length: Optional[int] = None, + empty: bool = False, +) -> Dict[str, numpy.ndarray]: + """ + Initialize the kv cache state for the given set of arguments. + + :param cache_shape: shape of the kv cache tensor. Should be + (batch_size, num_attention_heads, length, hidden_dims) + :param kv_cache_data_type: data type of the kv cache tensor + :param output_names: list of output names from the engine + :param length: length of the input sequence. If None, the length + is taken from the cache_shape + :param empty: if True, initialize an empty kv cache tensor + with batch_size set to 0. Otherwise, initialize a kv cache + tensor with zeros + """ + batch_size, num_attention_heads, length_, hidden_dims = cache_shape + + empty_kv_cache_tensor = numpy.zeros( + ( + batch_size if not empty else 0, + num_attention_heads, + length if length is not None else length_, + hidden_dims, + ), + dtype=kv_cache_data_type, + ) + + cache_keys = [ + output_name.replace(CACHE_OUTPUT_PREFIX, CACHE_INPUT_PREFIX) + for output_name in output_names + if output_name.startswith(CACHE_OUTPUT_PREFIX) + ] + return {key: empty_kv_cache_tensor for key in cache_keys} + + def generate_session_id() -> str: """ Generate uuid for session id. This is used to diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index 33e87328ff..0a8b36c01a 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -19,6 +19,7 @@ import pytest from deepsparse import Pipeline from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache +from deepsparse.transformers.utils.helpers import prepends_bos_token from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource @@ -66,18 +67,18 @@ def Fibonacci(n): CODE_LANGUAGE_PROMPT, 13, ), - ( - "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/" - "opt_pretrain/base-none", - "facebook/opt-1.3b", - True, - NATURAL_LANGUAGE_PROMPT, - 3.9, - ), + # ( + # "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/" + # "opt_pretrain/base-none", + # "facebook/opt-1.3b", + # True, + # NATURAL_LANGUAGE_PROMPT, + # 3.9, + # ), ], scope="class", ) -@pytest.mark.skip(reason="Those tests are too heavy to run as a normal part of the CI.") +# @pytest.mark.skip(reason="Those tests are too heavy to run as a normal part of the CI.") class TestTextGenerationPipeline: """ This test suite is meant to test the main scenarios of @@ -153,7 +154,7 @@ def test_freeze_first_position(self, setup): # the kv cache is full _, uses_bos_token, _ = setup pipeline = self.get_pipeline() - assert pipeline.engine._freeze_first_position == uses_bos_token + assert prepends_bos_token(pipeline.tokenizer) == uses_bos_token def test_ort_single_token_prefill(self, setup): # Test the pipeline that uses ORT engine. The test covers the @@ -181,11 +182,8 @@ def test_ort_single_token_prefill(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache - assert cache_session.total_num_processed_tokens < self.sequence_length self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, ) @@ -215,11 +213,8 @@ def test_ort_multi_token_prefill(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache - assert cache_session.total_num_processed_tokens < self.sequence_length self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, ) @@ -249,16 +244,8 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache - assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( - "for this scenario, the kv cache should be full: " - "the total number of processed tokens should be " - "greater than the sequence length" - ) - self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 ) @@ -285,13 +272,11 @@ def test_deepsparse_single_token_prefill(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache - assert cache_session.total_num_processed_tokens < self.sequence_length + self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.internal_kv_cache, + run_cache_validation=False, ) def test_deepsparse_multi_token_prefill(self, setup): @@ -316,11 +301,8 @@ def test_deepsparse_multi_token_prefill(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache - assert cache_session.total_num_processed_tokens < self.sequence_length self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, run_cache_validation=not self.internal_kv_cache, ) @@ -347,18 +329,10 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache - assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( - "for this scenario, the kv cache should be full: " - "the total number of processed tokens should be " - "greater than the sequence length" - ) self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, - run_cache_validation=not self.internal_kv_cache, max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 ) @@ -417,13 +391,11 @@ def test_num_generated_predictions(self, setup): def _test_output( self, output: "TextGenerationOutput", # noqa F821 - cache_session: DecoderKVCache, torch_ground_truth: Tuple[numpy.ndarray, ...], + cache_session: Optional[DecoderKVCache] = None, max_logits_difference_threshold: Optional[float] = None, run_cache_validation: bool = True, ): - # extract numpy arrays from cached_inputs - kv_cache_array = list(cache_session.cached_inputs.values()) ( generated_logits, @@ -455,7 +427,9 @@ def _test_output( assert numpy.allclose(output.logits, target_logits, atol=_PRECISION) assert self.prompt + output.sequences[0] == generated_text - if run_cache_validation: + if run_cache_validation and cache_session: + # extract numpy arrays from cached_inputs + kv_cache_array = list(cache_session.cached_inputs.values()) self._test_kv_cache_state( expected_cache=kv_cache_array, target_cache=torch_ground_truth[2], From 44cb44ac22656961c458b9cc922096969cec343d Mon Sep 17 00:00:00 2001 From: Damian Date: Wed, 20 Sep 2023 10:40:36 +0000 Subject: [PATCH 4/7] in this form tests pass --- .../transformers/pipelines/text_generation.py | 19 +++++++----- .../pipelines/test_text_generation.py | 30 +++++++++++-------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 85ac3e2cd3..b7bff080f8 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -613,7 +613,7 @@ def prompt_inference( if len(tokens) > self.prompt_sequence_length and self.enable_multitoken_prefill: for engine_inputs in self.engine_inputs_for_prefill( - tokens, num_total_processed_tokens=session.total_num_processed_tokens + tokens, kv_cache=session ): new_logits = self.multitoken_engine(engine_inputs, kv_cache=session) num_tokens_processed += self.prompt_sequence_length @@ -649,14 +649,16 @@ def autoregressive_inference( (with dimensions ['batch_size', 'num_tokens', 'vocab_size']) """ + num_total_processed_tokens = kv_cache.total_num_processed_tokens new_token = tokens[-1] # padding is added to left, so attention mask is 1s from the # right up to the number of total tokens (prompt + generated) attention_mask = numpy.zeros((1, self.sequence_length), dtype=numpy.int64) - num_tokens_processed = min(len(tokens), self.sequence_length) # cap by seq len - attention_mask[:, -num_tokens_processed:] = 1 - positions = numpy.array([[len(tokens)]], dtype=numpy.int64) - positions -= 1 + num_attention_entries_to_unmask = min( + num_total_processed_tokens + 1, self.sequence_length + ) # cap by seq len + attention_mask[:, -num_attention_entries_to_unmask:] = 1 + positions = numpy.array([[num_total_processed_tokens]], dtype=numpy.int64) input_ids = numpy.array([[new_token]]) causal_mask = create_causal_mask(input_ids, attention_mask) @@ -670,13 +672,13 @@ def autoregressive_inference( engine_inputs = [ engine_inputs_map[name] for name in self.engine.onnx_input_names_no_cache ] - + print(f"Token: {new_token} positions: {positions}") generated_logits = self.engine(engine_inputs, kv_cache) return generated_logits def engine_inputs_for_prefill( - self, tokens: List[int], num_total_processed_tokens: int + self, tokens: List[int], kv_cache: DecoderKVCache ) -> Generator[List[numpy.ndarray], None, None]: """ Takes a list of tokens and creates a generator @@ -709,7 +711,6 @@ def engine_inputs_for_prefill( :param tokens: the list of tokens to process :return: a generator of engine inputs """ - num_batches = len(tokens) // self.prompt_sequence_length token_batches = [ @@ -720,6 +721,7 @@ def engine_inputs_for_prefill( ] for idx, token_batch in enumerate(token_batches): + num_total_processed_tokens = kv_cache.total_num_processed_tokens engine_inputs = [] for name in self.multitoken_engine.onnx_input_names_no_cache: if name == "input_ids": @@ -756,6 +758,7 @@ def engine_inputs_for_prefill( input_ids=engine_inputs[0], attention_mask=engine_inputs[1] ) engine_inputs.append(causal_mask) + print(f"Token: {engine_inputs[0]} positions: {engine_inputs[2]}") yield engine_inputs diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index 0a8b36c01a..75b0197aca 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -20,6 +20,7 @@ from deepsparse import Pipeline from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache from deepsparse.transformers.utils.helpers import prepends_bos_token +from huggingface_hub import snapshot_download from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource @@ -28,11 +29,7 @@ NATURAL_LANGUAGE_PROMPT = """ Didn't know what time it was, the lights were low I leaned back on my radio -Some cat was layin' down some rock 'n' roll -"Lotta soul," he said -Then the loud sound did seem to fade -Came back like a slow voice on a wave of phase -That weren't no DJ, that was hazy cosmic jive + """ CODE_LANGUAGE_PROMPT = """ @@ -59,14 +56,14 @@ def Fibonacci(n): "prompt, " "logits_max_diff_kv_cache_has_been_filled", [ - ( - "zoo:nlg/text_generation/codegen_mono-350m/pytorch/" - "huggingface/bigpython_bigquery_thepile/base-none", - "salesforce/codegen-350m-mono", - False, - CODE_LANGUAGE_PROMPT, - 13, - ), + # ( + # "zoo:nlg/text_generation/codegen_mono-350m/pytorch/" + # "huggingface/bigpython_bigquery_thepile/base-none", + # "salesforce/codegen-350m-mono", + # False, + # CODE_LANGUAGE_PROMPT, + # 13, + # ), # ( # "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/" # "opt_pretrain/base-none", @@ -75,6 +72,13 @@ def Fibonacci(n): # NATURAL_LANGUAGE_PROMPT, # 3.9, # ), + ( + snapshot_download("mgoin/TinyStories-33M-deepsparse"), + "roneneldan/TinyStories-33M", + False, + NATURAL_LANGUAGE_PROMPT, + 13, + ), ], scope="class", ) From 0b0e75c93ba104e6cb8b9cffb30ba7d2d535aa23 Mon Sep 17 00:00:00 2001 From: Damian Date: Wed, 20 Sep 2023 11:07:28 +0000 Subject: [PATCH 5/7] cleanup --- src/deepsparse/pipeline.py | 7 +- .../transformers/engines/nl_decoder_engine.py | 41 ++++++++-- .../transformers/pipelines/text_generation.py | 7 +- src/deepsparse/transformers/utils/__init__.py | 4 +- .../pipelines/test_text_generation.py | 82 ++++++++++++------- tests/test_data/pipeline_bench_config.json | 1 + tests/test_pipeline_benchmark.py | 1 + 7 files changed, 97 insertions(+), 46 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 96aea4ed8e..1fb09cca0f 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -254,10 +254,9 @@ def __call__(self, *args, **kwargs) -> BaseModel: # submit split batches to engine threadpool engine_forward_with_context = partial(self.engine_forward, context=context) - # batch_outputs = list( - # self.executor.map(engine_forward_with_context, batches) - # ) - batch_outputs = [engine_forward_with_context(x) for x in batches] + batch_outputs = list( + self.executor.map(engine_forward_with_context, batches) + ) # join together the batches of size `self._batch_size` engine_outputs = self.join_engine_outputs(batch_outputs, orig_batch_size) diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index e49d0091bb..eefe6e5374 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -106,8 +106,24 @@ def onnx_input_names_no_cache(self) -> List[str]: if not name.startswith(CACHE_INPUT_PREFIX) ] + @property + def onnx_input_names_cache(self) -> List[str]: + """ + :return: The cached input names for the onnx model + """ + return [ + name + for name in self.engine.input_names + if name.startswith(CACHE_INPUT_PREFIX) + ] + @property def cache_shape(self) -> Tuple[int, int, int, int]: + """ + :return: The shape of the kv cache inputs + for the onnx model. The shape is + (batch_size, num_heads, sequence_length, hidden_size) + """ cache_engine_input_index = next( i for i, name in enumerate(self.engine.input_names) @@ -125,7 +141,22 @@ def output_names(self) -> List[str]: def run( self, inputs: List[numpy.ndarray], val_inp: bool, kv_cache: DecoderKVCache ) -> List[numpy.ndarray]: - """ """ + """ + Run the engine with the given inputs. + If the kv_cache.engine_internal_cache=True, the internal + deepsparse kv cache management is enabled. In this case + the LIB.kv_cache class object will be passed to the engine + call as well. In this scenario also the inputs will not be + validated, even if the val_inp=True. This is because we + want to pass the empty kv cache inputs (batch_size=0) to + the engine. + + :param inputs: The inputs to run the engine with + :param val_inp: Whether the input is for validation or not + :param kv_cache: The kv cache object to use for the inference + + :return: The output of the engine + """ if bool(kv_cache.engine_internal_cache): # conventionally, before dispatching # inputs to the engine, we validate them @@ -236,13 +267,9 @@ def update_kv_cache( kv_cache.total_num_processed_tokens += input_ids_len return - cache_onnx_names = [ - name - for name in self.engine.input_names - if name.startswith(CACHE_INPUT_PREFIX) - ] kv_cache_state = { - name: array for name, array in zip(cache_onnx_names, kv_cache_state) + name: array + for name, array in zip(self.onnx_input_names_cache, kv_cache_state) } kv_cache.update( diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index d21fb865e7..e15e8ed3e3 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -742,7 +742,6 @@ def autoregressive_inference( engine_inputs = [ engine_inputs_map[name] for name in self.engine.onnx_input_names_no_cache ] - print(f"Token: {new_token} positions: {positions}") generated_logits = self.engine(engine_inputs, kv_cache) return generated_logits @@ -828,7 +827,6 @@ def engine_inputs_for_prefill( input_ids=engine_inputs[0], attention_mask=engine_inputs[1] ) engine_inputs.append(causal_mask) - print(f"Token: {engine_inputs[0]} positions: {engine_inputs[2]}") yield engine_inputs @@ -912,6 +910,11 @@ def causal_mask_input_present(model_path: str) -> bool: return is_causal_mask_input def get_kv_cache_decoder(self) -> DecoderKVCache: + """ + Initialize the kv cache decoder for the inference + + :return: the initialized kv cache decoder + """ engine = self.multitoken_engine or self.engine kv_cache_state = initialize_kv_cache_state( diff --git a/src/deepsparse/transformers/utils/__init__.py b/src/deepsparse/transformers/utils/__init__.py index 2caefa5216..c507d5bc85 100644 --- a/src/deepsparse/transformers/utils/__init__.py +++ b/src/deepsparse/transformers/utils/__init__.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - +# flake8: noqa from .decoder_kv_cache import * from .helpers import * - -# flake8: noqa from .storage_kv_cache import * from .timings import * diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index 9b50632726..9eeacfb4ce 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -19,8 +19,6 @@ import pytest from deepsparse import Pipeline from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache -from deepsparse.transformers.utils.helpers import prepends_bos_token -from huggingface_hub import snapshot_download from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource @@ -29,7 +27,11 @@ NATURAL_LANGUAGE_PROMPT = """ Didn't know what time it was, the lights were low I leaned back on my radio - +Some cat was layin' down some rock 'n' roll +"Lotta soul," he said +Then the loud sound did seem to fade +Came back like a slow voice on a wave of phase +That weren't no DJ, that was hazy cosmic jive """ CODE_LANGUAGE_PROMPT = """ @@ -56,33 +58,26 @@ def Fibonacci(n): "prompt, " "logits_max_diff_kv_cache_has_been_filled", [ - # ( - # "zoo:nlg/text_generation/codegen_mono-350m/pytorch/" - # "huggingface/bigpython_bigquery_thepile/base-none", - # "salesforce/codegen-350m-mono", - # False, - # CODE_LANGUAGE_PROMPT, - # 13, - # ), - # ( - # "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/" - # "opt_pretrain/base-none", - # "facebook/opt-1.3b", - # True, - # NATURAL_LANGUAGE_PROMPT, - # 3.9, - # ), ( - snapshot_download("mgoin/TinyStories-33M-deepsparse"), - "roneneldan/TinyStories-33M", + "zoo:nlg/text_generation/codegen_mono-350m/pytorch/" + "huggingface/bigpython_bigquery_thepile/base-none", + "salesforce/codegen-350m-mono", False, - NATURAL_LANGUAGE_PROMPT, + CODE_LANGUAGE_PROMPT, 13, ), + ( + "zoo:nlg/text_generation/opt-1.3b/pytorch/huggingface/" + "opt_pretrain/base-none", + "facebook/opt-1.3b", + True, + NATURAL_LANGUAGE_PROMPT, + 3.9, + ), ], scope="class", ) -# @pytest.mark.skip(reason="Those tests are too heavy to run as a normal part of the CI.") +@pytest.mark.skip(reason="Those tests are too heavy to run as a normal part of the CI.") class TestTextGenerationPipeline: """ This test suite is meant to test the main scenarios of @@ -158,7 +153,7 @@ def test_freeze_first_position(self, setup): # the kv cache is full _, uses_bos_token, _ = setup pipeline = self.get_pipeline() - assert prepends_bos_token(pipeline.tokenizer) == uses_bos_token + assert pipeline.engine._freeze_first_position == uses_bos_token def test_ort_single_token_prefill(self, setup): # Test the pipeline that uses ORT engine. The test covers the @@ -186,8 +181,11 @@ def test_ort_single_token_prefill(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens < self.sequence_length self._test_output( output=output, + cache_session=cache_session, torch_ground_truth=torch_ground_truth, ) @@ -217,8 +215,11 @@ def test_ort_multi_token_prefill(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens < self.sequence_length self._test_output( output=output, + cache_session=cache_session, torch_ground_truth=torch_ground_truth, ) @@ -248,8 +249,16 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( + "for this scenario, the kv cache should be full: " + "the total number of processed tokens should be " + "greater than the sequence length" + ) + self._test_output( output=output, + cache_session=cache_session, torch_ground_truth=torch_ground_truth, max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 ) @@ -276,11 +285,13 @@ def test_deepsparse_single_token_prefill(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens < self.sequence_length self._test_output( output=output, + cache_session=cache_session, torch_ground_truth=torch_ground_truth, - run_cache_validation=False, + run_cache_validation=not self.internal_kv_cache, ) def test_deepsparse_multi_token_prefill(self, setup): @@ -305,8 +316,11 @@ def test_deepsparse_multi_token_prefill(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens < self.sequence_length self._test_output( output=output, + cache_session=cache_session, torch_ground_truth=torch_ground_truth, run_cache_validation=not self.internal_kv_cache, ) @@ -333,10 +347,18 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) + cache_session = pipeline.engine.kv_cache + assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( + "for this scenario, the kv cache should be full: " + "the total number of processed tokens should be " + "greater than the sequence length" + ) self._test_output( output=output, + cache_session=cache_session, torch_ground_truth=torch_ground_truth, + run_cache_validation=not self.internal_kv_cache, max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 ) @@ -406,11 +428,13 @@ def test_num_generated_predictions(self, setup): def _test_output( self, output: "TextGenerationOutput", # noqa F821 + cache_session: DecoderKVCache, torch_ground_truth: Tuple[numpy.ndarray, ...], - cache_session: Optional[DecoderKVCache] = None, max_logits_difference_threshold: Optional[float] = None, run_cache_validation: bool = True, ): + # extract numpy arrays from cached_inputs + kv_cache_array = list(cache_session.cached_inputs.values()) ( generated_logits, @@ -441,9 +465,7 @@ def _test_output( assert numpy.allclose(score, target_logits[0], atol=_PRECISION) assert self.prompt + output.generations[0].text == generated_text - if run_cache_validation and cache_session: - # extract numpy arrays from cached_inputs - kv_cache_array = list(cache_session.cached_inputs.values()) + if run_cache_validation: self._test_kv_cache_state( expected_cache=kv_cache_array, target_cache=torch_ground_truth[2], diff --git a/tests/test_data/pipeline_bench_config.json b/tests/test_data/pipeline_bench_config.json index afd4db352d..5886762cea 100644 --- a/tests/test_data/pipeline_bench_config.json +++ b/tests/test_data/pipeline_bench_config.json @@ -2,6 +2,7 @@ "data_type": "dummy", "gen_sequence_length": 100, "input_image_shape": [500,500,3], + "data_folder": "/home/sadkins/imagenette2-320/", "recursive_search": true, "max_string_length": -1, "pipeline_kwargs": {}, diff --git a/tests/test_pipeline_benchmark.py b/tests/test_pipeline_benchmark.py index 782a1f8016..485599d044 100644 --- a/tests/test_pipeline_benchmark.py +++ b/tests/test_pipeline_benchmark.py @@ -95,6 +95,7 @@ def test_pipeline_benchmark( if res.stdout is not None: print(f"\n==== test_benchmark output ====\n{res.stdout}") assert res.returncode == 0 + assert "error" not in res.stdout.lower() assert "fail" not in res.stdout.lower() assert "total_inference" in res.stdout.lower() From 412cfc75674e9706e8d3d5049a99b223d8173a1c Mon Sep 17 00:00:00 2001 From: Damian Date: Wed, 20 Sep 2023 12:22:29 +0000 Subject: [PATCH 6/7] ready for review --- .../transformers/engines/nl_decoder_engine.py | 4 +- .../transformers/pipelines/text_generation.py | 4 + src/deepsparse/transformers/utils/__init__.py | 2 +- src/deepsparse/transformers/utils/helpers.py | 8 +- .../transformers/utils/storage_kv_cache.py | 92 ------------------- .../engine/test_nl_decoder_engine.py | 5 +- .../pipelines/test_text_generation.py | 30 +++--- .../utils/test_decoder_kv_cache.py | 18 +++- .../transformers/utils/test_helpers.py | 46 +++++++++- 9 files changed, 89 insertions(+), 120 deletions(-) delete mode 100644 src/deepsparse/transformers/utils/storage_kv_cache.py diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index eefe6e5374..1db401b07e 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -107,7 +107,7 @@ def onnx_input_names_no_cache(self) -> List[str]: ] @property - def onnx_input_names_cache(self) -> List[str]: + def onnx_input_names_cached(self) -> List[str]: """ :return: The cached input names for the onnx model """ @@ -269,7 +269,7 @@ def update_kv_cache( kv_cache_state = { name: array - for name, array in zip(self.onnx_input_names_cache, kv_cache_state) + for name, array in zip(self.onnx_input_names_cached, kv_cache_state) } kv_cache.update( diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index e15e8ed3e3..2ce73a6bf5 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -267,6 +267,7 @@ def __init__( self.tokenizer.pad_token = self.tokenizer.eos_token self.engine, self.multitoken_engine = self.initialize_engines() + self._debug = False def initialize_engines( self, @@ -653,6 +654,9 @@ def engine_forward( if streamer is not None: streamer.end() + if self._debug: + self._debug = dict(kv_cache=session) + return ( numpy.array([generated_tokens]), numpy.concatenate(generated_logits, axis=1), diff --git a/src/deepsparse/transformers/utils/__init__.py b/src/deepsparse/transformers/utils/__init__.py index c507d5bc85..80f24f1040 100644 --- a/src/deepsparse/transformers/utils/__init__.py +++ b/src/deepsparse/transformers/utils/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + # flake8: noqa from .decoder_kv_cache import * from .helpers import * -from .storage_kv_cache import * from .timings import * diff --git a/src/deepsparse/transformers/utils/helpers.py b/src/deepsparse/transformers/utils/helpers.py index 92dd400079..6db0530557 100644 --- a/src/deepsparse/transformers/utils/helpers.py +++ b/src/deepsparse/transformers/utils/helpers.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import uuid -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy from transformers import AutoTokenizer @@ -48,7 +48,7 @@ def prepends_bos_token(tokenizer: AutoTokenizer) -> bool: def initialize_kv_cache_state( cache_shape: Tuple[int, int, int, int], - kv_cache_data_type: Any, # TODO: add type + kv_cache_data_type: numpy.dtype, output_names: List[str], length: Optional[int] = None, empty: bool = False, @@ -68,7 +68,7 @@ def initialize_kv_cache_state( """ batch_size, num_attention_heads, length_, hidden_dims = cache_shape - empty_kv_cache_tensor = numpy.zeros( + kv_cache_tensor = numpy.zeros( ( batch_size if not empty else 0, num_attention_heads, @@ -83,7 +83,7 @@ def initialize_kv_cache_state( for output_name in output_names if output_name.startswith(CACHE_OUTPUT_PREFIX) ] - return {key: empty_kv_cache_tensor for key in cache_keys} + return {key: kv_cache_tensor for key in cache_keys} def generate_session_id() -> str: diff --git a/src/deepsparse/transformers/utils/storage_kv_cache.py b/src/deepsparse/transformers/utils/storage_kv_cache.py deleted file mode 100644 index 6e525b9693..0000000000 --- a/src/deepsparse/transformers/utils/storage_kv_cache.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Dict, Union - -from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache - - -_LOGGER = logging.getLogger(__name__) - -__all__ = ["SessionStorageKVCache"] - - -class SessionStorageKVCache: - """ - A storage that stores the kv cache sessions. - Each session is a DecoderKVCache object that - stores the state of the kv cache. - The storage is a dictionary that where keys are session_ids - and values are of all the active sessions. - """ - - def __init__(self): - self._memory: Dict[str, DecoderKVCache] = dict() - - def __len__(self): - return len(self._memory) - - def __str__(self): - return ( - f"{SessionStorageKVCache.__name__}:\n " - f"\tsessions: {[session_name for session_name in self._memory.keys()]}\n" - ) - - def has_session(self, session_id: str) -> bool: - """ - Check if the storage has a session with the given session id. - :param session_id: The identifier of the cache session. - :return: True if the storage has a session with the given session id. - """ - return session_id in self._memory - - def put(self, session: DecoderKVCache): - """ - Put the cache session in the storage. - - :param session: The session to store. - """ - session_id = session.id - if self.has_session(session_id): - _LOGGER.debug( - f"Session: {session_id} already exists in the storage. " - f"It will be overwritten." - ) - self._memory[session.id] = session - - def get(self, session_id: str) -> Union[DecoderKVCache, None]: - """ - Get the state of the kv cache for a session from the storage. - - :param session_id: The identifier of the cache session. - :return: The state of the kv cache for the session. - """ - session = self._memory.get(session_id) - if session is None: - _LOGGER.debug(f"No cache session found for session id: {session_id}") - return session - - def pop(self, session_id: str) -> DecoderKVCache: - """ - Pop the session correspond to session_id from the storage. - :param session_id: The identifier of the cache session. - """ - session = self._memory.pop(session_id, None) - if session is None: - raise ValueError( - f"Attempting to remove session: {session_id} from the storage. " - f"However, the session does not exist in the storage." - ) - return session diff --git a/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py b/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py index 7d80aa6ada..ff5b67aa75 100644 --- a/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py +++ b/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py @@ -47,9 +47,10 @@ def test_add_kv_cache_to_input(): with patch.object(NLDecoderEngine, "__init__", lambda x, y, z: None): nl_decoder_engine = NLDecoderEngine(None, None) nl_decoder_engine.engine = DummyEngine() - nl_decoder_engine.kv_cache = DummyKVCacheDecoder() nl_decoder_engine.kv_cache_enabled = True - result = nl_decoder_engine.add_kv_cache_to_input(inp) + result = nl_decoder_engine.add_kv_cache_to_input( + inp, kv_cache=DummyKVCacheDecoder + ) for (x, y) in zip(result, expected_result): assert np.array_equal(x, y) diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index 9eeacfb4ce..31b05916b1 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -19,6 +19,7 @@ import pytest from deepsparse import Pipeline from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache +from deepsparse.transformers.utils.helpers import prepends_bos_token from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource @@ -153,7 +154,7 @@ def test_freeze_first_position(self, setup): # the kv cache is full _, uses_bos_token, _ = setup pipeline = self.get_pipeline() - assert pipeline.engine._freeze_first_position == uses_bos_token + assert prepends_bos_token(pipeline.tokenizer) == uses_bos_token def test_ort_single_token_prefill(self, setup): # Test the pipeline that uses ORT engine. The test covers the @@ -172,16 +173,16 @@ def test_ort_single_token_prefill(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length, prompt_sequence_length=1, - force_max_tokens=True, engine_type="onnxruntime", ) + pipeline._debug = True output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache + cache_session = pipeline._debug.get("kv_cache") assert cache_session.total_num_processed_tokens < self.sequence_length self._test_output( output=output, @@ -206,16 +207,16 @@ def test_ort_multi_token_prefill(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length, prompt_sequence_length=self.prompt_sequence_length, - force_max_tokens=True, engine_type="onnxruntime", ) + pipeline._debug = True output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache + cache_session = pipeline._debug.get("kv_cache") assert cache_session.total_num_processed_tokens < self.sequence_length self._test_output( output=output, @@ -243,13 +244,14 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): force_max_tokens=True, engine_type="onnxruntime", ) + pipeline._debug = True output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache + cache_session = pipeline._debug.get("kv_cache") assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( "for this scenario, the kv cache should be full: " "the total number of processed tokens should be " @@ -276,16 +278,16 @@ def test_deepsparse_single_token_prefill(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length, prompt_sequence_length=1, - force_max_tokens=True, internal_kv_cache=self.internal_kv_cache, ) + pipeline._debug = True output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache + cache_session = pipeline._debug.get("kv_cache") assert cache_session.total_num_processed_tokens < self.sequence_length self._test_output( output=output, @@ -307,16 +309,16 @@ def test_deepsparse_multi_token_prefill(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length, prompt_sequence_length=self.prompt_sequence_length, - force_max_tokens=True, internal_kv_cache=self.internal_kv_cache, ) + pipeline._debug = True output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache + cache_session = pipeline._debug.get("kv_cache") assert cache_session.total_num_processed_tokens < self.sequence_length self._test_output( output=output, @@ -338,16 +340,16 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length_short, prompt_sequence_length=self.prompt_sequence_length, - force_max_tokens=True, internal_kv_cache=self.internal_kv_cache, ) + pipeline._debug = True output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache + cache_session = pipeline._debug.get("kv_cache") assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( "for this scenario, the kv cache should be full: " "the total number of processed tokens should be " @@ -433,8 +435,6 @@ def _test_output( max_logits_difference_threshold: Optional[float] = None, run_cache_validation: bool = True, ): - # extract numpy arrays from cached_inputs - kv_cache_array = list(cache_session.cached_inputs.values()) ( generated_logits, @@ -466,6 +466,8 @@ def _test_output( assert self.prompt + output.generations[0].text == generated_text if run_cache_validation: + # extract numpy arrays from cached_inputs + kv_cache_array = list(cache_session.cached_inputs.values()) self._test_kv_cache_state( expected_cache=kv_cache_array, target_cache=torch_ground_truth[2], diff --git a/tests/deepsparse/transformers/utils/test_decoder_kv_cache.py b/tests/deepsparse/transformers/utils/test_decoder_kv_cache.py index 8fcc40dd87..e93890548e 100644 --- a/tests/deepsparse/transformers/utils/test_decoder_kv_cache.py +++ b/tests/deepsparse/transformers/utils/test_decoder_kv_cache.py @@ -81,7 +81,6 @@ def setup( # initialize a session session.setup( - session_id="dummy_id", state=state, num_processed_tokens=num_processed_tokens, freeze_first_position=freeze_first_position, @@ -93,11 +92,10 @@ def test_session_attributes(self, setup): # check if the session attributes are set correctly state = session.cached_inputs - assert session.num_non_blank_entries == np.count_nonzero( + assert session.total_num_processed_tokens == np.count_nonzero( state["dummy_cache_name"].flatten() ) assert session.capacity == state["dummy_cache_name"].shape[2] - assert session.id == "dummy_id" def test_set_capacity(self, setup): session, _, _ = setup @@ -134,7 +132,19 @@ def _test_increase_capacity(session_): @staticmethod def _test_decrease_capacity(session_): - pass + session = copy.deepcopy(session_) + capacity = session.capacity + # decrease capacity by 3 + session.set_capacity(capacity - 3) + kv_cache_state = session.cached_inputs + # check if the capacity has been decreased by 3 + if session_._freeze_first_position: + bos_token = session_.cached_inputs["dummy_cache_name"][:, :, :1, :] + new_array = session_.cached_inputs["dummy_cache_name"][:, :, 4:, :] + target_array = np.concatenate([bos_token, new_array], axis=2) + else: + target_array = session_.cached_inputs["dummy_cache_name"][:, :, 3:, :] + assert np.array_equal(target_array, kv_cache_state["dummy_cache_name"]) @staticmethod def _test_constant_capacity(session_): diff --git a/tests/deepsparse/transformers/utils/test_helpers.py b/tests/deepsparse/transformers/utils/test_helpers.py index c96a004d6b..bc0fabb3d4 100644 --- a/tests/deepsparse/transformers/utils/test_helpers.py +++ b/tests/deepsparse/transformers/utils/test_helpers.py @@ -15,7 +15,10 @@ import numpy import pytest -from deepsparse.transformers.utils.helpers import create_causal_mask +from deepsparse.transformers.utils.helpers import ( + create_causal_mask, + initialize_kv_cache_state, +) @pytest.mark.parametrize( @@ -85,3 +88,44 @@ def test_create_causal_mask(input_ids, attention_mask, expected_causal_mask): causal_mask = create_causal_mask(input_ids, attention_mask) assert numpy.array_equal(causal_mask, expected_causal_mask[None, None, ...]) + + +@pytest.mark.parametrize( + "cache_shape, kv_cache_data_type, output_names, length, empty, expected_result", + [ + ( + (1, 2, 3, 4), + numpy.float32, + ["present.1", "present.2", "present.3"], + None, + False, + { + "past_key_values.1": numpy.zeros((1, 2, 3, 4)), + "past_key_values.2": numpy.zeros((1, 2, 3, 4)), + "past_key_values.3": numpy.zeros((1, 2, 3, 4)), + }, + ), + ( + (5, 6, 7, 8), + numpy.int8, + ["present.1", "present.2"], + 10, + True, + { + "past_key_values.1": numpy.zeros((0, 6, 10, 8), dtype=numpy.int8), + "past_key_values.2": numpy.zeros((0, 6, 10, 8), dtype=numpy.int8), + }, + ), + ], +) +def test_initialize_kv_cache_state( + cache_shape, kv_cache_data_type, output_names, length, empty, expected_result +): + # make sure that resulting Dict[str, numpy.ndarray] is the same + # as the expected_result + result = initialize_kv_cache_state( + cache_shape, kv_cache_data_type, output_names, length, empty + ) + assert result.keys() == expected_result.keys() + for key in result.keys(): + assert numpy.array_equal(result[key], expected_result[key]) From 86cacef25cd12dc1d0eade7799ee315454421fba Mon Sep 17 00:00:00 2001 From: Damian Date: Thu, 21 Sep 2023 06:46:01 +0000 Subject: [PATCH 7/7] PR review changes --- .../transformers/engines/nl_decoder_engine.py | 7 +-- .../transformers/pipelines/text_generation.py | 49 ++++++++++++++++--- src/deepsparse/transformers/utils/helpers.py | 13 +++++ .../pipelines/test_text_generation.py | 37 ++++++-------- 4 files changed, 71 insertions(+), 35 deletions(-) diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index 1db401b07e..5ec5001c2f 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -73,9 +73,7 @@ def __init__( input_ids_length=input_ids_length, ) - kv_cache_enabled = False if any(output_indices_to_be_cached): - kv_cache_enabled = True self.kv_cache_data_type = kv_cache_data_type if internal_kv_cache and engine_type == DEEPSPARSE_ENGINE: # inform the engine, that are using the kv cache @@ -91,7 +89,6 @@ def __init__( self.sequence_length = sequence_length self.input_ids_length = input_ids_length self.cache_length = sequence_length - input_ids_length - self.kv_cache_enabled = kv_cache_enabled self._engine_type = engine_type @property @@ -189,7 +186,7 @@ def __call__( :return: The generated token and corresponding logits """ timer = self.timer_manager.current - if self.kv_cache_enabled: + if kv_cache: # if model has kv cache enabled, we need # to add the kv cache state to the input inp = self.add_kv_cache_to_input(inp, kv_cache) @@ -197,7 +194,7 @@ def __call__( with timer.time(f"EXECUTE_ENGINE_SEQ_LEN_{self.sequence_length}"): out = self.run(inp, val_inp, kv_cache) - if self.kv_cache_enabled: + if kv_cache: with timer.time(TextGenerationTimings.KV_CACHE_UPDATE): logits, *kv_cache_state = out self.update_kv_cache( diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 2ce73a6bf5..20e0d546ad 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -183,6 +183,7 @@ class TextGenerationOutput(BaseModel): class Config: arbitrary_types_allowed = True + extra = "allow" @Pipeline.register( @@ -267,6 +268,8 @@ def __init__( self.tokenizer.pad_token = self.tokenizer.eos_token self.engine, self.multitoken_engine = self.initialize_engines() + + # auxiliary flag for devs to enable debug mode for the pipeline self._debug = False def initialize_engines( @@ -494,7 +497,7 @@ def process_engine_outputs( :param engine_outputs: the outputs from the engine :return: the output schema for the pipeline """ - generated_tokens, generated_logits, finished_reason = engine_outputs + generated_tokens, generated_logits, finished_reason, *debug = engine_outputs finished_reason = [f[0] for f in finished_reason] sequences = self.tokenizer.batch_decode( @@ -544,13 +547,26 @@ def _create_generated_text_output( ] generations = grouped_generations - return TextGenerationOutput( + outputs = dict( created=datetime.datetime.now(), prompts=prompts, generations=generations ) + if debug: + kv_cache_state, total_num_processed_tokens = debug + debug_params = dict( + kv_cache_state=kv_cache_state, + total_num_processed_tokens=total_num_processed_tokens, + ) + outputs.update(debug_params) + + return TextGenerationOutput(**outputs) + def engine_forward( self, engine_inputs: List[numpy.ndarray], context: Dict - ) -> Tuple[numpy.ndarray, numpy.ndarray, List[FinishReason]]: + ) -> Union[ + Tuple[numpy.ndarray, numpy.ndarray, List[FinishReason]], + Tuple[numpy.ndarray, numpy.ndarray, List[FinishReason], DecoderKVCache], + ]: """ Run the forward pass on the engine. @@ -654,15 +670,17 @@ def engine_forward( if streamer is not None: streamer.end() - if self._debug: - self._debug = dict(kv_cache=session) - - return ( + returns = ( numpy.array([generated_tokens]), numpy.concatenate(generated_logits, axis=1), finished_reason, ) + if self._debug is True: + return *returns, session + + return returns + def prompt_inference( self, engine_inputs: List[numpy.ndarray], @@ -676,6 +694,7 @@ def prompt_inference( :return: A tuple of: - The logits generated from the prompt (with dimensions ['batch_size', 'num_tokens', 'vocab_size']) + - The kv cache session for this inference run """ # get tokens by attention mask tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist() @@ -776,6 +795,9 @@ def engine_inputs_for_prefill( so that the attention_mask properly attends to the current input tokens, as well as the previous cache entries. + Note: the aformentioned sum must be capped + by the sequence length, as the maximum shape of the + attention mask is [batch_size, sequence_length]. - positions: derived directly from the input_ids @@ -805,6 +827,8 @@ def engine_inputs_for_prefill( engine_input = numpy.zeros( (1, self.sequence_length), dtype=numpy.int64 ) + # calculate the number of entries in attention mask + # that should be set to 1 num_attention_entries_to_unmask = min( num_total_processed_tokens + self.prompt_sequence_length, self.sequence_length, @@ -857,7 +881,7 @@ def join_engine_outputs( :param orig_batch_size: The original batch size :return: A list of joined outputs """ - tokens, logits, finish_reason = zip(*batch_outputs) + tokens, logits, finish_reason, *debug = zip(*batch_outputs) if self.cache_support_enabled: # if the model has kv cache, we need to account for # the fact that the predicted outputs may have @@ -889,6 +913,15 @@ def join_engine_outputs( tokens = numpy.concatenate(tokens, axis=0) logits = numpy.concatenate(logits, axis=0) + if debug: + sessions = debug[0] + kv_cache_state = numpy.stack(session.cached_inputs for session in sessions) + num_processed_tokens = numpy.stack( + session.total_num_processed_tokens for session in sessions + ) + + return [tokens, logits, finish_reason, kv_cache_state, num_processed_tokens] + return [tokens, logits, finish_reason] @staticmethod diff --git a/src/deepsparse/transformers/utils/helpers.py b/src/deepsparse/transformers/utils/helpers.py index 6db0530557..57d09e309e 100644 --- a/src/deepsparse/transformers/utils/helpers.py +++ b/src/deepsparse/transformers/utils/helpers.py @@ -65,9 +65,22 @@ def initialize_kv_cache_state( :param empty: if True, initialize an empty kv cache tensor with batch_size set to 0. Otherwise, initialize a kv cache tensor with zeros + + :return: dictionary of kv cache tensors, where the keys are the + output names of the kv cache tensors and the values are the + kv cache tensors of shape + (batch_size, num_attention_heads, length, hidden_dims) """ batch_size, num_attention_heads, length_, hidden_dims = cache_shape + # new kv cache tensor is either + # - non-empty tensor of zeros with shape + # (batch_size, num_attention_heads, length, hidden_dims), + # required for the external kv cache management + # or + # - empty tensor with shape + # (0, num_attention_heads, length, hidden_dims) + # required for the internal kv cache management kv_cache_tensor = numpy.zeros( ( batch_size if not empty else 0, diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index 31b05916b1..d8c4fde2a1 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -18,7 +18,6 @@ import pytest from deepsparse import Pipeline -from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache from deepsparse.transformers.utils.helpers import prepends_bos_token from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource @@ -182,11 +181,9 @@ def test_ort_single_token_prefill(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline._debug.get("kv_cache") - assert cache_session.total_num_processed_tokens < self.sequence_length + assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, ) @@ -216,11 +213,10 @@ def test_ort_multi_token_prefill(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline._debug.get("kv_cache") - assert cache_session.total_num_processed_tokens < self.sequence_length + + assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, ) @@ -251,8 +247,8 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline._debug.get("kv_cache") - assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( + + assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( "for this scenario, the kv cache should be full: " "the total number of processed tokens should be " "greater than the sequence length" @@ -260,7 +256,6 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 ) @@ -287,11 +282,10 @@ def test_deepsparse_single_token_prefill(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline._debug.get("kv_cache") - assert cache_session.total_num_processed_tokens < self.sequence_length + + assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, run_cache_validation=not self.internal_kv_cache, ) @@ -312,17 +306,17 @@ def test_deepsparse_multi_token_prefill(self, setup): internal_kv_cache=self.internal_kv_cache, ) pipeline._debug = True + output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline._debug.get("kv_cache") - assert cache_session.total_num_processed_tokens < self.sequence_length + + assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, run_cache_validation=not self.internal_kv_cache, ) @@ -349,8 +343,8 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline._debug.get("kv_cache") - assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( + + assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( "for this scenario, the kv cache should be full: " "the total number of processed tokens should be " "greater than the sequence length" @@ -358,7 +352,6 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, run_cache_validation=not self.internal_kv_cache, max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 @@ -430,7 +423,6 @@ def test_num_generated_predictions(self, setup): def _test_output( self, output: "TextGenerationOutput", # noqa F821 - cache_session: DecoderKVCache, torch_ground_truth: Tuple[numpy.ndarray, ...], max_logits_difference_threshold: Optional[float] = None, run_cache_validation: bool = True, @@ -467,11 +459,12 @@ def _test_output( if run_cache_validation: # extract numpy arrays from cached_inputs - kv_cache_array = list(cache_session.cached_inputs.values()) + kv_cache_array = list(output.kv_cache_state[0].values()) + total_num_processed_tokens = output.total_num_processed_tokens[0] self._test_kv_cache_state( expected_cache=kv_cache_array, target_cache=torch_ground_truth[2], - total_num_processed_tokens=cache_session.total_num_processed_tokens, + total_num_processed_tokens=total_num_processed_tokens, ) @staticmethod