30
30
from vllm .sequence import IntermediateTensors
31
31
32
32
# yapf: disable
33
+ from .idefics2_vision_model import Idefics2VisionConfig
33
34
from .idefics2_vision_model import (
34
35
Idefics2VisionTransformer as Idefics3VisionTransformer )
35
36
# yapf: enable
@@ -50,6 +51,53 @@ class AriaImagePixelInputs(TypedDict):
50
51
"""
51
52
52
53
54
+ class AriaVisionTransformer (Idefics3VisionTransformer ):
55
+
56
+ def __init__ (
57
+ self ,
58
+ config : Idefics2VisionConfig ,
59
+ quant_config : Optional [QuantizationConfig ] = None ,
60
+ prefix : str = "" ,
61
+ ) -> None :
62
+ super ().__init__ (config , quant_config , prefix )
63
+ # Unlike Idefics3VisionTransformer which uses LayerNorm after the
64
+ # final layer, Aria omits this normalization, so we replace it with an
65
+ # Identity layer
66
+ self .post_layernorm = nn .Identity ()
67
+
68
+ def load_weights (self , weights : Iterable [Tuple [str ,
69
+ torch .Tensor ]]) -> Set [str ]:
70
+ stacked_params_mapping = [
71
+ # (param_name, shard_name, shard_id)
72
+ ("qkv_proj" , "q_proj" , "q" ),
73
+ ("qkv_proj" , "k_proj" , "k" ),
74
+ ("qkv_proj" , "v_proj" , "v" ),
75
+ ]
76
+ params_dict = dict (self .named_parameters ())
77
+ loaded_params : Set [str ] = set ()
78
+ for name , loaded_weight in weights :
79
+
80
+ # NOTE: post_layernorm is not used in Aria
81
+ if "post_layernorm" in name :
82
+ continue
83
+
84
+ for param_name , weight_name , shard_id in stacked_params_mapping :
85
+ if weight_name not in name :
86
+ continue
87
+ name = name .replace (weight_name , param_name )
88
+ param = params_dict [name ]
89
+ weight_loader = param .weight_loader
90
+ weight_loader (param , loaded_weight , shard_id )
91
+ break
92
+ else :
93
+ param = params_dict [name ]
94
+ weight_loader = getattr (param , "weight_loader" ,
95
+ default_weight_loader )
96
+ weight_loader (param , loaded_weight )
97
+ loaded_params .add (name )
98
+ return loaded_params
99
+
100
+
53
101
class AriaProjectorMLP (nn .Module ):
54
102
55
103
def __init__ (
@@ -228,8 +276,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
228
276
router_output = torch .nn .functional .linear (hidden_states ,
229
277
self .router_weight )
230
278
279
+ hidden_states_copy = hidden_states .clone ()
280
+ # NOTE: hidden_states will be modified inplace by `FusedMoE`
231
281
sparse_expert_output = self .experts (hidden_states , router_output )
232
- shared_expert_output = self .shared_experts (hidden_states )
282
+ shared_expert_output = self .shared_experts (hidden_states_copy )
233
283
234
284
return sparse_expert_output + shared_expert_output
235
285
@@ -445,7 +495,7 @@ def __init__(
445
495
quant_config = vllm_config .quant_config
446
496
447
497
self .config = config
448
- self .vision_tower = Idefics3VisionTransformer (
498
+ self .vision_tower = AriaVisionTransformer (
449
499
config .vision_config ,
450
500
quant_config ,
451
501
prefix = f"{ prefix } .vision_tower" ,
0 commit comments