1
1
import sys
2
2
from abc import ABC , abstractmethod
3
+ from dataclasses import dataclass
3
4
from typing import (
4
5
Optional ,
5
6
Sequence ,
13
14
14
15
from .llama_types import *
15
16
17
+ @dataclass (eq = True , frozen = True )
18
+ class LlamaCacheKey :
19
+ """A key in a LlamaCache. Stores tokens to key by. Also stores
20
+ information about active LoRA adapters, because we need different
21
+ cached values for different active adapters, even for the same tokens."""
22
+ active_lora_adapters : Tuple [Tuple [str , float ], ...]
23
+ tokens : Tuple [int , ...]
24
+
25
+ def __post_init__ (self ):
26
+ if not isinstance (self .tokens , tuple ):
27
+ raise ValueError ("tokens must be a tuple" )
16
28
17
29
class BaseLlamaCache (ABC ):
18
30
"""Base cache class for a llama.cpp model."""
19
31
20
32
def __init__ (self , capacity_bytes : int = (2 << 30 )):
21
33
self .capacity_bytes = capacity_bytes
22
34
35
+ def _convert_to_cache_key (self , key : Union [Sequence [int ], LlamaCacheKey ]) -> LlamaCacheKey :
36
+ """Convert raw tokens to a key if needed"""
37
+ if type (key ) == LlamaCacheKey :
38
+ return key
39
+ else :
40
+ return LlamaCacheKey (active_lora_adapters = (), tokens = tuple (key ))
41
+
23
42
@property
24
43
@abstractmethod
25
44
def cache_size (self ) -> int :
26
45
raise NotImplementedError
27
46
28
47
def _find_longest_prefix_key (
29
48
self ,
30
- key : Tuple [int , ...],
31
- ) -> Optional [Tuple [int , ...]]:
49
+ key : LlamaCacheKey ,
50
+ ) -> Optional [LlamaCacheKey ]:
51
+ """Find the cached key with the longest matching token prefix. A match also requires that the active
52
+ LoRA adapters match exactly.
53
+
54
+ Args:
55
+ key (LlamaCacheKey): The key to find a prefix match for.
56
+
57
+ Returns:
58
+ Optional[LlamaCacheKey]: The key with the longest matching prefix, or None if no match found.
59
+ """
32
60
pass
33
61
34
62
@abstractmethod
35
- def __getitem__ (self , key : Sequence [int ]) -> "llama_cpp.llama.LlamaState" :
63
+ def __getitem__ (self , key : Union [Sequence [int ], LlamaCacheKey ]) -> "llama_cpp.llama.LlamaState" :
64
+ """Retrieve a cached state by key, matching on the longest common token prefix. A match also requires
65
+ that the active LoRA adapters match exactly.
66
+
67
+ Args:
68
+ key: Key to look up. Raw token sequences are supported for backwards compatibility
69
+ and assume no active LoRA adapters.
70
+
71
+ Returns:
72
+ llama_cpp.llama.LlamaState: The cached state for the entry sharing the longest token prefix.
73
+
74
+ Raises:
75
+ KeyError: If no prefix match is found.
76
+ """
36
77
raise NotImplementedError
37
78
38
79
@abstractmethod
39
- def __contains__ (self , key : Sequence [int ]) -> bool :
80
+ def __contains__ (self , key : Union [Sequence [int ], LlamaCacheKey ]) -> bool :
81
+ """Check if any cached key shares a token prefix with the given key.
82
+
83
+ Args:
84
+ key: Key to look up. Raw token sequences are supported for backwards compatibility
85
+ and assume no active LoRA adapters.
86
+
87
+ Returns:
88
+ bool: True if any cached key shares a token prefix with this key.
89
+ """
40
90
raise NotImplementedError
41
91
42
92
@abstractmethod
43
93
def __setitem__ (
44
- self , key : Sequence [int ], value : "llama_cpp.llama.LlamaState"
94
+ self , key : Union [ Sequence [int ], LlamaCacheKey ], value : "llama_cpp.llama.LlamaState"
45
95
) -> None :
46
- raise NotImplementedError
96
+ """Store a state keyed on its tokens and information about active LoRA adapters.
47
97
98
+ Args:
99
+ key: Key to store. Raw token sequences are supported for backwards compatibility
100
+ and assume no active LoRA adapters
101
+ value: The state to cache
102
+ """
103
+ raise NotImplementedError
48
104
49
105
class LlamaRAMCache (BaseLlamaCache ):
50
106
"""Cache for a llama.cpp model using RAM."""
@@ -53,7 +109,7 @@ def __init__(self, capacity_bytes: int = (2 << 30)):
53
109
super ().__init__ (capacity_bytes )
54
110
self .capacity_bytes = capacity_bytes
55
111
self .cache_state : OrderedDict [
56
- Tuple [ int , ...] , "llama_cpp.llama.LlamaState"
112
+ LlamaCacheKey , "llama_cpp.llama.LlamaState"
57
113
] = OrderedDict ()
58
114
59
115
@property
@@ -62,34 +118,33 @@ def cache_size(self):
62
118
63
119
def _find_longest_prefix_key (
64
120
self ,
65
- key : Tuple [ int , ...] ,
66
- ) -> Optional [Tuple [ int , ...] ]:
121
+ key : LlamaCacheKey ,
122
+ ) -> Optional [LlamaCacheKey ]:
67
123
min_len = 0
68
- min_key = None
69
- keys = (
70
- (k , llama_cpp .llama .Llama .longest_token_prefix (k , key ))
71
- for k in self .cache_state .keys ()
72
- )
73
- for k , prefix_len in keys :
124
+ min_key : Optional [LlamaCacheKey ] = None
125
+ for k in self .cache_state .keys ():
126
+ if k .active_lora_adapters != key .active_lora_adapters : continue
127
+ if len (k .tokens ) < min_len : continue # Optimization
128
+ prefix_len = llama_cpp .llama .Llama .longest_token_prefix (k .tokens , key .tokens )
74
129
if prefix_len > min_len :
75
130
min_len = prefix_len
76
131
min_key = k
77
132
return min_key
78
133
79
- def __getitem__ (self , key : Sequence [int ]) -> "llama_cpp.llama.LlamaState" :
80
- key = tuple (key )
134
+ def __getitem__ (self , key : Union [ Sequence [int ], LlamaCacheKey ]) -> "llama_cpp.llama.LlamaState" :
135
+ key = self . _convert_to_cache_key (key )
81
136
_key = self ._find_longest_prefix_key (key )
82
137
if _key is None :
83
138
raise KeyError ("Key not found" )
84
139
value = self .cache_state [_key ]
85
140
self .cache_state .move_to_end (_key )
86
141
return value
87
142
88
- def __contains__ (self , key : Sequence [int ]) -> bool :
143
+ def __contains__ (self , key : Union [ Sequence [int ], LlamaCacheKey ]) -> bool :
89
144
return self ._find_longest_prefix_key (tuple (key )) is not None
90
145
91
- def __setitem__ (self , key : Sequence [int ], value : "llama_cpp.llama.LlamaState" ):
92
- key = tuple (key )
146
+ def __setitem__ (self , key : Union [ Sequence [int ], LlamaCacheKey ], value : "llama_cpp.llama.LlamaState" ):
147
+ key = self . _convert_to_cache_key (key )
93
148
if key in self .cache_state :
94
149
del self .cache_state [key ]
95
150
self .cache_state [key ] = value
@@ -116,19 +171,24 @@ def cache_size(self):
116
171
117
172
def _find_longest_prefix_key (
118
173
self ,
119
- key : Tuple [ int , ...] ,
120
- ) -> Optional [Tuple [ int , ...] ]:
174
+ key : LlamaCacheKey ,
175
+ ) -> Optional [LlamaCacheKey ]:
121
176
min_len = 0
122
177
min_key : Optional [Tuple [int , ...]] = None
123
178
for k in self .cache .iterkeys (): # type: ignore
124
- prefix_len = llama_cpp .llama .Llama .longest_token_prefix (k , key )
179
+ if not isinstance (k , LlamaCacheKey ):
180
+ print ("LlamaDiskCache: Disk cache keys must be LlamaCacheKey objects: skipping" )
181
+ continue
182
+ if k .active_lora_adapters != key .active_lora_adapters : continue
183
+ if len (k .tokens ) < min_len : continue # Optimization
184
+ prefix_len = llama_cpp .llama .Llama .longest_token_prefix (k .tokens , key .tokens )
125
185
if prefix_len > min_len :
126
186
min_len = prefix_len
127
- min_key = k # type: ignore
187
+ min_key = k
128
188
return min_key
129
189
130
- def __getitem__ (self , key : Sequence [int ]) -> "llama_cpp.llama.LlamaState" :
131
- key = tuple (key )
190
+ def __getitem__ (self , key : Union [ Sequence [int ], LlamaCacheKey ]) -> "llama_cpp.llama.LlamaState" :
191
+ key = self . _convert_to_cache_key (key )
132
192
_key = self ._find_longest_prefix_key (key )
133
193
if _key is None :
134
194
raise KeyError ("Key not found" )
@@ -138,12 +198,12 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
138
198
# self.cache.push(_key, side="front") # type: ignore
139
199
return value
140
200
141
- def __contains__ (self , key : Sequence [int ]) -> bool :
142
- return self ._find_longest_prefix_key (tuple (key )) is not None
201
+ def __contains__ (self , key : Union [ Sequence [int ], LlamaCacheKey ]) -> bool :
202
+ return self ._find_longest_prefix_key (self . _convert_to_cache_key (key )) is not None
143
203
144
- def __setitem__ (self , key : Sequence [int ], value : "llama_cpp.llama.LlamaState" ):
204
+ def __setitem__ (self , key : Union [ Sequence [int ], LlamaCacheKey ], value : "llama_cpp.llama.LlamaState" ):
145
205
print ("LlamaDiskCache.__setitem__: called" , file = sys .stderr )
146
- key = tuple (key )
206
+ key = self . _convert_to_cache_key (key )
147
207
if key in self .cache :
148
208
print ("LlamaDiskCache.__setitem__: delete" , file = sys .stderr )
149
209
del self .cache [key ]
0 commit comments