Skip to content

Commit 42cafa2

Browse files
JohnGiorgijeejeelee
authored andcommitted
[Misc][LoRA] Support Rank Stabilized LoRA (RSLoRA) (vllm-project#6909)
Signed-off-by: Jee Jee Li <[email protected]> Co-authored-by: Jee Jee Li <[email protected]>
1 parent f684303 commit 42cafa2

File tree

4 files changed

+30
-22
lines changed

4 files changed

+30
-22
lines changed

tests/lora/test_lora_manager.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import math
23
import os
34
from typing import Dict, List
45

@@ -50,6 +51,18 @@ def test_peft_helper(sql_lora_files):
5051
"embed_tokens",
5152
"lm_head",
5253
]
54+
scaling = peft_helper.lora_alpha / peft_helper.r
55+
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
56+
57+
# test RSLoRA
58+
config = dict(r=8,
59+
lora_alpha=16,
60+
target_modules=["gate_proj"],
61+
use_rslora=True)
62+
peft_helper = PEFTHelper.from_dict(config)
63+
64+
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
65+
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
5366

5467
expected_error = "vLLM only supports modules_to_save being None."
5568
with pytest.raises(ValueError, match=expected_error):
@@ -60,13 +73,6 @@ def test_peft_helper(sql_lora_files):
6073
modules_to_save=["lm_head"],
6174
)
6275
PEFTHelper.from_dict(config)
63-
expected_error = "vLLM does not yet support RSLoRA."
64-
with pytest.raises(ValueError, match=expected_error):
65-
config = dict(r=8,
66-
lora_alpha=16,
67-
target_modules=["gate_proj"],
68-
use_rslora=True)
69-
PEFTHelper.from_dict(config)
7076

7177
expected_error = "vLLM does not yet support DoRA."
7278
with pytest.raises(ValueError, match=expected_error):

vllm/lora/lora.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,9 @@ def from_config(
6767
peft_helper: PEFTHelper,
6868
embeddings_tensor: Optional[torch.Tensor] = None,
6969
) -> "LoRALayerWeights":
70-
return cls(
71-
module_name,
72-
peft_helper.r,
73-
peft_helper.lora_alpha,
74-
None,
75-
None,
76-
None,
77-
embeddings_tensor,
78-
)
70+
return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None,
71+
None, None, embeddings_tensor,
72+
peft_helper.vllm_lora_scaling_factor)
7973

8074
@classmethod
8175
def create_dummy_lora_weights(

vllm/lora/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def from_lora_tensors(
173173
return cls(lora_model_id,
174174
peft_helper.r,
175175
loras,
176-
scaling_factor=peft_helper.vllm_scaling_factor)
176+
scaling_factor=peft_helper.vllm_long_context_scaling_factor)
177177

178178
@classmethod
179179
def from_local_checkpoint(

vllm/lora/peft_helper.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from dataclasses import MISSING, dataclass, field, fields
55
from typing import Literal, Optional, Union
66

7+
from vllm.utils import print_info_once
8+
79

810
@dataclass
911
class PEFTHelper:
@@ -14,21 +16,22 @@ class PEFTHelper:
1416

1517
bias: Literal["none", "all", "lora_only"] = field(default="none")
1618
modules_to_save: Optional[list[str]] = field(default=None)
19+
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
1720
use_rslora: bool = field(default=False)
21+
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
1822
use_dora: bool = field(default=False)
19-
# long lora field
23+
# long context lora field
2024
context_length: int = field(default=0)
2125
# Extra vllm field, start with 'vllm_' to avoid conflict
26+
vllm_lora_scaling_factor: float = field(default=1.0)
2227
vllm_max_position_embeddings: Optional[int] = field(default=False)
23-
vllm_scaling_factor: Optional[float] = field(default=None)
28+
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
2429

2530
def _validate_features(self):
2631
error_msg = []
2732

2833
if self.modules_to_save:
2934
error_msg.append("vLLM only supports modules_to_save being None.")
30-
if self.use_rslora:
31-
error_msg.append("vLLM does not yet support RSLoRA.")
3235

3336
if self.use_dora:
3437
error_msg.append("vLLM does not yet support DoRA.")
@@ -38,10 +41,15 @@ def _validate_features(self):
3841

3942
def __post_init__(self):
4043
self._validate_features()
44+
if self.use_rslora:
45+
print_info_once("Loading LoRA weights trained with rsLoRA.")
46+
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
47+
else:
48+
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
4149
if self.context_length:
4250
if self.vllm_max_position_embeddings is None:
4351
self.vllm_max_position_embeddings = self.context_length
44-
self.vllm_scaling_factor = float(
52+
self.vllm_long_context_scaling_factor = float(
4553
math.ceil(self.context_length /
4654
self.vllm_max_position_embeddings))
4755

0 commit comments

Comments
 (0)