Skip to content

[Misc][LoRA] Add PEFTHelper for LoRA #11003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from typing import Dict, List

Expand All @@ -13,6 +14,7 @@
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager)
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager)
Expand All @@ -30,18 +32,68 @@
]


def test_peft_helper(sql_lora_files):
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
with open(lora_config_path) as f:
config = json.load(f)
peft_helper = PEFTHelper.from_dict(config)
assert peft_helper.r == 8
assert peft_helper.lora_alpha == 16
assert peft_helper.target_modules == [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]

expected_error = "vLLM only supports modules_to_save being None."
with pytest.raises(ValueError, match=expected_error):
config = dict(
r=8,
lora_alpha=16,
target_modules=["gate_proj"],
modules_to_save=["lm_head"],
)
PEFTHelper.from_dict(config)
expected_error = "vLLM does not yet support RSLoRA."
with pytest.raises(ValueError, match=expected_error):
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_rslora=True)
PEFTHelper.from_dict(config)

expected_error = "vLLM does not yet support DoRA."
with pytest.raises(ValueError, match=expected_error):
config = dict(r=8,
lora_alpha=16,
target_modules=["gate_proj"],
use_dora=True)
PEFTHelper.from_dict(config)


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors"))
new_embeddings = load_file(
os.path.join(sql_lora_files, "new_embeddings.safetensors"))

lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
with open(lora_config_path) as f:
config = json.load(f)

peft_helper = PEFTHelper.from_dict(config)
lora_model = LoRAModel.from_lora_tensors(
1,
8,
16,
tensors,
device,
peft_helper=peft_helper,
device=device,
embeddings=new_embeddings,
embedding_modules=EMBEDDING_MODULES,
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
Expand Down
18 changes: 18 additions & 0 deletions vllm/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.types

from vllm.lora.peft_helper import PEFTHelper
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -59,6 +60,23 @@ def extra_vocab_size(self) -> int:
return self.embeddings_tensor.shape[
0] if self.embeddings_tensor is not None else 0

@classmethod
def from_config(
cls,
module_name: str,
peft_helper: PEFTHelper,
embeddings_tensor: Optional[torch.Tensor] = None,
) -> "LoRALayerWeights":
return cls(
module_name,
peft_helper.r,
peft_helper.lora_alpha,
None,
None,
None,
embeddings_tensor,
)

@classmethod
def create_dummy_lora_weights(
cls,
Expand Down
42 changes: 17 additions & 25 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
LinearScalingRotaryEmbeddingWithLora,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules,
Expand Down Expand Up @@ -104,14 +105,12 @@ def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
def from_lora_tensors(
cls,
lora_model_id: int,
rank: int,
lora_alpha: int,
tensors: Dict[str, torch.Tensor],
peft_helper: PEFTHelper,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None,
scaling_factor: Optional[float] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
Expand All @@ -135,10 +134,9 @@ def from_lora_tensors(
if pin_memory:
lora_embeddings_tensor = (
lora_embeddings_tensor.pin_memory())
loras[module_name] = LoRALayerWeights(module_name, rank,
lora_alpha, None, None,
None,
lora_embeddings_tensor)
loras[module_name] = LoRALayerWeights.from_config(
module_name, peft_helper, lora_embeddings_tensor)

if is_bias:
loras[module_name].bias = tensor.to(device=device,
dtype=dtype).t()
Expand Down Expand Up @@ -170,7 +168,11 @@ def from_lora_tensors(

for lora in loras.values():
lora.optimize()
return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)

return cls(lora_model_id,
peft_helper.r,
loras,
scaling_factor=peft_helper.vllm_scaling_factor)

@classmethod
def from_local_checkpoint(
Expand Down Expand Up @@ -212,6 +214,9 @@ def from_local_checkpoint(
"new_embeddings.bin")
with open(lora_config_path) as f:
config = json.load(f)

config["vllm_max_position_embeddings"] = max_position_embeddings
peft_helper = PEFTHelper.from_dict(config)
if os.path.isfile(lora_tensor_path):
tensors: Dict[str, torch.Tensor] = {}
# Find unexpected modules.
Expand Down Expand Up @@ -242,7 +247,7 @@ def from_local_checkpoint(
# When a bin file is provided, we rely on config to find unexpected
# modules.
unexpected_modules = []
target_modules = config["target_modules"]
target_modules = peft_helper.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
Expand All @@ -256,7 +261,7 @@ def from_local_checkpoint(
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules and not is_regex_target_modules(
config["target_modules"], expected_lora_modules):
peft_helper.target_modules, expected_lora_modules):
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
Expand All @@ -274,30 +279,17 @@ def from_local_checkpoint(
embeddings = torch.load(new_embeddings_bin_file_path,
map_location=device)

rank = config["r"]
lora_alpha = config["lora_alpha"]
context_length = config.get("context_length", None)
scaling_factor = None
if context_length:
if max_position_embeddings is None:
max_position_embeddings = context_length
scaling_factor = float(
math.ceil(context_length / max_position_embeddings))

return cls.from_lora_tensors(
lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id,
rank=rank,
lora_alpha=lora_alpha,
tensors=tensors,
peft_helper=peft_helper,
device=device,
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
scaling_factor=scaling_factor,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
)
embedding_padding_modules=embedding_padding_modules)


class LoRAModelManager(AdapterModelManager):
Expand Down
70 changes: 70 additions & 0 deletions vllm/lora/peft_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py

import math
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union


@dataclass
class PEFTHelper:
# Required fields
r: int
lora_alpha: int
target_modules: Union[list[str], str]

bias: Literal["none", "all", "lora_only"] = field(default="none")
modules_to_save: Optional[list[str]] = field(default=None)
use_rslora: bool = field(default=False)
use_dora: bool = field(default=False)
# long lora field
context_length: int = field(default=0)
# Extra vllm field, start with 'vllm_' to avoid conflict
vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_scaling_factor: Optional[float] = field(default=None)

def _validate_features(self):
error_msg = []

if self.modules_to_save:
error_msg.append("vLLM only supports modules_to_save being None.")
if self.use_rslora:
error_msg.append("vLLM does not yet support RSLoRA.")

if self.use_dora:
error_msg.append("vLLM does not yet support DoRA.")

if error_msg:
raise ValueError(f"{', '.join(error_msg)}")

def __post_init__(self):
self._validate_features()
if self.context_length:
if self.vllm_max_position_embeddings is None:
self.vllm_max_position_embeddings = self.context_length
self.vllm_scaling_factor = float(
math.ceil(self.context_length /
self.vllm_max_position_embeddings))

@classmethod
def from_dict(cls, config_dict: dict) -> "PEFTHelper":
# Get all field information from the class
class_fields = {f.name: f for f in fields(cls)}
# Check for required fields
required_fields = {
name
for name, f in class_fields.items()
if f.default is MISSING and f.default_factory is MISSING
}

# Identify any missing required fields
missing_fields = required_fields - set(config_dict.keys())
if missing_fields:
raise ValueError(
f"Missing required configuration fields: {missing_fields}")

# Filter out fields that aren't defined in the class
filtered_dict = {
k: v
for k, v in config_dict.items() if k in class_fields
}
return cls(**filtered_dict)
Loading