Skip to content

Commit 196c34b

Browse files
authored
[Misc] Move weights mapper (#11443)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 5c79632 commit 196c34b

File tree

8 files changed

+74
-68
lines changed

8 files changed

+74
-68
lines changed

tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
class MyGemma2Embedding(nn.Module):
16+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
1617

1718
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1819
super().__init__()
@@ -62,8 +63,8 @@ def pooler(
6263
return self._pooler(hidden_states, pooling_metadata)
6364

6465
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
65-
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
66-
weights = hf_to_vllm_mapper.apply(weights)
66+
67+
weights = self.hf_to_vllm_mapper.apply(weights)
6768
weights = ((name, data) for name, data in weights
6869
if not name.startswith("lm_head."))
6970
return self.model.load_weights(weights)

vllm/model_executor/models/aria.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,15 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
521521
This model combines a vision tower, a multi-modal projector, and a language
522522
model to perform tasks that involve both image and text inputs.
523523
"""
524+
hf_to_vllm_mapper = WeightsMapper(
525+
orig_to_new_prefix={
526+
"language_model.model": "language_model",
527+
"language_model.lm_head": "lm_head",
528+
},
529+
orig_to_new_suffix={
530+
"router.weight": "router_weight",
531+
},
532+
)
524533

525534
def __init__(
526535
self,
@@ -662,15 +671,6 @@ def sample(
662671
return next_tokens
663672

664673
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
665-
hf_to_vllm_mapper = WeightsMapper(
666-
orig_to_new_prefix={
667-
"language_model.model": "language_model",
668-
"language_model.lm_head": "lm_head",
669-
},
670-
orig_to_new_suffix={
671-
"router.weight": "router_weight",
672-
},
673-
)
674674

675675
loader = AutoWeightsLoader(self)
676-
loader.load_weights(weights, mapper=hf_to_vllm_mapper)
676+
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

vllm/model_executor/models/bert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ class BertEmbeddingModel(nn.Module):
409409
model: An instance of BertModel used for forward operations.
410410
_pooler: An instance of Pooler used for pooling operations.
411411
"""
412+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
412413

413414
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
414415
super().__init__()
@@ -441,8 +442,7 @@ def pooler(
441442
return self._pooler(hidden_states, pooling_metadata)
442443

443444
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
444-
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
445-
weights = hf_to_vllm_mapper.apply(weights)
445+
weights = self.hf_to_vllm_mapper.apply(weights)
446446
weights = ((name, data) for name, data in weights
447447
if not name.startswith("lm_head."))
448448
self.model.load_weights(weights)

vllm/model_executor/models/molmo.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,6 +1123,34 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
11231123
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
11241124
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
11251125

1126+
hf_to_vllm_mapper = WeightsMapper(
1127+
orig_to_new_substr={
1128+
# vision backbone mapping
1129+
"image_projector.w1.": "image_projector.gate_proj.",
1130+
"image_projector.w3.": "image_projector.up_proj.",
1131+
"image_projector.w2.": "image_projector.down_proj.",
1132+
# language backbone mapping
1133+
"att_proj": "self_attn.qkv_proj",
1134+
"attn_out": "self_attn.o_proj",
1135+
"q_norm": "self_attn.q_norm",
1136+
"k_norm": "self_attn.k_norm",
1137+
"ff_proj": "mlp.gate_up_proj",
1138+
"ff_out": "mlp.down_proj",
1139+
"attn_norm": "input_layernorm",
1140+
"ff_norm": "post_attention_layernorm",
1141+
},
1142+
orig_to_new_prefix={
1143+
# vision backbone mapping
1144+
"model.vision_backbone.": "vision_backbone.",
1145+
# language backbone mapping
1146+
"model.transformer.blocks.": "model.layers.",
1147+
"model.transformer.ln_f.": "model.norm.",
1148+
# lm_head is renamed to model.transformer.mlp.down_proj firstly,
1149+
# we need to run a second renaming for it
1150+
"model.transformer.mlp.down_proj.": "lm_head.",
1151+
},
1152+
)
1153+
11261154
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
11271155
super().__init__()
11281156
config = vllm_config.model_config.hf_config
@@ -1298,36 +1326,10 @@ def sample(
12981326
return next_tokens
12991327

13001328
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1301-
hf_to_vllm_mapper = WeightsMapper(
1302-
orig_to_new_substr={
1303-
# vision backbone mapping
1304-
"image_projector.w1.": "image_projector.gate_proj.",
1305-
"image_projector.w3.": "image_projector.up_proj.",
1306-
"image_projector.w2.": "image_projector.down_proj.",
1307-
# language backbone mapping
1308-
"att_proj": "self_attn.qkv_proj",
1309-
"attn_out": "self_attn.o_proj",
1310-
"q_norm": "self_attn.q_norm",
1311-
"k_norm": "self_attn.k_norm",
1312-
"ff_proj": "mlp.gate_up_proj",
1313-
"ff_out": "mlp.down_proj",
1314-
"attn_norm": "input_layernorm",
1315-
"ff_norm": "post_attention_layernorm",
1316-
},
1317-
orig_to_new_prefix={
1318-
# vision backbone mapping
1319-
"model.vision_backbone.": "vision_backbone.",
1320-
# language backbone mapping
1321-
"model.transformer.blocks.": "model.layers.",
1322-
"model.transformer.ln_f.": "model.norm.",
1323-
# lm_head is renamed to model.transformer.mlp.down_proj firstly,
1324-
# we need to run a second renaming for it
1325-
"model.transformer.mlp.down_proj.": "lm_head.",
1326-
},
1327-
)
1329+
13281330
loader = AutoWeightsLoader(self)
13291331
weights = _get_weights_with_merged_embedding(weights)
1330-
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
1332+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
13311333

13321334

13331335
def _get_weights_with_merged_embedding(

vllm/model_executor/models/phi3v.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,13 @@ def _get_dummy_mm_inputs(
408408
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
409409
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
410410
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
411+
hf_to_vllm_mapper = WeightsMapper(
412+
orig_to_new_prefix={
413+
"model.vision_embed_tokens.wte": "embed_tokens",
414+
"model.vision_embed_tokens.": "vision_embed_tokens.",
415+
"lm_head.": "language_model.lm_head.",
416+
"model.": "language_model.model.",
417+
})
411418

412419
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
413420
super().__init__()
@@ -616,17 +623,10 @@ def sample(
616623

617624
def load_weights(self, weights: Iterable[Tuple[str,
618625
torch.Tensor]]) -> Set[str]:
619-
hf_to_vllm_mapper = WeightsMapper(
620-
orig_to_new_prefix={
621-
"model.vision_embed_tokens.wte": "embed_tokens",
622-
"model.vision_embed_tokens.": "vision_embed_tokens.",
623-
"lm_head.": "language_model.lm_head.",
624-
"model.": "language_model.model.",
625-
})
626626

627627
loader = AutoWeightsLoader(self)
628628
autoloaded_weights = loader.load_weights(weights,
629-
mapper=hf_to_vllm_mapper)
629+
mapper=self.hf_to_vllm_mapper)
630630

631631
# The HF config doesn't specify whether these are tied,
632632
# so we detect it this way

vllm/model_executor/models/qwen2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
529529
embedding_modules = {}
530530
embedding_padding_modules = []
531531

532+
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
533+
532534
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
533535
super().__init__()
534536
config = vllm_config.model_config.hf_config
@@ -577,8 +579,7 @@ def pooler(
577579
return self._pooler(hidden_states, pooling_metadata)
578580

579581
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
580-
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
581-
weights = hf_to_vllm_mapper.apply(weights)
582+
weights = self.hf_to_vllm_mapper.apply(weights)
582583
weights = ((name, data) for name, data in weights
583584
if not name.startswith("lm_head."))
584585
self.model.load_weights(weights)

vllm/model_executor/models/telechat2.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,19 @@
3131

3232
class TeleChat2Model(LlamaModel):
3333

34+
hf_to_vllm_mapper = WeightsMapper(
35+
orig_to_new_prefix={
36+
"transformer.": "model.",
37+
},
38+
orig_to_new_substr={
39+
".h.": ".layers.",
40+
".self_attention.": ".self_attn.",
41+
".word_embeddings.": ".embed_tokens.",
42+
".dense.": ".o_proj.",
43+
".ln_f.": ".norm.",
44+
},
45+
)
46+
3447
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
3548
# 1. Initialize the LlamaModel with bias
3649
vllm_config.model_config.hf_config.bias = True
@@ -111,21 +124,9 @@ def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
111124
def load_weights(self, weights: Iterable[Tuple[str,
112125
torch.Tensor]]) -> Set[str]:
113126

114-
hf_to_vllm_mapper = WeightsMapper(
115-
orig_to_new_prefix={
116-
"transformer.": "model.",
117-
},
118-
orig_to_new_substr={
119-
".h.": ".layers.",
120-
".self_attention.": ".self_attn.",
121-
".word_embeddings.": ".embed_tokens.",
122-
".dense.": ".o_proj.",
123-
".ln_f.": ".norm.",
124-
},
125-
)
126127
loader = AutoWeightsLoader(
127128
self,
128129
skip_prefixes=(["lm_head."]
129130
if self.config.tie_word_embeddings else None),
130131
)
131-
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
132+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

vllm/model_executor/models/ultravox.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ def forward(
302302
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
303303
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
304304

305+
hf_to_vllm_mapper = WeightsMapper(
306+
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
307+
305308
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
306309
super().__init__()
307310
config = vllm_config.model_config.hf_config
@@ -494,9 +497,7 @@ def sample(
494497

495498
def load_weights(self, weights: Iterable[Tuple[str,
496499
torch.Tensor]]) -> Set[str]:
497-
hf_to_vllm_mapper = WeightsMapper(
498-
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
499500

500501
loader = AutoWeightsLoader(self,
501502
ignore_unexpected_prefixes=["audio_tower."])
502-
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
503+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

0 commit comments

Comments
 (0)