Skip to content

Commit b49b89c

Browse files
jeejeeleemzusman
authored andcommitted
[Misc][LoRA] Fix LoRA weight mapper (vllm-project#11495)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent a7237ac commit b49b89c

File tree

5 files changed

+31
-29
lines changed

5 files changed

+31
-29
lines changed

tests/lora/test_lora_checkpoints.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_load_checkpoints(
7474
embedding_padding_modules=embed_padding_modules)
7575

7676

77-
def test_lora_weights_mapping(baichuan_lora_files, ):
77+
def test_lora_weights_mapping(baichuan_lora_files):
7878
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
7979
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
8080
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
@@ -86,10 +86,14 @@ def test_lora_weights_mapping(baichuan_lora_files, ):
8686
else:
8787
expected_lora_modules.append(module)
8888

89-
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
90-
"model.": "language_model.model.",
91-
}, )
92-
89+
hf_to_vllm_mapper = WeightsMapper(
90+
orig_to_new_prefix={
91+
"model.": "language_model.model.",
92+
},
93+
orig_to_new_substr={
94+
".layers.": ".baichuan_layers.",
95+
},
96+
)
9397
lora_model = LoRAModel.from_local_checkpoint(
9498
baichuan_lora_files,
9599
expected_lora_modules,
@@ -101,3 +105,4 @@ def test_lora_weights_mapping(baichuan_lora_files, ):
101105
)
102106
for name in lora_model.loras:
103107
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
108+
assert ".baichuan_layers." in name

tests/lora/test_qwen2vl.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
# After fine-tuning with LoRA, all generated content should start begin `A`.
2424
EXPECTED_OUTPUT = [
25-
"A stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501
25+
"A red stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501
2626
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
2727
]
2828

@@ -76,3 +76,7 @@ def test_qwen2vl_lora(qwen2vl_lora_files):
7676
output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1)
7777
for i in range(len(EXPECTED_OUTPUT)):
7878
assert EXPECTED_OUTPUT[i].startswith(output1[i])
79+
80+
output2 = do_sample(llm, qwen2vl_lora_files, lora_id=2)
81+
for i in range(len(EXPECTED_OUTPUT)):
82+
assert EXPECTED_OUTPUT[i].startswith(output2[i])

vllm/lora/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ def from_local_checkpoint(
231231
with safetensors.safe_open(lora_tensor_path,
232232
framework="pt") as f: # type: ignore
233233
for lora_module in f.keys(): # noqa
234-
module_name, _, _ = parse_fine_tuned_lora_name(lora_module)
234+
module_name, _, _ = parse_fine_tuned_lora_name(
235+
lora_module, weights_mapper)
235236
part_name = module_name.split(".")[-1]
236237
if part_name not in expected_lora_modules:
237238
unexpected_modules.append(module_name)

vllm/lora/utils.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
import os
32
import re
43
from typing import List, Optional, Set, Tuple, Type, Union
@@ -32,7 +31,6 @@
3231
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3332
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
3433
from vllm.model_executor.models.utils import WeightsMapper
35-
from vllm.utils import print_warning_once
3634

3735
logger = init_logger(__name__)
3836

@@ -112,36 +110,28 @@ def parse_fine_tuned_lora_name(
112110
is_bias whether the tensor is lora bias.
113111
"""
114112

115-
w_mapper = None
116-
if weights_mapper:
117-
w_mapper = copy.deepcopy(weights_mapper)
118-
# TODO: Currently only supports mapping for prefix, mapping for
119-
# substr and subfix will be supported in the future.
120-
for attr, mapping in [
121-
("orig_to_new_substr", w_mapper.orig_to_new_substr),
122-
("orig_to_new_suffix", w_mapper.orig_to_new_suffix),
123-
]:
124-
if mapping:
125-
print_warning_once(
126-
f"vLLM currently does not support mapping of LoRA weights "
127-
f"for {mapping}.")
128-
setattr(w_mapper, attr, {})
129-
130-
mapper = (lambda name: w_mapper._map_name(name)
131-
if w_mapper is not None else name)
113+
# LoRA weight qualified name always starts with `base_model.model.`,
114+
# so we remove the prefix `base_model.model.` to make the following
115+
# mapping correctly.
116+
if "base_model.model." in name:
117+
name = name.replace("base_model.model.", "")
118+
name = weights_mapper._map_name(name) if weights_mapper else name
119+
# recover the prefix `base_model.model.`
120+
name = "base_model.model." + name
121+
132122
parts = name.split(".")
133123
if parts[-1] == "weight" and (parts[-2] == "lora_A"
134124
or parts[-2] == "lora_B"):
135125
new_name = ".".join(parts[2:-2])
136-
return mapper(new_name), parts[-2] == "lora_A", False
126+
return new_name, parts[-2] == "lora_A", False
137127

138128
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
139129
new_name = ".".join(parts[2:-1])
140-
return mapper(new_name), parts[-1] == "lora_embedding_A", False
130+
return new_name, parts[-1] == "lora_embedding_A", False
141131

142132
if parts[-1] == "bias":
143133
new_name = ".".join(parts[2:-2])
144-
return mapper(new_name), False, True
134+
return new_name, False, True
145135

146136
raise ValueError(f"{name} is unsupported LoRA weight")
147137

vllm/lora/worker_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
9191
packed_modules_mapping[module])
9292
else:
9393
expected_lora_modules.append(module)
94+
95+
expected_lora_modules = list(set(expected_lora_modules))
9496
lora_path = get_adapter_absolute_path(lora_request.lora_path)
9597

9698
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper

0 commit comments

Comments
 (0)