Skip to content

[Text Generation] KV Cache internal Deepsparse support #1135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0bcf1ea
fix kv cache
SageMoore Jul 20, 2023
d9487bc
Merge branch 'main' into kv-cache-fixes
dbogunowicz Jul 24, 2023
1485478
refactor
dbogunowicz Jul 24, 2023
23c97c8
add validation pathway
dbogunowicz Jul 24, 2023
2deb33b
avx2 support
dbogunowicz Jul 25, 2023
4526499
add import
dbogunowicz Jul 25, 2023
1898a56
Merge remote-tracking branch 'origin/main' into kv-cache-fixes
dbogunowicz Jul 31, 2023
f41689a
initial commit
dbogunowicz Jul 31, 2023
4ed646a
initial commit
dbogunowicz Jul 31, 2023
7f34062
initial implementation
dbogunowicz Aug 1, 2023
817a1fa
Merge branch 'main' into kv-cache-fixes
dbogunowicz Aug 1, 2023
95b0082
problems with multitoken prefill
dbogunowicz Aug 1, 2023
db566c9
Merge branch 'main' into kv-cache-fixes
dbogunowicz Aug 2, 2023
ebf76fc
its working
dbogunowicz Aug 2, 2023
c8e54b8
Merge branch 'kv-cache-fixes' of https://github.com/neuralmagic/deeps…
dbogunowicz Aug 2, 2023
36c6664
Merge branch 'feature/damian/fb_testing' into tests/damian/decoder_kv…
dbogunowicz Aug 3, 2023
2353cb2
almost there...
dbogunowicz Aug 3, 2023
124a922
Merge branch 'main' into kv-cache-fixes
dbogunowicz Aug 3, 2023
aac5b85
finally all tests pass
dbogunowicz Aug 4, 2023
7ee2577
just need to change to stub
dbogunowicz Aug 4, 2023
ef38160
Merge remote-tracking branch 'origin/tests/damian/decoder_kv_cache' i…
dbogunowicz Aug 4, 2023
21b9456
Merge remote-tracking branch 'origin/tests/feature/nl_dec_engine' int…
dbogunowicz Aug 4, 2023
caef2f7
fix bad merge
dbogunowicz Aug 4, 2023
ffdc7fb
Merge branch 'main' into kv-cache-fixes
dbogunowicz Aug 7, 2023
8d98f73
Merge remote-tracking branch 'origin/tests/damian/llms' into kv-cache…
dbogunowicz Aug 7, 2023
f6b6807
added some tests
dbogunowicz Aug 7, 2023
a80c46d
ready for review
dbogunowicz Aug 7, 2023
c055873
Merge branch 'feature/damian/fb_testing' into tests/damian/decoder_kv…
dbogunowicz Aug 7, 2023
d851ab4
Merge branch 'tests/damian/decoder_kv_cache' into tests/feature/nl_de…
dbogunowicz Aug 7, 2023
0a42d3f
Merge branch 'tests/feature/nl_dec_engine' into tests/damian/llms
dbogunowicz Aug 7, 2023
9e6ea03
Merge branch 'tests/damian/llms' into kv-cache-fixes
dbogunowicz Aug 7, 2023
0cbb1a3
finish rebase
dbogunowicz Aug 8, 2023
1a524c7
Merge remote-tracking branch 'origin/feature/damian/fb_testing' into …
dbogunowicz Aug 9, 2023
098c7cf
full support
dbogunowicz Aug 9, 2023
b0a7c96
Merge branch 'feature/damian/fb_testing' into kv-cache-fixes
dbogunowicz Aug 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/deepsparse/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
num_streams: int = None,
scheduler: Scheduler = None,
input_shapes: List[List[int]] = None,
cached_outputs: List[bool] = None,
):
BaseEngine.construct(
self, model, batch_size, num_cores, num_streams, scheduler, input_shapes
Expand All @@ -316,6 +317,7 @@ def __init__(
self._num_streams,
self._scheduler.value,
None,
cached_outputs,
)
else:
self._eng_net = LIB.deepsparse_engine(
Expand All @@ -325,6 +327,7 @@ def __init__(
self._num_streams,
self._scheduler.value,
None,
cached_outputs,
)

def __call__(
Expand Down
37 changes: 31 additions & 6 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from typing import Any, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -87,7 +86,7 @@ def __init__(
self.kv_cache_data_type = kv_cache_data_type
if use_deepsparse_cache and engine_type == DEEPSPARSE_ENGINE:
# inform the engine, that are using the kv cache
engine_args["cache_output_bools"] = output_indices_to_be_cached
engine_args["cached_outputs"] = output_indices_to_be_cached

self.engine = create_engine(
onnx_file_path=onnx_file_path,
Expand All @@ -105,6 +104,7 @@ def __init__(
)
self._freeze_first_position = self._should_freeze_first_position(tokenizer)
self._session_id = generate_session_id()
self._engine_type = engine_type

@property
def session_id(self) -> str:
Expand Down Expand Up @@ -140,6 +140,32 @@ def num_non_blank_cache_entries(self) -> int:
"""
return self.kv_cache.num_non_blank_entries

def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]:
"""
Run the engine with the given inputs.

If the internal deepsparse kv cache management is enable,
the LIB.kv_cache class object will be passed to the engine
call as well.

:param inputs: The inputs to run the engine with
:param val_inp: Whether the input is for validation or not

:return: The output of the engine
"""

if self.kv_cache is not None:
if self.kv_cache._kv_cache is not None:
if val_inp:
self.engine._validate_inputs(inputs)
# model has kv cache support, as well as deepsparse
# internal management of the kv cache
return self.engine._eng_net.execute_list_out(
inputs, self.kv_cache._kv_cache
)

return self.engine.run(inputs, val_inp)

def __call__(
self,
inp: List[numpy.ndarray],
Expand All @@ -159,7 +185,7 @@ def __call__(
# to the input
inp = self.add_kv_cache_to_input(inp)

out = self.engine.run(inp, val_inp)
out = self.run(inp, val_inp)

if self.kv_cache:
logits, *kv_cache_state = out
Expand Down Expand Up @@ -192,10 +218,9 @@ def transfer_cache_state(self, cache: DecoderKVCache):
:param cache: The `DecoderKVCache` object to transfer to the engine
from
"""
cache_to_copy = copy.deepcopy(cache)
target_cache_capacity = self.sequence_length - self.input_ids_length
cache_to_copy.set_capacity(target_cache_capacity)
self.kv_cache = cache_to_copy
cache.set_capacity(target_cache_capacity)
self.kv_cache = cache

def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
"""
Expand Down
13 changes: 3 additions & 10 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from transformers import TextStreamer

from deepsparse import Pipeline
from deepsparse.cpu import cpu_avx512_compatible
from deepsparse.pipeline import DEEPSPARSE_ENGINE
from deepsparse.transformers.engines import NLDecoderEngine
from deepsparse.transformers.pipelines import TransformersPipeline
Expand Down Expand Up @@ -146,22 +145,16 @@ def __init__(
**kwargs,
):
kwargs_engine_type = kwargs.get("engine_type", DEEPSPARSE_ENGINE)
if not cpu_avx512_compatible() and kwargs_engine_type == DEEPSPARSE_ENGINE:
warnings.warn(
"AVX512 support not detected, disabling internal management "
"of KV cache which may affect performance. To enable full "
"performance, deploy on an AVX512-compatible system."
)
use_deepsparse_cache = False

if use_deepsparse_cache:
if kwargs_engine_type != DEEPSPARSE_ENGINE:
raise ValueError(
_LOGGER.warning(
"`use_deepsparse_cache` is set to True "
"but the chosen `engine_type` "
f"is {kwargs_engine_type}. "
f"Make sure to set `engine_type` to {DEEPSPARSE_ENGINE}"
f"The optimized kv cache management is disabled."
)
use_deepsparse_cache = False

super().__init__(
**kwargs, _delay_engine_initialize=True, _delay_overwriting_inputs=True
Expand Down
6 changes: 5 additions & 1 deletion src/deepsparse/transformers/utils/decoder_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import numpy

from deepsparse.engine import LIB


__all__ = ["DecoderKVCache", "SEQUENCE_LENGTH_AXIS"]

Expand Down Expand Up @@ -78,7 +80,9 @@ def setup(
self.total_num_processed_tokens = num_processed_tokens

if self._use_deepsparse_cache:
raise NotImplementedError("DeepSparse cache is not supported yet.")
prev_num_tokens = self.total_num_processed_tokens
num_frozen_tokens = int(self._freeze_first_position)
self._kv_cache = LIB.kv_cache(prev_num_tokens, num_frozen_tokens)

def update(
self,
Expand Down
15 changes: 13 additions & 2 deletions tests/deepsparse/transformers/pipelines/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def _initialize_kv_cache_state(model, length=0):
return kv_cache


@pytest.mark.parametrize(
"use_deepsparse_cache",
[True, False],
)
@pytest.mark.parametrize(
"model_stub, model_name, uses_bos_token",
[
Expand All @@ -67,14 +71,15 @@ def _initialize_kv_cache_state(model, length=0):
scope="class",
)
@pytest.mark.skip(
reason="Those tests are to heavy to " "run as a normal part of the CI."
reason="Those tests are too heavy to " "run as a normal part of the CI."
)
class TestTextGenerationPipeline:
@pytest.fixture
def setup(self, model_stub, model_name, uses_bos_token):
def setup(self, model_stub, model_name, uses_bos_token, use_deepsparse_cache):

self.max_generated_tokens = 16
self.model = Model(model_stub)
self.use_deepsparse_cache = use_deepsparse_cache

pipeline = Pipeline.create(
task="text_generation",
Expand Down Expand Up @@ -125,6 +130,12 @@ def test_model_output_sequences(self, setup):

def test_model_output_cache(self, setup):
pipeline, model_name, _, short_prompt, long_prompt = setup
if self.use_deepsparse_cache:
pytest.skip(
"Running pipeline with internal "
"deepsparse cache will not result "
"in meaningful cache entries."
)
self._test_cache_state(short_prompt, pipeline, model_name)
self._test_cache_state(long_prompt, pipeline, model_name)

Expand Down