diff --git a/src/deepsparse/engine.py b/src/deepsparse/engine.py index 624b6f5cf5..4ec05f29e1 100644 --- a/src/deepsparse/engine.py +++ b/src/deepsparse/engine.py @@ -300,6 +300,7 @@ def __init__( num_streams: int = None, scheduler: Scheduler = None, input_shapes: List[List[int]] = None, + cached_outputs: List[bool] = None, ): BaseEngine.construct( self, model, batch_size, num_cores, num_streams, scheduler, input_shapes @@ -316,6 +317,7 @@ def __init__( self._num_streams, self._scheduler.value, None, + cached_outputs, ) else: self._eng_net = LIB.deepsparse_engine( @@ -325,6 +327,7 @@ def __init__( self._num_streams, self._scheduler.value, None, + cached_outputs, ) def __call__( diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index 8bcb4d4ea5..d76c7fb8d7 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import logging from typing import Any, Dict, List, Optional, Tuple @@ -87,7 +86,7 @@ def __init__( self.kv_cache_data_type = kv_cache_data_type if use_deepsparse_cache and engine_type == DEEPSPARSE_ENGINE: # inform the engine, that are using the kv cache - engine_args["cache_output_bools"] = output_indices_to_be_cached + engine_args["cached_outputs"] = output_indices_to_be_cached self.engine = create_engine( onnx_file_path=onnx_file_path, @@ -105,6 +104,7 @@ def __init__( ) self._freeze_first_position = self._should_freeze_first_position(tokenizer) self._session_id = generate_session_id() + self._engine_type = engine_type @property def session_id(self) -> str: @@ -140,6 +140,32 @@ def num_non_blank_cache_entries(self) -> int: """ return self.kv_cache.num_non_blank_entries + def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]: + """ + Run the engine with the given inputs. + + If the internal deepsparse kv cache management is enable, + the LIB.kv_cache class object will be passed to the engine + call as well. + + :param inputs: The inputs to run the engine with + :param val_inp: Whether the input is for validation or not + + :return: The output of the engine + """ + + if self.kv_cache is not None: + if self.kv_cache._kv_cache is not None: + if val_inp: + self.engine._validate_inputs(inputs) + # model has kv cache support, as well as deepsparse + # internal management of the kv cache + return self.engine._eng_net.execute_list_out( + inputs, self.kv_cache._kv_cache + ) + + return self.engine.run(inputs, val_inp) + def __call__( self, inp: List[numpy.ndarray], @@ -159,7 +185,7 @@ def __call__( # to the input inp = self.add_kv_cache_to_input(inp) - out = self.engine.run(inp, val_inp) + out = self.run(inp, val_inp) if self.kv_cache: logits, *kv_cache_state = out @@ -192,10 +218,9 @@ def transfer_cache_state(self, cache: DecoderKVCache): :param cache: The `DecoderKVCache` object to transfer to the engine from """ - cache_to_copy = copy.deepcopy(cache) target_cache_capacity = self.sequence_length - self.input_ids_length - cache_to_copy.set_capacity(target_cache_capacity) - self.kv_cache = cache_to_copy + cache.set_capacity(target_cache_capacity) + self.kv_cache = cache def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray: """ diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index d01cde211f..391f5835a5 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -24,7 +24,6 @@ from transformers import TextStreamer from deepsparse import Pipeline -from deepsparse.cpu import cpu_avx512_compatible from deepsparse.pipeline import DEEPSPARSE_ENGINE from deepsparse.transformers.engines import NLDecoderEngine from deepsparse.transformers.pipelines import TransformersPipeline @@ -146,22 +145,16 @@ def __init__( **kwargs, ): kwargs_engine_type = kwargs.get("engine_type", DEEPSPARSE_ENGINE) - if not cpu_avx512_compatible() and kwargs_engine_type == DEEPSPARSE_ENGINE: - warnings.warn( - "AVX512 support not detected, disabling internal management " - "of KV cache which may affect performance. To enable full " - "performance, deploy on an AVX512-compatible system." - ) - use_deepsparse_cache = False if use_deepsparse_cache: if kwargs_engine_type != DEEPSPARSE_ENGINE: - raise ValueError( + _LOGGER.warning( "`use_deepsparse_cache` is set to True " "but the chosen `engine_type` " f"is {kwargs_engine_type}. " - f"Make sure to set `engine_type` to {DEEPSPARSE_ENGINE}" + f"The optimized kv cache management is disabled." ) + use_deepsparse_cache = False super().__init__( **kwargs, _delay_engine_initialize=True, _delay_overwriting_inputs=True diff --git a/src/deepsparse/transformers/utils/decoder_kv_cache.py b/src/deepsparse/transformers/utils/decoder_kv_cache.py index 9233233255..b031d40f6f 100644 --- a/src/deepsparse/transformers/utils/decoder_kv_cache.py +++ b/src/deepsparse/transformers/utils/decoder_kv_cache.py @@ -16,6 +16,8 @@ import numpy +from deepsparse.engine import LIB + __all__ = ["DecoderKVCache", "SEQUENCE_LENGTH_AXIS"] @@ -78,7 +80,9 @@ def setup( self.total_num_processed_tokens = num_processed_tokens if self._use_deepsparse_cache: - raise NotImplementedError("DeepSparse cache is not supported yet.") + prev_num_tokens = self.total_num_processed_tokens + num_frozen_tokens = int(self._freeze_first_position) + self._kv_cache = LIB.kv_cache(prev_num_tokens, num_frozen_tokens) def update( self, diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index b4bbb42992..2bfe899594 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -49,6 +49,10 @@ def _initialize_kv_cache_state(model, length=0): return kv_cache +@pytest.mark.parametrize( + "use_deepsparse_cache", + [True, False], +) @pytest.mark.parametrize( "model_stub, model_name, uses_bos_token", [ @@ -67,14 +71,15 @@ def _initialize_kv_cache_state(model, length=0): scope="class", ) @pytest.mark.skip( - reason="Those tests are to heavy to " "run as a normal part of the CI." + reason="Those tests are too heavy to " "run as a normal part of the CI." ) class TestTextGenerationPipeline: @pytest.fixture - def setup(self, model_stub, model_name, uses_bos_token): + def setup(self, model_stub, model_name, uses_bos_token, use_deepsparse_cache): self.max_generated_tokens = 16 self.model = Model(model_stub) + self.use_deepsparse_cache = use_deepsparse_cache pipeline = Pipeline.create( task="text_generation", @@ -125,6 +130,12 @@ def test_model_output_sequences(self, setup): def test_model_output_cache(self, setup): pipeline, model_name, _, short_prompt, long_prompt = setup + if self.use_deepsparse_cache: + pytest.skip( + "Running pipeline with internal " + "deepsparse cache will not result " + "in meaningful cache entries." + ) self._test_cache_state(short_prompt, pipeline, model_name) self._test_cache_state(long_prompt, pipeline, model_name)