Skip to content

Commit 528dbca

Browse files
authored
[Model][Bugfix]: correct Aria model output (#12309)
Signed-off-by: xffxff <[email protected]>
1 parent cd7b6f0 commit 528dbca

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

examples/offline_inference/vision_language.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ def run_aria(question: str, modality: str):
2828
llm = LLM(model=model_name,
2929
max_model_len=4096,
3030
max_num_seqs=2,
31+
dtype="bfloat16",
3132
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
3233

33-
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
34+
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
3435
"<|im_end|>\n<|im_start|>assistant\n")
3536

3637
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]

vllm/model_executor/models/aria.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.sequence import IntermediateTensors
3131

3232
# yapf: disable
33+
from .idefics2_vision_model import Idefics2VisionConfig
3334
from .idefics2_vision_model import (
3435
Idefics2VisionTransformer as Idefics3VisionTransformer)
3536
# yapf: enable
@@ -50,6 +51,53 @@ class AriaImagePixelInputs(TypedDict):
5051
"""
5152

5253

54+
class AriaVisionTransformer(Idefics3VisionTransformer):
55+
56+
def __init__(
57+
self,
58+
config: Idefics2VisionConfig,
59+
quant_config: Optional[QuantizationConfig] = None,
60+
prefix: str = "",
61+
) -> None:
62+
super().__init__(config, quant_config, prefix)
63+
# Unlike Idefics3VisionTransformer which uses LayerNorm after the
64+
# final layer, Aria omits this normalization, so we replace it with an
65+
# Identity layer
66+
self.post_layernorm = nn.Identity()
67+
68+
def load_weights(self, weights: Iterable[Tuple[str,
69+
torch.Tensor]]) -> Set[str]:
70+
stacked_params_mapping = [
71+
# (param_name, shard_name, shard_id)
72+
("qkv_proj", "q_proj", "q"),
73+
("qkv_proj", "k_proj", "k"),
74+
("qkv_proj", "v_proj", "v"),
75+
]
76+
params_dict = dict(self.named_parameters())
77+
loaded_params: Set[str] = set()
78+
for name, loaded_weight in weights:
79+
80+
# NOTE: post_layernorm is not used in Aria
81+
if "post_layernorm" in name:
82+
continue
83+
84+
for param_name, weight_name, shard_id in stacked_params_mapping:
85+
if weight_name not in name:
86+
continue
87+
name = name.replace(weight_name, param_name)
88+
param = params_dict[name]
89+
weight_loader = param.weight_loader
90+
weight_loader(param, loaded_weight, shard_id)
91+
break
92+
else:
93+
param = params_dict[name]
94+
weight_loader = getattr(param, "weight_loader",
95+
default_weight_loader)
96+
weight_loader(param, loaded_weight)
97+
loaded_params.add(name)
98+
return loaded_params
99+
100+
53101
class AriaProjectorMLP(nn.Module):
54102

55103
def __init__(
@@ -228,8 +276,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
228276
router_output = torch.nn.functional.linear(hidden_states,
229277
self.router_weight)
230278

279+
hidden_states_copy = hidden_states.clone()
280+
# NOTE: hidden_states will be modified inplace by `FusedMoE`
231281
sparse_expert_output = self.experts(hidden_states, router_output)
232-
shared_expert_output = self.shared_experts(hidden_states)
282+
shared_expert_output = self.shared_experts(hidden_states_copy)
233283

234284
return sparse_expert_output + shared_expert_output
235285

@@ -445,7 +495,7 @@ def __init__(
445495
quant_config = vllm_config.quant_config
446496

447497
self.config = config
448-
self.vision_tower = Idefics3VisionTransformer(
498+
self.vision_tower = AriaVisionTransformer(
449499
config.vision_config,
450500
quant_config,
451501
prefix=f"{prefix}.vision_tower",

0 commit comments

Comments
 (0)