Skip to content

Commit 0240402

Browse files
authored
[Misc]Add BNB quantization for MolmoForCausalLM (#11551)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 55509c2 commit 0240402

File tree

2 files changed

+83
-33
lines changed

2 files changed

+83
-33
lines changed

vllm/model_executor/model_loader/loader.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import warnings
1212
from abc import ABC, abstractmethod
1313
from contextlib import contextmanager
14-
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
14+
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
15+
Tuple, cast)
1516

1617
import gguf
1718
import huggingface_hub
@@ -706,6 +707,8 @@ def __init__(self, load_config: LoadConfig):
706707
# Store all module names (from transformers) that support
707708
# BNB quantization.
708709
self.target_modules: List[str] = []
710+
# mapping weight names from transformers to vllm.
711+
self.weight_mapper: Callable = lambda name: name
709712

710713
def _get_weight_files(
711714
self,
@@ -763,9 +766,12 @@ def _prepare_weights(self, model_name_or_path: str,
763766

764767
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
765768
if use_safetensors:
766-
return safetensors_weights_iterator(hf_weights_files)
769+
iterator = safetensors_weights_iterator(hf_weights_files)
767770
else:
768-
return pt_weights_iterator(hf_weights_files)
771+
iterator = pt_weights_iterator(hf_weights_files)
772+
for name, param in iterator:
773+
# mapping weight names from transformers to vllm.
774+
yield self.weight_mapper(name), param
769775

770776
def _get_quantized_weights_iterator(
771777
self,
@@ -782,12 +788,12 @@ def _get_quantized_weights_iterator(
782788
try:
783789
import bitsandbytes
784790

785-
if bitsandbytes.__version__ < "0.44.0":
791+
if bitsandbytes.__version__ < "0.45.0":
786792
raise ImportError("bitsandbytes version is wrong. Please "
787-
"install bitsandbytes>=0.44.0.")
793+
"install bitsandbytes>=0.45.0.")
788794
except ImportError as err:
789-
raise ImportError("Please install bitsandbytes>=0.44.0 via "
790-
"`pip install bitsandbytes>=0.44.0` to use "
795+
raise ImportError("Please install bitsandbytes>=0.45.0 via "
796+
"`pip install bitsandbytes>=0.45.0` to use "
791797
"bitsandbytes quantizer.") from err
792798

793799
hf_weights_files, use_safetensors = self._prepare_weights(
@@ -991,7 +997,7 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None:
991997
if isinstance(module, (LinearBase, )):
992998
last_name = name.split(".")[-1]
993999
if sub_modules := inverse_stacked_mapping.get(last_name, []):
994-
# Map vllm's names to transformers' names.
1000+
# Map vllm's names to transformers's names.
9951001
for sub_name in sub_modules:
9961002
self.target_modules.append(
9971003
name.replace(last_name, sub_name))
@@ -1013,6 +1019,10 @@ def _load_weights(self, model_config: ModelConfig,
10131019
f"Model {type(model).__name__} does not support BitsAndBytes "
10141020
"quantization yet.")
10151021

1022+
# For some models like Molmo, we need to use hf_to_vllm_mapper
1023+
# to ensure correct loading of weights.
1024+
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
1025+
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
10161026
# Modules whose weights might have fused on disk
10171027
# we need their output_sizes to make shard in flight correctly with TP
10181028
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}

vllm/model_executor/models/molmo.py

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -461,30 +461,71 @@ def forward(
461461
return output
462462

463463

464-
class MolmoMLP(nn.Module):
464+
class SwiGLU(nn.Module):
465+
466+
def forward(self, x: torch.Tensor) -> torch.Tensor:
467+
x, gate = x.chunk(2, dim=-1)
468+
# Note that the order is reversed compared to
469+
# SiluAndMul.
470+
return x * F.silu(gate)
471+
472+
473+
class LanuageModelMLP(nn.Module):
465474
"""Molmo's LLM mlp."""
466475

467476
def __init__(self,
468477
config: PretrainedConfig,
469478
input_dim: Optional[int] = None,
470-
quant_config: Optional[QuantizationConfig] = None,
471-
proj_name: str = "gate_up_proj") -> None:
479+
quant_config: Optional[QuantizationConfig] = None) -> None:
472480
super().__init__()
473481
self.hidden_size = config.hidden_size
474482
self.intermediate_size = config.intermediate_size // 2
475483

476-
# Molmo's LLM proj weights are already merged into the disk, while
477-
# image_projector proj is separate. If the same proj_name were used, it
478-
# would create ambiguity and make it difficult to support BNB and LoRA.
479-
self.proj_name = proj_name
480-
setattr(
481-
self, proj_name,
482-
MergedColumnParallelLinear(
483-
input_dim or self.hidden_size,
484-
[self.intermediate_size] * 2,
485-
bias=False,
486-
quant_config=quant_config,
487-
))
484+
self.gate_up_proj = MergedColumnParallelLinear(
485+
input_dim or self.hidden_size,
486+
[self.intermediate_size] * 2,
487+
bias=False,
488+
quant_config=quant_config,
489+
)
490+
# Activation function.
491+
self.act_fn = SwiGLU()
492+
# Feed-forward output projection.
493+
self.down_proj = RowParallelLinear(
494+
self.intermediate_size,
495+
self.hidden_size,
496+
bias=False,
497+
quant_config=quant_config,
498+
)
499+
500+
def forward(
501+
self,
502+
x: torch.Tensor,
503+
) -> torch.Tensor:
504+
gate_up, _ = self.gate_up_proj(x)
505+
x = self.act_fn(gate_up)
506+
x, _ = self.down_proj(x)
507+
return x
508+
509+
510+
class ImageProjectorMLP(nn.Module):
511+
"""Molmo's image_projector mlp."""
512+
513+
def __init__(
514+
self,
515+
config: PretrainedConfig,
516+
input_dim: Optional[int] = None,
517+
quant_config: Optional[QuantizationConfig] = None,
518+
) -> None:
519+
super().__init__()
520+
self.hidden_size = config.hidden_size
521+
self.intermediate_size = config.intermediate_size // 2
522+
523+
self.merged_linear = MergedColumnParallelLinear(
524+
input_dim or self.hidden_size,
525+
[self.intermediate_size] * 2,
526+
bias=False,
527+
quant_config=quant_config,
528+
)
488529
# Activation function.
489530
self.act_fn = SiluAndMul()
490531

@@ -500,7 +541,7 @@ def forward(
500541
self,
501542
x: torch.Tensor,
502543
) -> torch.Tensor:
503-
gate_up, _ = getattr(self, self.proj_name)(x)
544+
gate_up, _ = self.merged_linear(x)
504545
x = self.act_fn(gate_up)
505546
x, _ = self.down_proj(x)
506547
return x
@@ -523,9 +564,7 @@ def __init__(
523564
prefix=f"{prefix}.self_attn")
524565

525566
# MLP block.
526-
self.mlp = MolmoMLP(config,
527-
quant_config=quant_config,
528-
proj_name="gate_up_proj")
567+
self.mlp = LanuageModelMLP(config, quant_config=quant_config)
529568

530569
# LayerNorm
531570
assert config.layer_norm_type == "rms"
@@ -617,11 +656,10 @@ def __init__(
617656
vision_config,
618657
nlayers=len(self.vit_layers),
619658
quant_config=quant_config)
620-
self.image_projector = MolmoMLP(
659+
self.image_projector = ImageProjectorMLP(
621660
config,
622661
input_dim=vision_config.image_emb_dim,
623662
quant_config=quant_config,
624-
proj_name="merged_linear",
625663
)
626664

627665
image_dim = vision_config.image_emb_dim * len(self.vit_layers)
@@ -842,10 +880,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
842880
loaded_params: Set[str] = set()
843881

844882
for name, loaded_weight in weights:
845-
if "gate_up_proj" in name:
846-
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
847-
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)
848-
849883
if name.endswith(".bias") and name not in params_dict:
850884
continue
851885
if is_pp_missing_parameter(name, self):
@@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
11571191
},
11581192
)
11591193

1194+
# BitandBytes specific attributes
1195+
bitsandbytes_stacked_params_mapping = {
1196+
"gate_proj": ("merged_linear", 0),
1197+
"up_proj": ("merged_linear", 1),
1198+
}
1199+
11601200
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
11611201
super().__init__()
11621202
config = vllm_config.model_config.hf_config

0 commit comments

Comments
 (0)