Skip to content

Commit d434c77

Browse files
committed
feat: Segment cache by active LoRAs; change key format
1 parent 5dc0a1e commit d434c77

File tree

2 files changed

+98
-34
lines changed

2 files changed

+98
-34
lines changed

Diff for: llama_cpp/llama.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from .llama_types import *
3535
from .llama_grammar import LlamaGrammar
3636
from .llama_cache import (
37+
LlamaCacheKey,
3738
BaseLlamaCache,
3839
LlamaCache, # type: ignore
3940
LlamaDiskCache, # type: ignore
@@ -407,7 +408,7 @@ def __init__(
407408
# Dict from LoRA path to wrapper
408409
self._lora_adapters_paths: Dict[str, internals.LlamaLoraAdapter] = {}
409410
# Immutable value representing active adapters for use as a key
410-
self._lora_adapters_active: Tuple[Tuple[str, float]] = ()
411+
self._lora_adapters_active: Tuple[Tuple[str, float], ...] = ()
411412

412413
if self.lora_adapters:
413414
for lora_path, scale in self.lora_adapters.copy().items():
@@ -1315,7 +1316,8 @@ def logit_bias_processor(
13151316

13161317
if self.cache:
13171318
try:
1318-
cache_item = self.cache[(self._lora_adapters_active, prompt_tokens)]
1319+
cache_key = LlamaCacheKey(active_lora_adapters=self._lora_adapters_active, tokens=tuple(prompt_tokens))
1320+
cache_item = self.cache[cache_key]
13191321
cache_prefix_len = Llama.longest_token_prefix(
13201322
cache_item.input_ids.tolist(), prompt_tokens
13211323
)
@@ -1653,15 +1655,17 @@ def logit_bias_processor(
16531655
if self.cache:
16541656
if self.verbose:
16551657
print("Llama._create_completion: cache save", file=sys.stderr)
1656-
self.cache[(self._lora_adapters_active, prompt_tokens + completion_tokens)] = self.save_state()
1658+
cache_key = LlamaCacheKey(active_lora_adapters=self._lora_adapters_active, tokens=tuple(prompt_tokens + completion_tokens))
1659+
self.cache[cache_key] = self.save_state()
16571660
if self.verbose:
16581661
print("Llama._create_completion: cache saved", file=sys.stderr)
16591662
return
16601663

16611664
if self.cache:
16621665
if self.verbose:
16631666
print("Llama._create_completion: cache save", file=sys.stderr)
1664-
self.cache[(self._lora_adapters_active, prompt_tokens + completion_tokens)] = self.save_state()
1667+
cache_key = LlamaCacheKey(active_lora_adapters=self._lora_adapters_active, tokens=tuple(prompt_tokens + completion_tokens))
1668+
self.cache[cache_key] = self.save_state()
16651669

16661670
text_str = text.decode("utf-8", errors="ignore")
16671671

Diff for: llama_cpp/llama_cache.py

+90-30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
22
from abc import ABC, abstractmethod
3+
from dataclasses import dataclass
34
from typing import (
45
Optional,
56
Sequence,
@@ -13,38 +14,93 @@
1314

1415
from .llama_types import *
1516

17+
@dataclass(eq=True, frozen=True)
18+
class LlamaCacheKey:
19+
"""A key in a LlamaCache. Stores tokens to key by. Also stores
20+
information about active LoRA adapters, because we need different
21+
cached values for different active adapters, even for the same tokens."""
22+
active_lora_adapters: Tuple[Tuple[str, float], ...]
23+
tokens: Tuple[int, ...]
24+
25+
def __post_init__(self):
26+
if not isinstance(self.tokens, tuple):
27+
raise ValueError("tokens must be a tuple")
1628

1729
class BaseLlamaCache(ABC):
1830
"""Base cache class for a llama.cpp model."""
1931

2032
def __init__(self, capacity_bytes: int = (2 << 30)):
2133
self.capacity_bytes = capacity_bytes
2234

35+
def _convert_to_cache_key(self, key: Union[Sequence[int], LlamaCacheKey]) -> LlamaCacheKey:
36+
"""Convert raw tokens to a key if needed"""
37+
if type(key) == LlamaCacheKey:
38+
return key
39+
else:
40+
return LlamaCacheKey(active_lora_adapters=(), tokens=tuple(key))
41+
2342
@property
2443
@abstractmethod
2544
def cache_size(self) -> int:
2645
raise NotImplementedError
2746

2847
def _find_longest_prefix_key(
2948
self,
30-
key: Tuple[int, ...],
31-
) -> Optional[Tuple[int, ...]]:
49+
key: LlamaCacheKey,
50+
) -> Optional[LlamaCacheKey]:
51+
"""Find the cached key with the longest matching token prefix. A match also requires that the active
52+
LoRA adapters match exactly.
53+
54+
Args:
55+
key (LlamaCacheKey): The key to find a prefix match for.
56+
57+
Returns:
58+
Optional[LlamaCacheKey]: The key with the longest matching prefix, or None if no match found.
59+
"""
3260
pass
3361

3462
@abstractmethod
35-
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
63+
def __getitem__(self, key: Union[Sequence[int], LlamaCacheKey]) -> "llama_cpp.llama.LlamaState":
64+
"""Retrieve a cached state by key, matching on the longest common token prefix. A match also requires
65+
that the active LoRA adapters match exactly.
66+
67+
Args:
68+
key: Key to look up. Raw token sequences are supported for backwards compatibility
69+
and assume no active LoRA adapters.
70+
71+
Returns:
72+
llama_cpp.llama.LlamaState: The cached state for the entry sharing the longest token prefix.
73+
74+
Raises:
75+
KeyError: If no prefix match is found.
76+
"""
3677
raise NotImplementedError
3778

3879
@abstractmethod
39-
def __contains__(self, key: Sequence[int]) -> bool:
80+
def __contains__(self, key: Union[Sequence[int], LlamaCacheKey]) -> bool:
81+
"""Check if any cached key shares a token prefix with the given key.
82+
83+
Args:
84+
key: Key to look up. Raw token sequences are supported for backwards compatibility
85+
and assume no active LoRA adapters.
86+
87+
Returns:
88+
bool: True if any cached key shares a token prefix with this key.
89+
"""
4090
raise NotImplementedError
4191

4292
@abstractmethod
4393
def __setitem__(
44-
self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"
94+
self, key: Union[Sequence[int], LlamaCacheKey], value: "llama_cpp.llama.LlamaState"
4595
) -> None:
46-
raise NotImplementedError
96+
"""Store a state keyed on its tokens and information about active LoRA adapters.
4797
98+
Args:
99+
key: Key to store. Raw token sequences are supported for backwards compatibility
100+
and assume no active LoRA adapters
101+
value: The state to cache
102+
"""
103+
raise NotImplementedError
48104

49105
class LlamaRAMCache(BaseLlamaCache):
50106
"""Cache for a llama.cpp model using RAM."""
@@ -53,7 +109,7 @@ def __init__(self, capacity_bytes: int = (2 << 30)):
53109
super().__init__(capacity_bytes)
54110
self.capacity_bytes = capacity_bytes
55111
self.cache_state: OrderedDict[
56-
Tuple[int, ...], "llama_cpp.llama.LlamaState"
112+
LlamaCacheKey, "llama_cpp.llama.LlamaState"
57113
] = OrderedDict()
58114

59115
@property
@@ -62,34 +118,33 @@ def cache_size(self):
62118

63119
def _find_longest_prefix_key(
64120
self,
65-
key: Tuple[int, ...],
66-
) -> Optional[Tuple[int, ...]]:
121+
key: LlamaCacheKey,
122+
) -> Optional[LlamaCacheKey]:
67123
min_len = 0
68-
min_key = None
69-
keys = (
70-
(k, llama_cpp.llama.Llama.longest_token_prefix(k, key))
71-
for k in self.cache_state.keys()
72-
)
73-
for k, prefix_len in keys:
124+
min_key: Optional[LlamaCacheKey] = None
125+
for k in self.cache_state.keys():
126+
if k.active_lora_adapters != key.active_lora_adapters: continue
127+
if len(k.tokens) < min_len: continue # Optimization
128+
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k.tokens, key.tokens)
74129
if prefix_len > min_len:
75130
min_len = prefix_len
76131
min_key = k
77132
return min_key
78133

79-
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
80-
key = tuple(key)
134+
def __getitem__(self, key: Union[Sequence[int], LlamaCacheKey]) -> "llama_cpp.llama.LlamaState":
135+
key = self._convert_to_cache_key(key)
81136
_key = self._find_longest_prefix_key(key)
82137
if _key is None:
83138
raise KeyError("Key not found")
84139
value = self.cache_state[_key]
85140
self.cache_state.move_to_end(_key)
86141
return value
87142

88-
def __contains__(self, key: Sequence[int]) -> bool:
143+
def __contains__(self, key: Union[Sequence[int], LlamaCacheKey]) -> bool:
89144
return self._find_longest_prefix_key(tuple(key)) is not None
90145

91-
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
92-
key = tuple(key)
146+
def __setitem__(self, key: Union[Sequence[int], LlamaCacheKey], value: "llama_cpp.llama.LlamaState"):
147+
key = self._convert_to_cache_key(key)
93148
if key in self.cache_state:
94149
del self.cache_state[key]
95150
self.cache_state[key] = value
@@ -116,19 +171,24 @@ def cache_size(self):
116171

117172
def _find_longest_prefix_key(
118173
self,
119-
key: Tuple[int, ...],
120-
) -> Optional[Tuple[int, ...]]:
174+
key: LlamaCacheKey,
175+
) -> Optional[LlamaCacheKey]:
121176
min_len = 0
122177
min_key: Optional[Tuple[int, ...]] = None
123178
for k in self.cache.iterkeys(): # type: ignore
124-
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key)
179+
if not isinstance(k, LlamaCacheKey):
180+
print("LlamaDiskCache: Disk cache keys must be LlamaCacheKey objects: skipping")
181+
continue
182+
if k.active_lora_adapters != key.active_lora_adapters: continue
183+
if len(k.tokens) < min_len: continue # Optimization
184+
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k.tokens, key.tokens)
125185
if prefix_len > min_len:
126186
min_len = prefix_len
127-
min_key = k # type: ignore
187+
min_key = k
128188
return min_key
129189

130-
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
131-
key = tuple(key)
190+
def __getitem__(self, key: Union[Sequence[int], LlamaCacheKey]) -> "llama_cpp.llama.LlamaState":
191+
key = self._convert_to_cache_key(key)
132192
_key = self._find_longest_prefix_key(key)
133193
if _key is None:
134194
raise KeyError("Key not found")
@@ -138,12 +198,12 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
138198
# self.cache.push(_key, side="front") # type: ignore
139199
return value
140200

141-
def __contains__(self, key: Sequence[int]) -> bool:
142-
return self._find_longest_prefix_key(tuple(key)) is not None
201+
def __contains__(self, key: Union[Sequence[int], LlamaCacheKey]) -> bool:
202+
return self._find_longest_prefix_key(self._convert_to_cache_key(key)) is not None
143203

144-
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
204+
def __setitem__(self, key: Union[Sequence[int], LlamaCacheKey], value: "llama_cpp.llama.LlamaState"):
145205
print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
146-
key = tuple(key)
206+
key = self._convert_to_cache_key(key)
147207
if key in self.cache:
148208
print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
149209
del self.cache[key]

0 commit comments

Comments
 (0)