Skip to content

Commit 765a5f7

Browse files
committed
initial implementation
1 parent 5d59d23 commit 765a5f7

File tree

3 files changed

+160
-0
lines changed

3 files changed

+160
-0
lines changed

src/deepsparse/transformers/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .decoder_kv_cache import *
16+
1517
# flake8: noqa
1618
from .kv_cache_ort import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict, Optional
16+
17+
from deepsparse.pipeline import SUPPORTED_PIPELINE_ENGINES
18+
from deepsparse.transformers.utils.kv_cache_ort import KVCacheORT
19+
20+
21+
__all__ = ["DecoderKVCache"]
22+
23+
24+
class DecoderKVCache:
25+
def __init__(self, engine_type: str):
26+
"""
27+
The goal of DecoderKVCache is to provide a common
28+
interface for the KVCache objects used
29+
by the NLDecoderEngine
30+
31+
:param engine_type: The engine type to use for the decoder
32+
"""
33+
if engine_type not in SUPPORTED_PIPELINE_ENGINES:
34+
raise ValueError(f"Unsupported engine type: {engine_type}")
35+
elif engine_type != "onnxruntime":
36+
raise NotImplementedError(f"Unsupported engine type: {engine_type}")
37+
self._kv_cache_type = KVCacheORT
38+
39+
self._kv_cache = None
40+
self._session_id = None
41+
self._frozen_position = None
42+
self._num_tokens = None
43+
44+
def setup_session(
45+
self,
46+
session_id: str,
47+
state: Dict[str, Any],
48+
num_tokens: int,
49+
frozen_position=Optional[int],
50+
):
51+
"""
52+
Setup the session that will be used to transform
53+
the input and output cache values
54+
55+
:param session_id: The session id to use for the current
56+
session
57+
:param state: The state of the cache. This is a dictionary
58+
that maps the name of the cache array to the cache array.
59+
The cache tensor is a numpy array of shape
60+
[batch_size, num_heads, sequence_length, hidden_size]
61+
:param num_tokens: The number of tokens processed so far,
62+
corresponding to the number of "non-blank" entries in the
63+
kv cache array.
64+
:param frozen_position: The position along the sequence length axis
65+
that is frozen and thus, once it is occupied by a "non-blank"
66+
cache entry, it cannot be removed from the cache.
67+
"""
68+
self.session_id = session_id
69+
self._num_tokens = num_tokens
70+
self._frozen_position = frozen_position
71+
self._initialize_kv_cache(state)
72+
73+
def update_session(self, state: Dict[str, Any]):
74+
"""
75+
Update the session with the new state of the cache
76+
77+
:param state: The state of the cache. This is a dictionary
78+
that maps the name of the cache array to the cache array.
79+
The cache tensor is a numpy array of shape
80+
[batch_size, num_heads, sequence_length, hidden_size]
81+
"""
82+
self._num_tokens += 1
83+
self._initialize_kv_cache(state)
84+
85+
@property
86+
def session_id(self):
87+
if self._session_id is None:
88+
raise ValueError("Attempted to access session_id before setting up session")
89+
return self._session_id
90+
91+
@property
92+
def cached_inputs(self):
93+
if self._kv_cache is None:
94+
raise ValueError(
95+
"Attempted to access cached inputs before setting up session"
96+
)
97+
# TODO: Not sure whether this is the appropriate place
98+
# to invoke the shift_last method, to reconsider
99+
self._kv_cache.shift_last()
100+
return self._kv_cache.state
101+
102+
@session_id.setter
103+
def session_id(self, session_id: str):
104+
self._session_id = session_id
105+
106+
def _initialize_kv_cache(self, state: Dict[str, Any]):
107+
self._kv_cache = KVCacheORT(
108+
state=state,
109+
num_tokens=self._num_tokens,
110+
frozen_position=self._frozen_position,
111+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
17+
import pytest
18+
from deepsparse.transformers.utils import DecoderKVCache
19+
20+
21+
@pytest.mark.parametrize(
22+
"state, num_tokens, state_shifted, new_state, new_state_shifted",
23+
[
24+
(
25+
{"dummy_cache_name": np.array([[[[0], [0], [1], [2], [3]]]])},
26+
3,
27+
{"dummy_cache_name": np.array([[[[0], [1], [2], [3]]]])},
28+
{"dummy_cache_name": np.array([[[[0], [1], [2], [3], [4]]]])},
29+
{"dummy_cache_name": np.array([[[[1], [2], [3], [4]]]])},
30+
),
31+
],
32+
)
33+
def test_kv_cache_ort_shift(
34+
state, num_tokens, state_shifted, new_state, new_state_shifted
35+
):
36+
decoder_kv_cache = DecoderKVCache(engine_type="onnxruntime")
37+
decoder_kv_cache.setup_session(
38+
session_id="some_id", state=state, num_tokens=num_tokens
39+
)
40+
cache = decoder_kv_cache.cached_inputs
41+
for k, v in cache.items():
42+
assert np.array_equal(v, state_shifted[k])
43+
44+
decoder_kv_cache.update_session(state=new_state)
45+
cache = decoder_kv_cache.cached_inputs
46+
for k, v in cache.items():
47+
assert np.array_equal(v, new_state_shifted[k])

0 commit comments

Comments
 (0)