Skip to content

[Bugfix] Adjust mllama to regional compilation #15112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,8 +1070,8 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds

for decoder_layer in self.layers:
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
for idx, decoder_layer in enumerate(self.layers):
if idx in self.cross_attention_layers:
if not skip_cross_attention:
hidden_states = decoder_layer(
hidden_states=hidden_states,
Expand All @@ -1081,16 +1081,13 @@ def forward(
full_text_row_masked_out_mask=
full_text_row_masked_out_mask,
)
elif isinstance(decoder_layer, LlamaDecoderLayer):
else:
hidden_states, residual = decoder_layer(
positions=positions,
hidden_states=hidden_states,
residual=None,
)
hidden_states = hidden_states + residual
else:
raise ValueError(
f"Unknown decoder layer type {type(decoder_layer)}")
hidden_states = self.norm(hidden_states)
return hidden_states

Expand Down Expand Up @@ -1551,4 +1548,4 @@ def convert_dense_cross_attention_mask_to_tensor(
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
mask *= full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
return mask
return mask