30
30
from vllm .sequence import IntermediateTensors
31
31
32
32
# yapf: disable
33
+ from .idefics2_vision_model import Idefics2VisionConfig
33
34
from .idefics2_vision_model import (
34
35
Idefics2VisionTransformer as Idefics3VisionTransformer )
35
36
# yapf: enable
@@ -50,6 +51,50 @@ class AriaImagePixelInputs(TypedDict):
50
51
"""
51
52
52
53
54
+ class AriaVisionTransformer (Idefics3VisionTransformer ):
55
+
56
+ def __init__ (
57
+ self ,
58
+ config : Idefics2VisionConfig ,
59
+ quant_config : Optional [QuantizationConfig ] = None ,
60
+ prefix : str = "" ,
61
+ ) -> None :
62
+ super ().__init__ (config , quant_config , prefix )
63
+ self .post_layernorm = nn .Identity ()
64
+
65
+ def load_weights (self , weights : Iterable [Tuple [str ,
66
+ torch .Tensor ]]) -> Set [str ]:
67
+ stacked_params_mapping = [
68
+ # (param_name, shard_name, shard_id)
69
+ ("qkv_proj" , "q_proj" , "q" ),
70
+ ("qkv_proj" , "k_proj" , "k" ),
71
+ ("qkv_proj" , "v_proj" , "v" ),
72
+ ]
73
+ params_dict = dict (self .named_parameters ())
74
+ loaded_params : Set [str ] = set ()
75
+ for name , loaded_weight in weights :
76
+
77
+ # NOTE: post_layernorm is not used in Aria
78
+ if "post_layernorm" in name :
79
+ continue
80
+
81
+ for param_name , weight_name , shard_id in stacked_params_mapping :
82
+ if weight_name not in name :
83
+ continue
84
+ name = name .replace (weight_name , param_name )
85
+ param = params_dict [name ]
86
+ weight_loader = param .weight_loader
87
+ weight_loader (param , loaded_weight , shard_id )
88
+ break
89
+ else :
90
+ param = params_dict [name ]
91
+ weight_loader = getattr (param , "weight_loader" ,
92
+ default_weight_loader )
93
+ weight_loader (param , loaded_weight )
94
+ loaded_params .add (name )
95
+ return loaded_params
96
+
97
+
53
98
class AriaProjectorMLP (nn .Module ):
54
99
55
100
def __init__ (
@@ -228,8 +273,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
228
273
router_output = torch .nn .functional .linear (hidden_states ,
229
274
self .router_weight )
230
275
276
+ hidden_states_copy = hidden_states .clone ()
277
+ # NOTE: hidden_states will be modified inplace by `FusedMoE`
231
278
sparse_expert_output = self .experts (hidden_states , router_output )
232
- shared_expert_output = self .shared_experts (hidden_states )
279
+ shared_expert_output = self .shared_experts (hidden_states_copy )
233
280
234
281
return sparse_expert_output + shared_expert_output
235
282
@@ -445,7 +492,7 @@ def __init__(
445
492
quant_config = vllm_config .quant_config
446
493
447
494
self .config = config
448
- self .vision_tower = Idefics3VisionTransformer (
495
+ self .vision_tower = AriaVisionTransformer (
449
496
config .vision_config ,
450
497
quant_config ,
451
498
prefix = f"{ prefix } .vision_tower" ,
0 commit comments