diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index 30176b3b10..d5f5dfa91b 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -64,7 +64,7 @@ def __init__( sampling_temperature: float = 1.0, deterministic: bool = True, engine_context: Optional[Context] = None, - use_deepsparse_cache=False, + use_deepsparse_cache: bool = False, ): # flag to indicate if the model is quantized or not self.kv_cache_data_type = None @@ -93,10 +93,12 @@ def __init__( engine_args=engine_args, context=engine_context, ) + self.sequence_length = sequence_length self.sampling_temperature = sampling_temperature self.deterministic = deterministic self.input_ids_length = input_ids_length + self.cache_length = sequence_length - input_ids_length self.kv_cache_enabled = kv_cache_enabled self.kv_cache = ( DecoderKVCache(use_deepsparse_cache) if kv_cache_enabled else None @@ -134,35 +136,41 @@ def onnx_input_names_no_cache(self) -> List[str]: @property def num_non_blank_cache_entries(self) -> int: """ - :return a number of non-blank entries in the - kv cache + :return A number of non-blank entries in the + kv cache """ return self.kv_cache.num_non_blank_entries + @property + def internal_cache_active(self) -> bool: + """ + :return: Whether the internal kv cache is active + """ + return self.kv_cache_enabled and self.kv_cache.engine_internal_cache is not None + def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]: """ Run the engine with the given inputs. - If the internal deepsparse kv cache management is enable, + If the self.internal_cache_active=True, the internal + deepsparse kv cache management is enabled. In this case the LIB.kv_cache class object will be passed to the engine call as well. :param inputs: The inputs to run the engine with :param val_inp: Whether the input is for validation or not - :return: The output of the engine """ - if self.kv_cache is not None: - if self.kv_cache._kv_cache is not None: - if val_inp: - self.engine._validate_inputs(inputs) - # model has kv cache support, as well as deepsparse - # internal management of the kv cache - return self.engine._eng_net.execute_list_out( - inputs, self.kv_cache._kv_cache - ) - + if self.internal_cache_active: + # validate the inputs if needed + if val_inp: + self.engine._validate_inputs(inputs) + # run the engine with the LIB.kv_cache object + return self.engine._eng_net.execute_list_out( + inputs, self.kv_cache.engine_internal_cache + ) + # run the engine without the LIB.kv_cache object return self.engine.run(inputs, val_inp) def __call__( @@ -180,8 +188,8 @@ def __call__( :return: The generated token and corresponding logits """ if self.kv_cache: - # if kv cache is enabled, we need to add the kv cache state - # to the input + # if model has kv cache enabled, we need + # to add the kv cache state to the input inp = self.add_kv_cache_to_input(inp) out = self.run(inp, val_inp) @@ -217,8 +225,7 @@ def transfer_cache_state(self, cache: DecoderKVCache): :param cache: The `DecoderKVCache` object to transfer to the engine from """ - target_cache_capacity = self.sequence_length - self.input_ids_length - cache.set_capacity(target_cache_capacity) + cache.set_capacity(self.cache_length) self.kv_cache = cache def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray: @@ -241,9 +248,7 @@ def reset_kv_cache(self): """ Resets the kv cache state. """ - kv_cache_state = self._initialize_kv_cache_state( - self.sequence_length - self.input_ids_length - ) + kv_cache_state = self._initialize_kv_cache_state(self.cache_length) self.kv_cache.setup( session_id=self._session_id, state=kv_cache_state, @@ -255,13 +260,27 @@ def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray] """ Takes the input and adds the past kv cache state to it. + If the internal kv cache is enabled, the kv cache state + will always be reinitialized to zeros. This is just to make sure + that the input shapes of the kv cache arrays to the + model are correct, the actual values are + being tracked internally inside the engine. + + If the internal kv cache is disabled, we need to + fetch the kv cache state as numpy arrays + from the current session, or initialize it if required. + + :param inp: The input to the model :return The input with the kv cache state added to it """ - kv_cache_state = self.kv_cache.cached_inputs - if kv_cache_state is None: - self.reset_kv_cache() + if self.internal_cache_active: + kv_cache_state = self._initialize_kv_cache_state(self.cache_length) + else: kv_cache_state = self.kv_cache.cached_inputs + if kv_cache_state is None: + self.reset_kv_cache() + kv_cache_state = self.kv_cache.cached_inputs for idx, input_name in enumerate(self.onnx_input_names_no_cache): kv_cache_state[input_name] = inp[idx] @@ -277,9 +296,17 @@ def update_kv_cache( """ Updates the state of the kv cache + If the internal kv cache is enabled, we refrain from + updating the kv cache state as it is being tracked internally + inside the engine. We only update the number of tokens processed. + :param kv_cache_state: The state of the kv cache storage :param input_ids_len: The length of input_ids """ + if self.internal_cache_active: + self.kv_cache.total_num_processed_tokens += input_ids_len + return + cache_onnx_names = [ name for name in self.engine.input_names diff --git a/src/deepsparse/transformers/utils/decoder_kv_cache.py b/src/deepsparse/transformers/utils/decoder_kv_cache.py index 957c8c9a2e..8d1548c852 100644 --- a/src/deepsparse/transformers/utils/decoder_kv_cache.py +++ b/src/deepsparse/transformers/utils/decoder_kv_cache.py @@ -31,8 +31,9 @@ def __init__(self, use_deepsparse_cache: bool = False): The goal this object is to handle the manipulation of the key value cache. - :param use_deepsparse_cache: If set to True, the `kv_cache` object - from the deepsparse.LIB will be loaded as an attribute. + :param use_deepsparse_cache: If set to True, the + `kv_cache` object from the deepsparse.LIB will + be loaded as an engine_internal_cache attribute. This object is used to handle the manipulation of the key/value buffers on the DeepSparse engine side. """ @@ -45,7 +46,7 @@ def __init__(self, use_deepsparse_cache: bool = False): self._session_id = None self._freeze_first_position = None self._state = None - self._kv_cache = None + self.engine_internal_cache = None def setup( self, @@ -82,7 +83,9 @@ def setup( if self._use_deepsparse_cache: prev_num_tokens = self.total_num_processed_tokens num_frozen_tokens = int(self._freeze_first_position) - self._kv_cache = LIB.kv_cache(prev_num_tokens, num_frozen_tokens) + self.engine_internal_cache = LIB.kv_cache( + prev_num_tokens, num_frozen_tokens + ) def update( self, diff --git a/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py b/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py index 1769d479eb..f4c8cc2f97 100644 --- a/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py +++ b/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py @@ -25,6 +25,7 @@ class DummyKVCacheDecoder: "past_key_values_1": np.array([10, 11, 12]), "past_key_values_2": np.array([13, 14, 15]), } + engine_internal_cache = None class DummyEngine: @@ -62,6 +63,7 @@ def test_add_kv_cache_to_input(): nl_decoder_engine = NLDecoderEngine(None, None) nl_decoder_engine.engine = DummyEngine() nl_decoder_engine.kv_cache = DummyKVCacheDecoder() + nl_decoder_engine.kv_cache_enabled = True result = nl_decoder_engine.add_kv_cache_to_input(inp) for (x, y) in zip(result, expected_result):