Skip to content

Commit 6a308fc

Browse files
alex-jw-brooksjeejeelee
authored andcommitted
[Bugfix] Fix Lora Name Parsing (vllm-project#17196)
Signed-off-by: Alex-Brooks <[email protected]> Co-authored-by: Jee Jee Li <[email protected]>
1 parent cd39586 commit 6a308fc

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

tests/lora/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ def test_parse_fine_tuned_lora_name_valid():
3939
False,
4040
False,
4141
),
42+
(
43+
"language_model.layers.9.mlp.down_proj.lora_A.weight",
44+
"language_model.layers.9.mlp.down_proj",
45+
True,
46+
False,
47+
),
48+
(
49+
"language_model.layers.9.mlp.down_proj.lora_B.weight",
50+
"language_model.layers.9.mlp.down_proj",
51+
False,
52+
False,
53+
),
4254
}
4355
for name, module_name, is_lora_a, is_bias in fixture:
4456
assert (module_name, is_lora_a,

vllm/lora/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def parse_fine_tuned_lora_name(
114114
is_bias whether the tensor is lora bias.
115115
"""
116116

117-
# LoRA weight qualified name always starts with `base_model.model.`,
117+
# LoRA weight qualified name usually starts with `base_model.model.`,
118118
# so we remove the prefix `base_model.model.` to make the following
119119
# mapping correctly.
120120
if "base_model.model." in name:
@@ -123,18 +123,23 @@ def parse_fine_tuned_lora_name(
123123
# recover the prefix `base_model.model.`
124124
name = "base_model.model." + name
125125

126+
# In some situations, we may not start with `base_model.model.`.
127+
# If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
128+
# we should keep the prefix intact.
129+
start_index = 2 if "base_model.model." in name else 0
130+
126131
parts = name.split(".")
127132
if parts[-1] == "weight" and (parts[-2] == "lora_A"
128133
or parts[-2] == "lora_B"):
129-
new_name = ".".join(parts[2:-2])
134+
new_name = ".".join(parts[start_index:-2])
130135
return new_name, parts[-2] == "lora_A", False
131136

132137
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
133-
new_name = ".".join(parts[2:-1])
138+
new_name = ".".join(parts[start_index:-1])
134139
return new_name, parts[-1] == "lora_embedding_A", False
135140

136141
if parts[-1] == "bias":
137-
new_name = ".".join(parts[2:-2])
142+
new_name = ".".join(parts[start_index:-2])
138143
return new_name, False, True
139144

140145
raise ValueError(f"{name} is unsupported LoRA weight")

0 commit comments

Comments
 (0)