Skip to content

Commit 8363cd0

Browse files
authored
[Bugfix] Adjust mllama to regional compilation (#15112)
Signed-off-by: Jan Kaniecki <[email protected]>
1 parent 6c5a319 commit 8363cd0

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

vllm/model_executor/models/mllama.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,8 +1070,8 @@ def forward(
10701070
inputs_embeds = self.embed_tokens(input_ids)
10711071
hidden_states = inputs_embeds
10721072

1073-
for decoder_layer in self.layers:
1074-
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
1073+
for idx, decoder_layer in enumerate(self.layers):
1074+
if idx in self.cross_attention_layers:
10751075
if not skip_cross_attention:
10761076
hidden_states = decoder_layer(
10771077
hidden_states=hidden_states,
@@ -1081,16 +1081,13 @@ def forward(
10811081
full_text_row_masked_out_mask=
10821082
full_text_row_masked_out_mask,
10831083
)
1084-
elif isinstance(decoder_layer, LlamaDecoderLayer):
1084+
else:
10851085
hidden_states, residual = decoder_layer(
10861086
positions=positions,
10871087
hidden_states=hidden_states,
10881088
residual=None,
10891089
)
10901090
hidden_states = hidden_states + residual
1091-
else:
1092-
raise ValueError(
1093-
f"Unknown decoder layer type {type(decoder_layer)}")
10941091
hidden_states = self.norm(hidden_states)
10951092
return hidden_states
10961093

@@ -1551,4 +1548,4 @@ def convert_dense_cross_attention_mask_to_tensor(
15511548
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
15521549
mask *= full_text_mask
15531550
# (num_prompt_tokens, num_encoder_tokens)
1554-
return mask
1551+
return mask

0 commit comments

Comments
 (0)