Skip to content

Commit a3a3ee4

Browse files
authored
[Misc] Merge bitsandbytes_stacked_params_mapping and packed_modules_mapping (#11924)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 87054a5 commit a3a3ee4

24 files changed

+49
-200
lines changed

vllm/model_executor/model_loader/loader.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
from vllm.model_executor.model_loader.tensorizer import (
4040
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
4141
serialize_vllm_model, tensorizer_weights_iterator)
42-
from vllm.model_executor.model_loader.utils import (get_model_architecture,
42+
from vllm.model_executor.model_loader.utils import (ParamMapping,
43+
get_model_architecture,
4344
set_default_torch_dtype)
4445
from vllm.model_executor.model_loader.weight_utils import (
4546
download_safetensors_index_file_from_hf, download_weights_from_hf,
@@ -983,21 +984,11 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
983984

984985
def _get_bnb_target_modules(self, model: nn.Module) -> None:
985986

986-
# TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
987-
# packed_modules_mapping.
988-
inverse_stacked_mapping: Dict[str, List[str]] = {}
989-
for orig, (
990-
packed,
991-
idx,
992-
) in model.bitsandbytes_stacked_params_mapping.items():
993-
if packed not in inverse_stacked_mapping:
994-
inverse_stacked_mapping[packed] = []
995-
inverse_stacked_mapping[packed].insert(idx, orig)
996-
997987
for name, module in model.named_modules():
998988
if isinstance(module, (LinearBase, )):
999989
last_name = name.split(".")[-1]
1000-
if sub_modules := inverse_stacked_mapping.get(last_name, []):
990+
if sub_modules := self.modules_mapping.packed_mapping.get(
991+
last_name, []):
1001992
# Map vllm's names to transformers's names.
1002993
for sub_name in sub_modules:
1003994
self.target_modules.append(
@@ -1018,15 +1009,19 @@ def _load_weights(self, model_config: ModelConfig,
10181009
"The required method 'load_weights' is not defined in class"
10191010
f" {type(model).__name__}.")
10201011

1021-
if not hasattr(model, "bitsandbytes_stacked_params_mapping"):
1012+
if not hasattr(model, "packed_modules_mapping"):
10221013
raise AttributeError(
10231014
f"Model {type(model).__name__} does not support BitsAndBytes "
1024-
"quantization yet.")
1015+
"quantization yet. No 'packed_modules_mapping' found.")
1016+
1017+
self.modules_mapping = ParamMapping(
1018+
copy.deepcopy(model.packed_modules_mapping))
10251019

10261020
# For some models like Molmo, we need to use hf_to_vllm_mapper
10271021
# to ensure correct loading of weights.
10281022
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
10291023
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
1024+
10301025
# Modules whose weights might have fused on disk
10311026
# we need their output_sizes to make shard in flight correctly with TP
10321027
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
@@ -1109,7 +1104,7 @@ def _load_weights(self, model_config: ModelConfig,
11091104
for shard_name, (
11101105
weight_name,
11111106
index,
1112-
) in model.bitsandbytes_stacked_params_mapping.items():
1107+
) in self.modules_mapping.inverse_packed_mapping.items():
11131108
shard_pos = quant_param_name.find(shard_name)
11141109
# Some models, such as MiniCPM V2.5/2.6, contain both
11151110
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'

vllm/model_executor/model_loader/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Utilities for selecting and loading models."""
22
import contextlib
3-
from typing import Tuple, Type
3+
from dataclasses import dataclass, field
4+
from typing import Dict, List, Tuple, Type
45

56
import torch
67
from torch import nn
@@ -49,3 +50,26 @@ def get_model_architecture(
4950

5051
def get_architecture_class_name(model_config: ModelConfig) -> str:
5152
return get_model_architecture(model_config)[1]
53+
54+
55+
@dataclass
56+
class ParamMapping:
57+
"""
58+
A class to handle parameter mapping for model weight loading.
59+
It creates a bidirectional mapping between packed parameters and their
60+
constituent parts.
61+
"""
62+
packed_mapping: Dict[str, List[str]]
63+
inverse_packed_mapping: Dict[str, Tuple[str,
64+
int]] = field(default_factory=dict)
65+
66+
def __post_init__(self):
67+
for packed_name, sub_params in self.packed_mapping.items():
68+
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
69+
if len(sub_params) == 1 and sub_params[0] == packed_name:
70+
continue
71+
for index, param_name in enumerate(sub_params):
72+
self.inverse_packed_mapping[param_name] = (
73+
packed_name,
74+
index,
75+
)

vllm/model_executor/models/baichuan.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
350350
embedding_modules = {}
351351
embedding_padding_modules = []
352352

353-
# BitandBytes specific attributes
354-
bitsandbytes_stacked_params_mapping = {
355-
# shard_name, weight_name, index
356-
"gate_proj": ("gate_up_proj", 0),
357-
"up_proj": ("gate_up_proj", 1),
358-
}
359-
360353
def __init__(
361354
self,
362355
*,

vllm/model_executor/models/exaone.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -430,14 +430,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
430430
"lm_head": "output_embeddings",
431431
}
432432
embedding_padding_modules = ["lm_head"]
433-
bitsandbytes_stacked_params_mapping = {
434-
# shard_name, weight_name, index
435-
"q_proj": ("qkv_proj", 0),
436-
"k_proj": ("qkv_proj", 1),
437-
"v_proj": ("qkv_proj", 2),
438-
"c_fc_0": ("gate_up_proj", 0),
439-
"c_fc_1": ("gate_up_proj", 1),
440-
}
441433

442434
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
443435
super().__init__()

vllm/model_executor/models/falcon.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,9 @@ def forward(
409409

410410

411411
class FalconForCausalLM(nn.Module, SupportsPP):
412-
413-
# BitandBytes specific attributes
414-
bitsandbytes_stacked_params_mapping = {}
412+
packed_modules_mapping = {
413+
"query_key_value": ["query_key_value"],
414+
}
415415

416416
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
417417
super().__init__()

vllm/model_executor/models/gemma.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -349,15 +349,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
349349
"gate_up_proj",
350350
"down_proj",
351351
]
352-
# BitandBytes specific attributes
353-
bitsandbytes_stacked_params_mapping = {
354-
# shard_name, weight_name, index
355-
"q_proj": ("qkv_proj", 0),
356-
"k_proj": ("qkv_proj", 1),
357-
"v_proj": ("qkv_proj", 2),
358-
"gate_proj": ("gate_up_proj", 0),
359-
"up_proj": ("gate_up_proj", 1),
360-
}
361352

362353
# Gemma does not apply LoRA to the embedding layer.
363354
embedding_modules = {}

vllm/model_executor/models/gemma2.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -399,16 +399,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
399399
embedding_modules = {}
400400
embedding_padding_modules = []
401401

402-
# BitandBytes specific attributes
403-
bitsandbytes_stacked_params_mapping = {
404-
# shard_name, weight_name, index
405-
"q_proj": ("qkv_proj", 0),
406-
"k_proj": ("qkv_proj", 1),
407-
"v_proj": ("qkv_proj", 2),
408-
"gate_proj": ("gate_up_proj", 0),
409-
"up_proj": ("gate_up_proj", 1),
410-
}
411-
412402
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
413403
config = vllm_config.model_config.hf_config
414404
quant_config = vllm_config.quant_config

vllm/model_executor/models/granite.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -362,14 +362,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
362362
"lm_head": "output_embeddings",
363363
}
364364
embedding_padding_modules = ["lm_head"]
365-
bitsandbytes_stacked_params_mapping = {
366-
# shard_name, weight_name, index
367-
"q_proj": ("qkv_proj", 0),
368-
"k_proj": ("qkv_proj", 1),
369-
"v_proj": ("qkv_proj", 2),
370-
"gate_proj": ("gate_up_proj", 0),
371-
"up_proj": ("gate_up_proj", 1),
372-
}
373365

374366
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
375367
super().__init__()

vllm/model_executor/models/idefics3.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -662,16 +662,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
662662
"down_proj",
663663
]
664664

665-
# BitandBytes specific attributes
666-
bitsandbytes_stacked_params_mapping = {
667-
# shard_name, weight_name, index
668-
"q_proj": ("qkv_proj", 0),
669-
"k_proj": ("qkv_proj", 1),
670-
"v_proj": ("qkv_proj", 2),
671-
"gate_proj": ("gate_up_proj", 0),
672-
"up_proj": ("gate_up_proj", 1),
673-
}
674-
675665
embedding_modules = {}
676666
embedding_padding_modules = []
677667

vllm/model_executor/models/llama.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -478,16 +478,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
478478
}
479479
embedding_padding_modules = ["lm_head"]
480480

481-
# BitandBytes specific attributes
482-
bitsandbytes_stacked_params_mapping = {
483-
# shard_name, weight_name, index
484-
"q_proj": ("qkv_proj", 0),
485-
"k_proj": ("qkv_proj", 1),
486-
"v_proj": ("qkv_proj", 2),
487-
"gate_proj": ("gate_up_proj", 0),
488-
"up_proj": ("gate_up_proj", 1),
489-
}
490-
491481
# Mistral/Llama models can also be loaded with --load-format mistral
492482
# from consolidated.safetensors checkpoints
493483
mistral_mapping = {

vllm/model_executor/models/llava.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -463,14 +463,10 @@ def init_vision_tower_for_llava(
463463
info=_build_llava_or_pixtral_hf_info,
464464
dummy_inputs=LlavaDummyInputsBuilder)
465465
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
466-
# BitandBytes specific attributes
467-
bitsandbytes_stacked_params_mapping = {
468-
# shard_name, weight_name, index
469-
"q_proj": ("qkv_proj", 0),
470-
"k_proj": ("qkv_proj", 1),
471-
"v_proj": ("qkv_proj", 2),
472-
"gate_proj": ("gate_up_proj", 0),
473-
"up_proj": ("gate_up_proj", 1),
466+
467+
packed_modules_mapping = {
468+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
469+
"gate_up_proj": ["gate_proj", "up_proj"]
474470
}
475471

476472
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

vllm/model_executor/models/minicpm.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -534,16 +534,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
534534
}
535535
embedding_padding_modules = ["lm_head"]
536536

537-
# BitandBytes specific attributes
538-
bitsandbytes_stacked_params_mapping = {
539-
# shard_name, weight_name, index
540-
"q_proj": ("qkv_proj", 0),
541-
"k_proj": ("qkv_proj", 1),
542-
"v_proj": ("qkv_proj", 2),
543-
"gate_proj": ("gate_up_proj", 0),
544-
"up_proj": ("gate_up_proj", 1),
545-
}
546-
547537
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
548538
super().__init__()
549539
config = vllm_config.model_config.hf_config

vllm/model_executor/models/minicpm3.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,5 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
241241
# `embedding_modules` and `embedding_padding_modules`
242242
# are inherited from MiniCPMForCausalLM
243243

244-
bitsandbytes_stacked_params_mapping = {
245-
# shard_name, weight_name, index
246-
"gate_proj": ("gate_up_proj", 0),
247-
"up_proj": ("gate_up_proj", 1),
248-
}
249-
250244
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
251245
return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)

vllm/model_executor/models/minicpmv.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -761,16 +761,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
761761
"kv_proj",
762762
]
763763

764-
# BitandBytes specific attributes
765-
bitsandbytes_stacked_params_mapping = {
766-
# shard_name, weight_name, index
767-
"q_proj": ("qkv_proj", 0),
768-
"k_proj": ("qkv_proj", 1),
769-
"v_proj": ("qkv_proj", 2),
770-
"gate_proj": ("gate_up_proj", 0),
771-
"up_proj": ("gate_up_proj", 1),
772-
}
773-
774764
embedding_modules = {}
775765
embedding_padding_modules = []
776766

@@ -881,16 +871,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
881871
"kv_proj",
882872
]
883873

884-
# BitandBytes specific attributes
885-
bitsandbytes_stacked_params_mapping = {
886-
# shard_name, weight_name, index
887-
"q_proj": ("qkv_proj", 0),
888-
"k_proj": ("qkv_proj", 1),
889-
"v_proj": ("qkv_proj", 2),
890-
"gate_proj": ("gate_up_proj", 0),
891-
"up_proj": ("gate_up_proj", 1),
892-
}
893-
894874
embedding_modules = {}
895875
embedding_padding_modules = []
896876

vllm/model_executor/models/mllama.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,14 +1107,9 @@ def forward(
11071107
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
11081108
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
11091109
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
1110-
# BitandBytes specific attributes
1111-
bitsandbytes_stacked_params_mapping = {
1112-
# shard_name, weight_name, index
1113-
"q_proj": ("qkv_proj", 0),
1114-
"k_proj": ("qkv_proj", 1),
1115-
"v_proj": ("qkv_proj", 2),
1116-
"gate_proj": ("gate_up_proj", 0),
1117-
"up_proj": ("gate_up_proj", 1),
1110+
packed_modules_mapping = {
1111+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
1112+
"gate_up_proj": ["gate_proj", "up_proj"]
11181113
}
11191114

11201115
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm/model_executor/models/molmo.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,12 +1193,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
11931193
embedding_modules = {}
11941194
embedding_padding_modules = []
11951195

1196-
# BitandBytes specific attributes
1197-
bitsandbytes_stacked_params_mapping = {
1198-
"gate_proj": ("merged_linear", 0),
1199-
"up_proj": ("merged_linear", 1),
1200-
}
1201-
12021196
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
12031197
super().__init__()
12041198
config = vllm_config.model_config.hf_config

vllm/model_executor/models/nemotron.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -395,12 +395,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
395395
"lm_head": "output_embeddings",
396396
}
397397
embedding_padding_modules = ["lm_head"]
398-
bitsandbytes_stacked_params_mapping = {
399-
# shard_name, weight_name, index
400-
"q_proj": ("qkv_proj", 0),
401-
"k_proj": ("qkv_proj", 1),
402-
"v_proj": ("qkv_proj", 2),
403-
}
404398

405399
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
406400
super().__init__()

vllm/model_executor/models/opt.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,9 @@ def forward(
329329

330330

331331
class OPTForCausalLM(nn.Module, SupportsPP):
332-
333-
# BitandBytes specific attributes
334-
bitsandbytes_stacked_params_mapping = {
335-
# shard_name, weight_name, index
336-
"q_proj": ("qkv_proj", 0),
337-
"k_proj": ("qkv_proj", 1),
338-
"v_proj": ("qkv_proj", 2),
332+
packed_modules_mapping = {
333+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
334+
"gate_up_proj": ["gate_proj", "up_proj"]
339335
}
340336

341337
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm/model_executor/models/phi.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
279279
"fc2",
280280
]
281281

282-
# BitandBytes specific attributes
283-
bitsandbytes_stacked_params_mapping = {
284-
# shard_name, weight_name, index
285-
"q_proj": ("qkv_proj", 0),
286-
"k_proj": ("qkv_proj", 1),
287-
"v_proj": ("qkv_proj", 2),
288-
}
289-
290282
embedding_modules = {}
291283
embedding_padding_modules = []
292284

vllm/model_executor/models/phi3.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,3 @@ class Phi3ForCausalLM(LlamaForCausalLM):
1414
"gate_up_proj",
1515
],
1616
}
17-
18-
# BitandBytes specific attributes
19-
# Initialize an empty dict when there is no stacked parameter mapping.
20-
bitsandbytes_stacked_params_mapping = {}

0 commit comments

Comments
 (0)