Skip to content

Commit 73be551

Browse files
jeejeeleeweilong.yu
authored and
weilong.yu
committed
[Misc][LoRA] Add PEFTHelper for LoRA (vllm-project#11003)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent f571614 commit 73be551

File tree

4 files changed

+160
-28
lines changed

4 files changed

+160
-28
lines changed

tests/lora/test_lora_manager.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
from typing import Dict, List
34

@@ -13,6 +14,7 @@
1314
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
1415
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
1516
LRUCacheLoRAModelManager)
17+
from vllm.lora.peft_helper import PEFTHelper
1618
from vllm.lora.request import LoRARequest
1719
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
1820
WorkerLoRAManager)
@@ -30,18 +32,68 @@
3032
]
3133

3234

35+
def test_peft_helper(sql_lora_files):
36+
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
37+
with open(lora_config_path) as f:
38+
config = json.load(f)
39+
peft_helper = PEFTHelper.from_dict(config)
40+
assert peft_helper.r == 8
41+
assert peft_helper.lora_alpha == 16
42+
assert peft_helper.target_modules == [
43+
"q_proj",
44+
"v_proj",
45+
"k_proj",
46+
"o_proj",
47+
"gate_proj",
48+
"up_proj",
49+
"down_proj",
50+
"embed_tokens",
51+
"lm_head",
52+
]
53+
54+
expected_error = "vLLM only supports modules_to_save being None."
55+
with pytest.raises(ValueError, match=expected_error):
56+
config = dict(
57+
r=8,
58+
lora_alpha=16,
59+
target_modules=["gate_proj"],
60+
modules_to_save=["lm_head"],
61+
)
62+
PEFTHelper.from_dict(config)
63+
expected_error = "vLLM does not yet support RSLoRA."
64+
with pytest.raises(ValueError, match=expected_error):
65+
config = dict(r=8,
66+
lora_alpha=16,
67+
target_modules=["gate_proj"],
68+
use_rslora=True)
69+
PEFTHelper.from_dict(config)
70+
71+
expected_error = "vLLM does not yet support DoRA."
72+
with pytest.raises(ValueError, match=expected_error):
73+
config = dict(r=8,
74+
lora_alpha=16,
75+
target_modules=["gate_proj"],
76+
use_dora=True)
77+
PEFTHelper.from_dict(config)
78+
79+
3380
@pytest.mark.parametrize("device", CUDA_DEVICES)
3481
def test_from_lora_tensors(sql_lora_files, device):
3582
tensors = load_file(
3683
os.path.join(sql_lora_files, "adapter_model.safetensors"))
3784
new_embeddings = load_file(
3885
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
86+
87+
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
88+
with open(lora_config_path) as f:
89+
config = json.load(f)
90+
91+
peft_helper = PEFTHelper.from_dict(config)
3992
lora_model = LoRAModel.from_lora_tensors(
4093
1,
41-
8,
42-
16,
4394
tensors,
44-
device,
95+
peft_helper=peft_helper,
96+
device=device,
4597
embeddings=new_embeddings,
4698
embedding_modules=EMBEDDING_MODULES,
4799
embedding_padding_modules=EMBEDDING_PADDING_MODULES)

vllm/lora/lora.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torch.types
66

7+
from vllm.lora.peft_helper import PEFTHelper
78
from vllm.utils import is_pin_memory_available
89

910

@@ -59,6 +60,23 @@ def extra_vocab_size(self) -> int:
5960
return self.embeddings_tensor.shape[
6061
0] if self.embeddings_tensor is not None else 0
6162

63+
@classmethod
64+
def from_config(
65+
cls,
66+
module_name: str,
67+
peft_helper: PEFTHelper,
68+
embeddings_tensor: Optional[torch.Tensor] = None,
69+
) -> "LoRALayerWeights":
70+
return cls(
71+
module_name,
72+
peft_helper.r,
73+
peft_helper.lora_alpha,
74+
None,
75+
None,
76+
None,
77+
embeddings_tensor,
78+
)
79+
6280
@classmethod
6381
def create_dummy_lora_weights(
6482
cls,

vllm/lora/models.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
LinearScalingRotaryEmbeddingWithLora,
2222
LoRAMapping)
2323
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
24+
from vllm.lora.peft_helper import PEFTHelper
2425
from vllm.lora.punica_wrapper import get_punica_wrapper
2526
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
2627
is_regex_target_modules,
@@ -104,14 +105,12 @@ def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
104105
def from_lora_tensors(
105106
cls,
106107
lora_model_id: int,
107-
rank: int,
108-
lora_alpha: int,
109108
tensors: Dict[str, torch.Tensor],
109+
peft_helper: PEFTHelper,
110110
device: str = "cuda",
111111
dtype: Optional[torch.dtype] = None,
112112
embeddings: Optional[Dict[str, torch.Tensor]] = None,
113113
target_embedding_padding: Optional[int] = None,
114-
scaling_factor: Optional[float] = None,
115114
embedding_modules: Optional[Dict[str, str]] = None,
116115
embedding_padding_modules: Optional[List[str]] = None,
117116
) -> "LoRAModel":
@@ -135,10 +134,9 @@ def from_lora_tensors(
135134
if pin_memory:
136135
lora_embeddings_tensor = (
137136
lora_embeddings_tensor.pin_memory())
138-
loras[module_name] = LoRALayerWeights(module_name, rank,
139-
lora_alpha, None, None,
140-
None,
141-
lora_embeddings_tensor)
137+
loras[module_name] = LoRALayerWeights.from_config(
138+
module_name, peft_helper, lora_embeddings_tensor)
139+
142140
if is_bias:
143141
loras[module_name].bias = tensor.to(device=device,
144142
dtype=dtype).t()
@@ -170,7 +168,11 @@ def from_lora_tensors(
170168

171169
for lora in loras.values():
172170
lora.optimize()
173-
return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
171+
172+
return cls(lora_model_id,
173+
peft_helper.r,
174+
loras,
175+
scaling_factor=peft_helper.vllm_scaling_factor)
174176

175177
@classmethod
176178
def from_local_checkpoint(
@@ -212,6 +214,9 @@ def from_local_checkpoint(
212214
"new_embeddings.bin")
213215
with open(lora_config_path) as f:
214216
config = json.load(f)
217+
218+
config["vllm_max_position_embeddings"] = max_position_embeddings
219+
peft_helper = PEFTHelper.from_dict(config)
215220
if os.path.isfile(lora_tensor_path):
216221
tensors: Dict[str, torch.Tensor] = {}
217222
# Find unexpected modules.
@@ -242,7 +247,7 @@ def from_local_checkpoint(
242247
# When a bin file is provided, we rely on config to find unexpected
243248
# modules.
244249
unexpected_modules = []
245-
target_modules = config["target_modules"]
250+
target_modules = peft_helper.target_modules
246251
if not isinstance(target_modules, list):
247252
target_modules = [target_modules]
248253
for module in target_modules:
@@ -256,7 +261,7 @@ def from_local_checkpoint(
256261
# https://github.com/vllm-project/vllm/pull/5909. But there's no
257262
# other better mechanism.
258263
if unexpected_modules and not is_regex_target_modules(
259-
config["target_modules"], expected_lora_modules):
264+
peft_helper.target_modules, expected_lora_modules):
260265
raise ValueError(
261266
f"While loading {lora_dir}, expected"
262267
f" target modules in {expected_lora_modules}"
@@ -274,30 +279,17 @@ def from_local_checkpoint(
274279
embeddings = torch.load(new_embeddings_bin_file_path,
275280
map_location=device)
276281

277-
rank = config["r"]
278-
lora_alpha = config["lora_alpha"]
279-
context_length = config.get("context_length", None)
280-
scaling_factor = None
281-
if context_length:
282-
if max_position_embeddings is None:
283-
max_position_embeddings = context_length
284-
scaling_factor = float(
285-
math.ceil(context_length / max_position_embeddings))
286-
287282
return cls.from_lora_tensors(
288283
lora_model_id=get_lora_id()
289284
if lora_model_id is None else lora_model_id,
290-
rank=rank,
291-
lora_alpha=lora_alpha,
292285
tensors=tensors,
286+
peft_helper=peft_helper,
293287
device=device,
294288
dtype=dtype,
295289
embeddings=embeddings,
296290
target_embedding_padding=target_embedding_padding,
297-
scaling_factor=scaling_factor,
298291
embedding_modules=embedding_modules,
299-
embedding_padding_modules=embedding_padding_modules,
300-
)
292+
embedding_padding_modules=embedding_padding_modules)
301293

302294

303295
class LoRAModelManager(AdapterModelManager):

vllm/lora/peft_helper.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py
2+
3+
import math
4+
from dataclasses import MISSING, dataclass, field, fields
5+
from typing import Literal, Optional, Union
6+
7+
8+
@dataclass
9+
class PEFTHelper:
10+
# Required fields
11+
r: int
12+
lora_alpha: int
13+
target_modules: Union[list[str], str]
14+
15+
bias: Literal["none", "all", "lora_only"] = field(default="none")
16+
modules_to_save: Optional[list[str]] = field(default=None)
17+
use_rslora: bool = field(default=False)
18+
use_dora: bool = field(default=False)
19+
# long lora field
20+
context_length: int = field(default=0)
21+
# Extra vllm field, start with 'vllm_' to avoid conflict
22+
vllm_max_position_embeddings: Optional[int] = field(default=False)
23+
vllm_scaling_factor: Optional[float] = field(default=None)
24+
25+
def _validate_features(self):
26+
error_msg = []
27+
28+
if self.modules_to_save:
29+
error_msg.append("vLLM only supports modules_to_save being None.")
30+
if self.use_rslora:
31+
error_msg.append("vLLM does not yet support RSLoRA.")
32+
33+
if self.use_dora:
34+
error_msg.append("vLLM does not yet support DoRA.")
35+
36+
if error_msg:
37+
raise ValueError(f"{', '.join(error_msg)}")
38+
39+
def __post_init__(self):
40+
self._validate_features()
41+
if self.context_length:
42+
if self.vllm_max_position_embeddings is None:
43+
self.vllm_max_position_embeddings = self.context_length
44+
self.vllm_scaling_factor = float(
45+
math.ceil(self.context_length /
46+
self.vllm_max_position_embeddings))
47+
48+
@classmethod
49+
def from_dict(cls, config_dict: dict) -> "PEFTHelper":
50+
# Get all field information from the class
51+
class_fields = {f.name: f for f in fields(cls)}
52+
# Check for required fields
53+
required_fields = {
54+
name
55+
for name, f in class_fields.items()
56+
if f.default is MISSING and f.default_factory is MISSING
57+
}
58+
59+
# Identify any missing required fields
60+
missing_fields = required_fields - set(config_dict.keys())
61+
if missing_fields:
62+
raise ValueError(
63+
f"Missing required configuration fields: {missing_fields}")
64+
65+
# Filter out fields that aren't defined in the class
66+
filtered_dict = {
67+
k: v
68+
for k, v in config_dict.items() if k in class_fields
69+
}
70+
return cls(**filtered_dict)

0 commit comments

Comments
 (0)