Skip to content

Commit d56e1ab

Browse files
committed
[Model] use AutoWeightsLoader for stablelm,starcoder2,zamba2
Signed-off-by: rongfu.leng <[email protected]>
1 parent 97ae6d7 commit d56e1ab

File tree

3 files changed

+135
-121
lines changed

3 files changed

+135
-121
lines changed

vllm/model_executor/models/stablelm.py

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from vllm.sequence import IntermediateTensors
4545

4646
from .interfaces import SupportsPP
47-
from .utils import (is_pp_missing_parameter,
47+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4848
make_empty_intermediate_tensors_factory, make_layers,
4949
maybe_prefix)
5050

@@ -253,6 +253,45 @@ def forward(
253253
hidden_states = self.norm(hidden_states)
254254
return hidden_states
255255

256+
def load_weights(self, weights: Iterable[Tuple[str,
257+
torch.Tensor]]) -> Set[str]:
258+
stacked_params_mapping = [
259+
# (param_name, shard_name, shard_id)
260+
("qkv_proj", "q_proj", "q"),
261+
("qkv_proj", "k_proj", "k"),
262+
("qkv_proj", "v_proj", "v"),
263+
("gate_up_proj", "gate_proj", 0),
264+
("gate_up_proj", "up_proj", 1),
265+
]
266+
params_dict = dict(self.named_parameters())
267+
loaded_params: Set[str] = set()
268+
for name, loaded_weight in weights:
269+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
270+
if weight_name not in name:
271+
continue
272+
name = name.replace(weight_name, param_name)
273+
# Skip loading extra bias for GPTQ models.
274+
if name.endswith(".bias") and name not in params_dict:
275+
continue
276+
if is_pp_missing_parameter(name, self):
277+
continue
278+
param = params_dict[name]
279+
weight_loader = param.weight_loader
280+
weight_loader(param, loaded_weight, shard_id)
281+
break
282+
else:
283+
# Skip loading extra bias for GPTQ models.
284+
if name.endswith(".bias") and name not in params_dict:
285+
continue
286+
if is_pp_missing_parameter(name, self):
287+
continue
288+
param = params_dict[name]
289+
weight_loader = getattr(param, "weight_loader",
290+
default_weight_loader)
291+
weight_loader(param, loaded_weight)
292+
loaded_params.add(name)
293+
return loaded_params
294+
256295

257296
class StablelmForCausalLM(nn.Module, SupportsPP):
258297

@@ -308,46 +347,13 @@ def sample(
308347

309348
def load_weights(self, weights: Iterable[Tuple[str,
310349
torch.Tensor]]) -> Set[str]:
311-
stacked_params_mapping = [
312-
# (param_name, shard_name, shard_id)
313-
("qkv_proj", "q_proj", "q"),
314-
("qkv_proj", "k_proj", "k"),
315-
("qkv_proj", "v_proj", "v"),
316-
("gate_up_proj", "gate_proj", 0),
317-
("gate_up_proj", "up_proj", 1),
318-
]
319-
params_dict = dict(self.named_parameters())
320-
loaded_params: Set[str] = set()
321-
for name, loaded_weight in weights:
322-
if "rotary_emb.inv_freq" in name:
323-
continue
324-
if ("rotary_emb.cos_cached" in name
325-
or "rotary_emb.sin_cached" in name):
326-
# Models trained using ColossalAI may include these tensors in
327-
# the checkpoint. Skip them.
328-
continue
329-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
330-
if weight_name not in name:
331-
continue
332-
name = name.replace(weight_name, param_name)
333-
# Skip loading extra bias for GPTQ models.
334-
if name.endswith(".bias") and name not in params_dict:
335-
continue
336-
if is_pp_missing_parameter(name, self):
337-
continue
338-
param = params_dict[name]
339-
weight_loader = param.weight_loader
340-
weight_loader(param, loaded_weight, shard_id)
341-
break
342-
else:
343-
# Skip loading extra bias for GPTQ models.
344-
if name.endswith(".bias") and name not in params_dict:
345-
continue
346-
if is_pp_missing_parameter(name, self):
347-
continue
348-
param = params_dict[name]
349-
weight_loader = getattr(param, "weight_loader",
350-
default_weight_loader)
351-
weight_loader(param, loaded_weight)
352-
loaded_params.add(name)
353-
return loaded_params
350+
loader = AutoWeightsLoader(
351+
self,
352+
# Models trained using ColossalAI may include these tensors in
353+
# the checkpoint. Skip them.
354+
skip_prefixes=[
355+
"rotary_emb.inv_freq", "rotary_emb.cos_cached",
356+
"rotary_emb.sin_cached"
357+
],
358+
)
359+
return loader.load_weights(weights)

vllm/model_executor/models/starcoder2.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from vllm.sequence import IntermediateTensors
4646

4747
from .interfaces import SupportsPP
48-
from .utils import (is_pp_missing_parameter,
48+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4949
make_empty_intermediate_tensors_factory, make_layers,
5050
maybe_prefix)
5151

@@ -256,6 +256,41 @@ def forward(
256256
hidden_states = self.norm(hidden_states)
257257
return hidden_states
258258

259+
def load_weights(self, weights: Iterable[Tuple[str,
260+
torch.Tensor]]) -> Set[str]:
261+
stacked_params_mapping = [
262+
# (param_name, shard_name, shard_id)
263+
("qkv_proj", "q_proj", "q"),
264+
("qkv_proj", "k_proj", "k"),
265+
("qkv_proj", "v_proj", "v"),
266+
]
267+
268+
params_dict = dict(self.named_parameters(remove_duplicate=False))
269+
loaded_params: Set[str] = set()
270+
for name, loaded_weight in weights:
271+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
272+
if weight_name not in name:
273+
continue
274+
name = name.replace(weight_name, param_name)
275+
if is_pp_missing_parameter(name, self):
276+
continue
277+
param = params_dict[name]
278+
weight_loader = param.weight_loader
279+
weight_loader(param, loaded_weight, shard_id)
280+
break
281+
else:
282+
name = maybe_remap_kv_scale_name(name, params_dict)
283+
if name is None:
284+
continue
285+
if is_pp_missing_parameter(name, self):
286+
continue
287+
param = params_dict[name]
288+
weight_loader = getattr(param, "weight_loader",
289+
default_weight_loader)
290+
weight_loader(param, loaded_weight)
291+
loaded_params.add(name)
292+
return loaded_params
293+
259294

260295
class Starcoder2ForCausalLM(nn.Module, SupportsPP):
261296

@@ -319,41 +354,12 @@ def sample(
319354

320355
def load_weights(self, weights: Iterable[Tuple[str,
321356
torch.Tensor]]) -> Set[str]:
322-
stacked_params_mapping = [
323-
# (param_name, shard_name, shard_id)
324-
("qkv_proj", "q_proj", "q"),
325-
("qkv_proj", "k_proj", "k"),
326-
("qkv_proj", "v_proj", "v"),
327-
]
328-
329-
params_dict = dict(self.named_parameters(remove_duplicate=False))
330-
loaded_params: Set[str] = set()
331-
for name, loaded_weight in weights:
332-
if "rotary_emb.inv_freq" in name:
333-
continue
334-
335-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
336-
if weight_name not in name:
337-
continue
338-
name = name.replace(weight_name, param_name)
339-
if is_pp_missing_parameter(name, self):
340-
continue
341-
param = params_dict[name]
342-
weight_loader = param.weight_loader
343-
weight_loader(param, loaded_weight, shard_id)
344-
break
345-
else:
346-
name = maybe_remap_kv_scale_name(name, params_dict)
347-
if name is None:
348-
continue
349-
350-
if self.config.tie_word_embeddings and "lm_head.weight" in name:
351-
continue
352-
if is_pp_missing_parameter(name, self):
353-
continue
354-
param = params_dict[name]
355-
weight_loader = getattr(param, "weight_loader",
356-
default_weight_loader)
357-
weight_loader(param, loaded_weight)
358-
loaded_params.add(name)
359-
return loaded_params
357+
loader = AutoWeightsLoader(
358+
self,
359+
# Models trained using ColossalAI may include these tensors in
360+
# the checkpoint. Skip them.
361+
skip_prefixes=([
362+
"rotary_emb.inv_freq", "lm_head.weight"
363+
] if self.config.tie_word_embeddings else ["rotary_emb.inv_freq"]),
364+
)
365+
return loader.load_weights(weights)

vllm/model_executor/models/zamba2.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from vllm.sequence import IntermediateTensors
4040

4141
from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
42-
from .utils import maybe_prefix
42+
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
4343

4444

4545
class Zamba2LoRA(nn.Module):
@@ -777,6 +777,37 @@ def forward(
777777
hidden_states = self.final_layernorm(hidden_states)
778778
return hidden_states
779779

780+
def load_weights(self, weights: Iterable[Tuple[str,
781+
torch.Tensor]]) -> Set[str]:
782+
stacked_params_mapping = [
783+
# (param_name, shard_name, shard_id)
784+
("qkv_proj", "q_proj", "q"),
785+
("qkv_proj", "k_proj", "k"),
786+
("qkv_proj", "v_proj", "v"),
787+
]
788+
789+
params_dict = dict(self.named_parameters())
790+
loaded_params: Set[str] = set()
791+
for chkpt_weight_name, loaded_weight in weights:
792+
for param_name, weight_name, shard_id in stacked_params_mapping:
793+
if weight_name not in chkpt_weight_name:
794+
continue
795+
chkpt_weight_name = chkpt_weight_name.replace(
796+
weight_name, param_name)
797+
param = params_dict[chkpt_weight_name]
798+
weight_loader = param.weight_loader
799+
weight_loader(param, loaded_weight, shard_id)
800+
break
801+
else:
802+
if chkpt_weight_name not in params_dict:
803+
continue
804+
param = params_dict[chkpt_weight_name]
805+
weight_loader = getattr(param, "weight_loader",
806+
default_weight_loader)
807+
weight_loader(param, loaded_weight)
808+
loaded_params.add(chkpt_weight_name)
809+
return loaded_params
810+
780811

781812
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
782813
"""Zamba2 model with causal language modeling head.
@@ -787,6 +818,12 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
787818
- Support for model parallelism and quantization
788819
- Sampling capabilities for text generation
789820
"""
821+
# To ensure correct weight loading and mapping.
822+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
823+
"A_log": "A",
824+
"0.weight": "A.weight",
825+
"1.weight": "B.weight",
826+
})
790827

791828
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
792829
"""Initialize the Zamba2 model for causal language modeling.
@@ -992,40 +1029,5 @@ def sample(
9921029

9931030
def load_weights(self, weights: Iterable[Tuple[str,
9941031
torch.Tensor]]) -> Set[str]:
995-
stacked_params_mapping = [
996-
# (param_name, shard_name, shard_id)
997-
("qkv_proj", "q_proj", "q"),
998-
("qkv_proj", "k_proj", "k"),
999-
("qkv_proj", "v_proj", "v"),
1000-
]
1001-
1002-
weights_dict = {}
1003-
for key, loaded_weight in weights:
1004-
if "A_log" in key:
1005-
key = key.replace("A_log", "A")
1006-
elif "adapter_list" in key:
1007-
key = key.replace("0.weight", "A.weight")
1008-
key = key.replace("1.weight", "B.weight")
1009-
weights_dict[key] = loaded_weight
1010-
1011-
params_dict = dict(self.named_parameters())
1012-
loaded_params: Set[str] = set()
1013-
for chkpt_weight_name, loaded_weight in weights_dict.items():
1014-
for param_name, weight_name, shard_id in stacked_params_mapping:
1015-
if weight_name not in chkpt_weight_name:
1016-
continue
1017-
chkpt_weight_name = chkpt_weight_name.replace(
1018-
weight_name, param_name)
1019-
param = params_dict[chkpt_weight_name]
1020-
weight_loader = param.weight_loader
1021-
weight_loader(param, loaded_weight, shard_id)
1022-
break
1023-
else:
1024-
if chkpt_weight_name not in params_dict:
1025-
continue
1026-
param = params_dict[chkpt_weight_name]
1027-
weight_loader = getattr(param, "weight_loader",
1028-
default_weight_loader)
1029-
weight_loader(param, loaded_weight)
1030-
loaded_params.add(chkpt_weight_name)
1031-
return loaded_params
1032+
loader = AutoWeightsLoader(self)
1033+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

0 commit comments

Comments
 (0)