@@ -461,30 +461,71 @@ def forward(
461
461
return output
462
462
463
463
464
- class MolmoMLP (nn .Module ):
464
+ class SwiGLU (nn .Module ):
465
+
466
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
467
+ x , gate = x .chunk (2 , dim = - 1 )
468
+ # Note that the order is reversed compared to
469
+ # SiluAndMul.
470
+ return x * F .silu (gate )
471
+
472
+
473
+ class LanuageModelMLP (nn .Module ):
465
474
"""Molmo's LLM mlp."""
466
475
467
476
def __init__ (self ,
468
477
config : PretrainedConfig ,
469
478
input_dim : Optional [int ] = None ,
470
- quant_config : Optional [QuantizationConfig ] = None ,
471
- proj_name : str = "gate_up_proj" ) -> None :
479
+ quant_config : Optional [QuantizationConfig ] = None ) -> None :
472
480
super ().__init__ ()
473
481
self .hidden_size = config .hidden_size
474
482
self .intermediate_size = config .intermediate_size // 2
475
483
476
- # Molmo's LLM proj weights are already merged into the disk, while
477
- # image_projector proj is separate. If the same proj_name were used, it
478
- # would create ambiguity and make it difficult to support BNB and LoRA.
479
- self .proj_name = proj_name
480
- setattr (
481
- self , proj_name ,
482
- MergedColumnParallelLinear (
483
- input_dim or self .hidden_size ,
484
- [self .intermediate_size ] * 2 ,
485
- bias = False ,
486
- quant_config = quant_config ,
487
- ))
484
+ self .gate_up_proj = MergedColumnParallelLinear (
485
+ input_dim or self .hidden_size ,
486
+ [self .intermediate_size ] * 2 ,
487
+ bias = False ,
488
+ quant_config = quant_config ,
489
+ )
490
+ # Activation function.
491
+ self .act_fn = SwiGLU ()
492
+ # Feed-forward output projection.
493
+ self .down_proj = RowParallelLinear (
494
+ self .intermediate_size ,
495
+ self .hidden_size ,
496
+ bias = False ,
497
+ quant_config = quant_config ,
498
+ )
499
+
500
+ def forward (
501
+ self ,
502
+ x : torch .Tensor ,
503
+ ) -> torch .Tensor :
504
+ gate_up , _ = self .gate_up_proj (x )
505
+ x = self .act_fn (gate_up )
506
+ x , _ = self .down_proj (x )
507
+ return x
508
+
509
+
510
+ class ImageProjectorMLP (nn .Module ):
511
+ """Molmo's image_projector mlp."""
512
+
513
+ def __init__ (
514
+ self ,
515
+ config : PretrainedConfig ,
516
+ input_dim : Optional [int ] = None ,
517
+ quant_config : Optional [QuantizationConfig ] = None ,
518
+ ) -> None :
519
+ super ().__init__ ()
520
+ self .hidden_size = config .hidden_size
521
+ self .intermediate_size = config .intermediate_size // 2
522
+
523
+ self .merged_linear = MergedColumnParallelLinear (
524
+ input_dim or self .hidden_size ,
525
+ [self .intermediate_size ] * 2 ,
526
+ bias = False ,
527
+ quant_config = quant_config ,
528
+ )
488
529
# Activation function.
489
530
self .act_fn = SiluAndMul ()
490
531
@@ -500,7 +541,7 @@ def forward(
500
541
self ,
501
542
x : torch .Tensor ,
502
543
) -> torch .Tensor :
503
- gate_up , _ = getattr ( self , self . proj_name ) (x )
544
+ gate_up , _ = self . merged_linear (x )
504
545
x = self .act_fn (gate_up )
505
546
x , _ = self .down_proj (x )
506
547
return x
@@ -523,9 +564,7 @@ def __init__(
523
564
prefix = f"{ prefix } .self_attn" )
524
565
525
566
# MLP block.
526
- self .mlp = MolmoMLP (config ,
527
- quant_config = quant_config ,
528
- proj_name = "gate_up_proj" )
567
+ self .mlp = LanuageModelMLP (config , quant_config = quant_config )
529
568
530
569
# LayerNorm
531
570
assert config .layer_norm_type == "rms"
@@ -617,11 +656,10 @@ def __init__(
617
656
vision_config ,
618
657
nlayers = len (self .vit_layers ),
619
658
quant_config = quant_config )
620
- self .image_projector = MolmoMLP (
659
+ self .image_projector = ImageProjectorMLP (
621
660
config ,
622
661
input_dim = vision_config .image_emb_dim ,
623
662
quant_config = quant_config ,
624
- proj_name = "merged_linear" ,
625
663
)
626
664
627
665
image_dim = vision_config .image_emb_dim * len (self .vit_layers )
@@ -842,10 +880,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
842
880
loaded_params : Set [str ] = set ()
843
881
844
882
for name , loaded_weight in weights :
845
- if "gate_up_proj" in name :
846
- up_proj , gate_proj = loaded_weight .chunk (2 , dim = 0 )
847
- loaded_weight = torch .cat ([gate_proj , up_proj ], dim = 0 )
848
-
849
883
if name .endswith (".bias" ) and name not in params_dict :
850
884
continue
851
885
if is_pp_missing_parameter (name , self ):
@@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
1157
1191
},
1158
1192
)
1159
1193
1194
+ # BitandBytes specific attributes
1195
+ bitsandbytes_stacked_params_mapping = {
1196
+ "gate_proj" : ("merged_linear" , 0 ),
1197
+ "up_proj" : ("merged_linear" , 1 ),
1198
+ }
1199
+
1160
1200
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
1161
1201
super ().__init__ ()
1162
1202
config = vllm_config .model_config .hf_config
0 commit comments