Skip to content

Commit 1bd94df

Browse files
committed
fix: correct Aria model output
Signed-off-by: xffxff <[email protected]>
1 parent 68ad4e3 commit 1bd94df

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

examples/offline_inference/vision_language.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ def run_aria(question: str, modality: str):
2828
llm = LLM(model=model_name,
2929
max_model_len=4096,
3030
max_num_seqs=2,
31+
dtype="bfloat16",
3132
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
3233

33-
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
34+
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
3435
"<|im_end|>\n<|im_start|>assistant\n")
3536

3637
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]

vllm/model_executor/models/aria.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.sequence import IntermediateTensors
3131

3232
# yapf: disable
33+
from .idefics2_vision_model import Idefics2VisionConfig
3334
from .idefics2_vision_model import (
3435
Idefics2VisionTransformer as Idefics3VisionTransformer)
3536
# yapf: enable
@@ -50,6 +51,50 @@ class AriaImagePixelInputs(TypedDict):
5051
"""
5152

5253

54+
class AriaVisionTransformer(Idefics3VisionTransformer):
55+
56+
def __init__(
57+
self,
58+
config: Idefics2VisionConfig,
59+
quant_config: Optional[QuantizationConfig] = None,
60+
prefix: str = "",
61+
) -> None:
62+
super().__init__(config, quant_config, prefix)
63+
self.post_layernorm = nn.Identity()
64+
65+
def load_weights(self, weights: Iterable[Tuple[str,
66+
torch.Tensor]]) -> Set[str]:
67+
stacked_params_mapping = [
68+
# (param_name, shard_name, shard_id)
69+
("qkv_proj", "q_proj", "q"),
70+
("qkv_proj", "k_proj", "k"),
71+
("qkv_proj", "v_proj", "v"),
72+
]
73+
params_dict = dict(self.named_parameters())
74+
loaded_params: Set[str] = set()
75+
for name, loaded_weight in weights:
76+
77+
# NOTE: post_layernorm is not used in Aria
78+
if "post_layernorm" in name:
79+
continue
80+
81+
for param_name, weight_name, shard_id in stacked_params_mapping:
82+
if weight_name not in name:
83+
continue
84+
name = name.replace(weight_name, param_name)
85+
param = params_dict[name]
86+
weight_loader = param.weight_loader
87+
weight_loader(param, loaded_weight, shard_id)
88+
break
89+
else:
90+
param = params_dict[name]
91+
weight_loader = getattr(param, "weight_loader",
92+
default_weight_loader)
93+
weight_loader(param, loaded_weight)
94+
loaded_params.add(name)
95+
return loaded_params
96+
97+
5398
class AriaProjectorMLP(nn.Module):
5499

55100
def __init__(
@@ -228,8 +273,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
228273
router_output = torch.nn.functional.linear(hidden_states,
229274
self.router_weight)
230275

276+
hidden_states_copy = hidden_states.clone()
277+
# NOTE: hidden_states will be modified inplace by `FusedMoE`
231278
sparse_expert_output = self.experts(hidden_states, router_output)
232-
shared_expert_output = self.shared_experts(hidden_states)
279+
shared_expert_output = self.shared_experts(hidden_states_copy)
233280

234281
return sparse_expert_output + shared_expert_output
235282

@@ -445,7 +492,7 @@ def __init__(
445492
quant_config = vllm_config.quant_config
446493

447494
self.config = config
448-
self.vision_tower = Idefics3VisionTransformer(
495+
self.vision_tower = AriaVisionTransformer(
449496
config.vision_config,
450497
quant_config,
451498
prefix=f"{prefix}.vision_tower",

0 commit comments

Comments
 (0)