Skip to content

Commit bb29b63

Browse files
Isotr0pyrasmith
authored andcommitted
[Misc] Add BNB support to GLM4-V model (vllm-project#12184)
Signed-off-by: Isotr0py <[email protected]>
1 parent c358fa9 commit bb29b63

File tree

3 files changed

+60
-53
lines changed

3 files changed

+60
-53
lines changed

vllm/model_executor/model_loader/loader.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,15 +1105,22 @@ def _load_weights(self, model_config: ModelConfig,
11051105
weight_name,
11061106
index,
11071107
) in self.modules_mapping.inverse_packed_mapping.items():
1108-
shard_pos = quant_param_name.find(shard_name)
11091108
# Some models, such as MiniCPM V2.5/2.6, contain both
11101109
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
11111110
# from being incorrectly identified as being present in
11121111
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
1113-
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
1112+
shard_pos = quant_param_name.find(shard_name)
1113+
can_correct_rename = (shard_pos > 0) and (
1114+
quant_param_name[shard_pos - 1] == ".")
1115+
# If the quant_param_name is packed, it won't occur in the
1116+
# param_dict before renaming.
1117+
new_quant_param_name = quant_param_name.replace(
1118+
shard_name, weight_name)
1119+
need_rename = (quant_param_name not in param_dict) \
1120+
and (new_quant_param_name in param_dict)
1121+
if can_correct_rename and need_rename:
11141122
shard_index = index
1115-
quant_param_name = quant_param_name.replace(
1116-
shard_name, weight_name)
1123+
quant_param_name = new_quant_param_name
11171124
break
11181125

11191126
# Models like Clip/Siglip may skip some layers in initialization,

vllm/model_executor/models/chatglm.py

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from vllm.transformers_utils.configs import ChatGLMConfig
4242

4343
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
44-
from .utils import (is_pp_missing_parameter,
44+
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
4545
make_empty_intermediate_tensors_factory, make_layers,
4646
maybe_prefix)
4747

@@ -605,9 +605,50 @@ def forward(
605605
return IntermediateTensors({"hidden_states": hidden_states})
606606
return hidden_states
607607

608+
def load_weights(self, weights: Iterable[Tuple[str,
609+
torch.Tensor]]) -> Set[str]:
610+
stacked_params_mapping = [
611+
# (param_name, shard_name, shard_id)
612+
("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
613+
("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
614+
]
615+
params_dict = dict(self.named_parameters())
616+
loaded_params: Set[str] = set()
617+
618+
for name, loaded_weight in weights:
619+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
620+
if weight_name not in name:
621+
continue
622+
name = name.replace(weight_name, param_name)
623+
# Skip loading extra bias for GPTQ models.
624+
if name.endswith(".bias") and name not in params_dict:
625+
continue
626+
if is_pp_missing_parameter(name, self):
627+
continue
628+
param = params_dict[name]
629+
weight_loader = param.weight_loader
630+
weight_loader(param, loaded_weight, shard_id)
631+
break
632+
else:
633+
if "rotary_pos_emb.inv_freq" in name:
634+
continue
635+
if name.endswith(".bias") and name not in params_dict:
636+
continue
637+
if is_pp_missing_parameter(name, self):
638+
continue
639+
param = params_dict[name]
640+
weight_loader = getattr(param, "weight_loader",
641+
default_weight_loader)
642+
weight_loader(param, loaded_weight)
643+
loaded_params.add(name)
644+
return loaded_params
645+
608646

609647
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
610648

649+
hf_to_vllm_mapper = WeightsMapper(
650+
orig_to_new_substr={".word_embeddings": ""}, )
651+
611652
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
612653
super().__init__()
613654
config = vllm_config.model_config.hf_config
@@ -660,52 +701,9 @@ def sample(
660701
next_tokens = self.sampler(logits, sampling_metadata)
661702
return next_tokens
662703

663-
def load_weights(self, weights: Iterable[Tuple[str,
664-
torch.Tensor]]) -> Set[str]:
665-
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
666-
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
667-
"transformer.vision.linear_proj.merged_proj.weight": {
668-
"transformer.vision.linear_proj.gate_proj.weight": None,
669-
"transformer.vision.linear_proj.dense_h_to_4h.weight": None,
670-
}
671-
}
672-
673-
params_dict = dict(self.named_parameters(remove_duplicate=False))
674-
loaded_params: Set[str] = set()
675-
for name, loaded_weight in weights:
676-
is_weight_to_be_merge = False
677-
for _, merged_weight_dict in merged_weights_dict.items():
678-
if name in merged_weight_dict:
679-
assert merged_weight_dict[name] is None
680-
merged_weight_dict[name] = loaded_weight
681-
is_weight_to_be_merge = True
682-
if is_weight_to_be_merge:
683-
continue
684-
if "rotary_pos_emb.inv_freq" in name:
685-
continue
686-
if "word_embeddings" in name:
687-
name = name.replace(".word_embeddings", "")
688-
# Skip loading extra bias for GPTQ models.
689-
if name.endswith(".bias") and name not in params_dict:
690-
continue
691-
if is_pp_missing_parameter(name, self):
692-
continue
693-
param = params_dict[name]
694-
weight_loader = getattr(param, "weight_loader",
695-
default_weight_loader)
696-
weight_loader(param, loaded_weight)
697-
loaded_params.add(name)
698-
699-
for combined_name, merged_weight_dict in merged_weights_dict.items():
700-
if combined_name in params_dict:
701-
param = params_dict[combined_name]
702-
combined_weight = torch.cat(list(merged_weight_dict.values()),
703-
dim=0)
704-
weight_loader = getattr(param, "weight_loader",
705-
default_weight_loader)
706-
weight_loader(param, combined_weight)
707-
loaded_params.add(combined_name)
708-
return loaded_params
704+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
705+
loader = AutoWeightsLoader(self)
706+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
709707

710708

711709
class ChatGLM(ChatGLMBaseModel):
@@ -726,6 +724,7 @@ class ChatGLM(ChatGLMBaseModel):
726724

727725

728726
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
727+
729728
packed_modules_mapping = {
730729
"query_key_value": ["query_key_value"],
731730
"dense_h_to_4h": ["dense_h_to_4h"],
@@ -777,7 +776,7 @@ def __new__(
777776
) -> None:
778777
config = vllm_config.model_config.hf_config
779778
# Initialize VL
780-
if hasattr(config, "visual"):
779+
if hasattr(config, "vision_config"):
781780
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
782781
# Initialize LLM
783782
else:

vllm/model_executor/models/glm4_vision_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def forward(self, images: torch.Tensor) -> torch.Tensor:
4242
torch.Tensor
4343
Transformed tensor with shape (B, L, D)
4444
"""
45-
images = images.to(self.proj.weight.device)
45+
images = images.to(device=self.proj.weight.device,
46+
dtype=self.proj.weight.dtype)
4647
x = self.proj(images)
4748
x = x.flatten(2).transpose(1, 2)
4849
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)

0 commit comments

Comments
 (0)