Skip to content

Commit c3a9cb0

Browse files
committed
feat: expose llama.cpp LoRA hot-swapping
1 parent 7ecdd94 commit c3a9cb0

File tree

9 files changed

+333
-116
lines changed

9 files changed

+333
-116
lines changed

Diff for: CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- feat: Add hot-swapping for LoRA adapters
11+
1012
## [0.3.2]
1113

1214
- feat: Update llama.cpp to ggerganov/llama.cpp@74d73dc85cc2057446bf63cc37ff649ae7cebd80

Diff for: docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ High-level Python bindings for llama.cpp.
2222
- __call__
2323
- create_chat_completion
2424
- create_chat_completion_openai_v1
25+
- set_lora_adapter_scale
2526
- set_cache
2627
- save_state
2728
- load_state

Diff for: examples/low_level_api/common.py

+56-18
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import re
44

55
from dataclasses import dataclass, field
6-
from typing import List
7-
8-
# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
6+
from typing import List, Sequence, Tuple
7+
import typing
98

9+
# Based on https://github.com/ggerganov/llama.cpp/blob/master/common/common.cpp
10+
# and https://github.com/ggerganov/llama.cpp/blob/master/common/arg.cpp
1011

1112
@dataclass
1213
class GptParams:
@@ -40,8 +41,8 @@ class GptParams:
4041
input_suffix: str = ""
4142
antiprompt: List[str] = field(default_factory=list)
4243

43-
lora_adapter: str = ""
44-
lora_base: str = ""
44+
lora: List[str] = None
45+
lora_scaled: List[Tuple[str, float]] = None
4546

4647
memory_f16: bool = True
4748
random_prompt: bool = False
@@ -257,16 +258,56 @@ def gpt_params_parse(argv=None):
257258
parser.add_argument(
258259
"--lora",
259260
type=str,
260-
default="",
261-
help="apply LoRA adapter (implies --no-mmap)",
262-
dest="lora_adapter",
263-
)
264-
parser.add_argument(
265-
"--lora-base",
266-
type=str,
267-
default="",
268-
help="optional model to use as a base for the layers modified by the LoRA adapter",
269-
dest="lora_base",
261+
action="append",
262+
default=[],
263+
help="path to LoRA adapter (can be repeated to use multiple adapters)",
264+
metavar="FNAME",
265+
dest="lora",
266+
)
267+
268+
class MultiTupleAction(argparse.Action):
269+
"""Action for handling multiple arguments as tuples with type conversion"""
270+
def __init__(self,
271+
option_strings: Sequence[str],
272+
dest: str,
273+
nargs: int = None,
274+
type: Tuple = None,
275+
metavar: Tuple = None,
276+
**kwargs):
277+
self.tuple_type = type
278+
super().__init__(
279+
option_strings=option_strings,
280+
dest=dest,
281+
type=str, # We will fix
282+
nargs=nargs,
283+
metavar=metavar,
284+
**kwargs
285+
)
286+
287+
def __call__(self, parser, namespace, values, option_string=None):
288+
if len(values) != self.nargs:
289+
parser.error(
290+
f'{option_string} requires {len(self.metavar)} arguments: '
291+
f'{" ".join(self.metavar)}'
292+
)
293+
294+
converted_values = tuple(value_type(value) for value_type, value in zip(typing.get_args(self.tuple_type), values))
295+
# Initialize list if needed
296+
if not hasattr(namespace, self.dest):
297+
setattr(namespace, self.dest, [])
298+
299+
# Add the converted tuple to the list
300+
getattr(namespace, self.dest).append(converted_values)
301+
302+
parser.add_argument(
303+
"--lora-scaled",
304+
action=MultiTupleAction,
305+
nargs=2,
306+
type=Tuple[str, float],
307+
help="path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
308+
metavar=('FNAME', 'SCALE'),
309+
dest='lora_scaled',
310+
default=[],
270311
)
271312

272313
parser.add_argument(
@@ -375,9 +416,6 @@ def gpt_params_parse(argv=None):
375416
delattr(args, "logit_bias_str")
376417
params = GptParams(**vars(args))
377418

378-
if params.lora_adapter:
379-
params.use_mmap = False
380-
381419
if logit_bias_str != None:
382420
for i in logit_bias_str:
383421
if m := re.match(r"(\d+)([-+]\d+)", i):

Diff for: examples/low_level_api/low_level_api_chat_cpp.py

+8-16
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,14 @@ def __init__(self, params: GptParams) -> None:
9393
if self.params.ignore_eos:
9494
self.params.logit_bias[llama_cpp.llama_token_eos()] = -float("inf")
9595

96-
if len(self.params.lora_adapter) > 0:
97-
if (
98-
llama_cpp.llama_apply_lora_from_file(
99-
self.ctx,
100-
self.params.lora_adapter.encode("utf8"),
101-
(
102-
self.params.lora_base.encode("utf8")
103-
if len(self.params.lora_base) > 0
104-
else None
105-
),
106-
self.params.n_threads,
107-
)
108-
!= 0
109-
):
110-
print("error: failed to apply lora adapter")
111-
return
96+
for lora_path, scale in [(pth, 1.0) for pth in self.params.lora] + self.params.lora_scaled:
97+
lora_adapter = llama_cpp.llama_lora_adapter_init(
98+
self.model,
99+
lora_path.encode("utf8"))
100+
if lora_adapter is None:
101+
raise RuntimeError(f"error: failed to load lora adapter '{lora_path}'")
102+
if scale != 0.0:
103+
llama_cpp.llama_lora_adapter_set(self.ctx, lora_adapter, scale)
112104

113105
print(file=sys.stderr)
114106
print(

Diff for: llama_cpp/_internals.py

+54
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,18 @@ def kv_cache_seq_keep(self, seq_id: int):
285285
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
286286
llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift)
287287

288+
def lora_adapter_set(self, adapter: LlamaLoraAdapter, scale: float):
289+
return_code = llama_cpp.llama_lora_adapter_set(self.ctx, adapter.lora_adapter, scale)
290+
if return_code != 0:
291+
raise RuntimeError(f"lora_adapter_set returned {return_code}")
292+
293+
def lora_adapter_remove(self, adapter: LlamaLoraAdapter) -> bool:
294+
return_code = llama_cpp.llama_lora_adapter_remove(self.ctx, adapter.lora_adapter)
295+
return return_code != 0
296+
297+
def lora_adapter_clear(self):
298+
llama_cpp.llama_lora_adapter_clear(self.ctx)
299+
288300
def get_state_size(self) -> int:
289301
return llama_cpp.llama_get_state_size(self.ctx)
290302

@@ -861,3 +873,45 @@ def close(self):
861873

862874
def __del__(self):
863875
self.close()
876+
877+
class LlamaLoraAdapter:
878+
"""Intermediate Python wrapper for a llama.cpp llama_lora_adapter.
879+
NOTE: For stability it's recommended you use the Llama class instead."""
880+
881+
def __init__(
882+
self,
883+
model: LlamaModel,
884+
lora_path: str,
885+
*,
886+
verbose: bool = True,
887+
):
888+
self.model = model
889+
self.lora_path = lora_path
890+
891+
lora_adapter = None
892+
893+
if not os.path.exists(lora_path):
894+
raise ValueError(f"LoRA adapter path does not exist: {lora_path}")
895+
896+
with suppress_stdout_stderr(disable=verbose):
897+
lora_adapter = llama_cpp.llama_lora_adapter_init(
898+
self.model.model,
899+
self.lora_path.encode("utf-8"),
900+
)
901+
902+
if lora_adapter is None:
903+
raise RuntimeError(
904+
f"Failed to initialize LoRA adapter from lora path: {self.lora_path}"
905+
)
906+
907+
# The llama_lora_adapter will be freed by the llama_model as part of its
908+
# lifecycle. The llama_model destructor destroys each llama_lora_adapter,
909+
# and the destructor for llama_lora_adapter calls llama_lora_adapter_free.
910+
# All we do here is clear the wrapped reference when the LlamaModel wrapper
911+
# is closed, so that the LlamaLoraAdapter wrapper reference is cleared to
912+
# when the llama_lora_adapters are freed.
913+
def clear_lora_adapter():
914+
self.lora_adapter = None
915+
self.model._exit_stack.callback(clear_lora_adapter)
916+
917+
self.lora_adapter = lora_adapter

0 commit comments

Comments
 (0)