Skip to content

Commit 6003eeb

Browse files
authored
Adjust mllama to regional compilation (vllm-project#999)
This PR involves cherry-pick of vllm-project#15112 from the upstream and a fix for cos_sin preparation in emb layers to match regional compilation. --------- Signed-off-by: Jan Kaniecki <[email protected]>
1 parent ceeedae commit 6003eeb

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

vllm/model_executor/models/mllama.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,11 +1075,11 @@ def forward(
10751075

10761076
if is_hpu:
10771077
for idx, decoder_layer in enumerate(self.layers):
1078-
if isinstance(decoder_layer, LlamaDecoderLayer):
1078+
if idx not in self.cross_attention_layers:
10791079
self.layers[idx].self_attn.rotary_emb.prepare_cos_sin(
10801080
positions)
10811081
for idx, decoder_layer in enumerate(self.layers):
1082-
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
1082+
if idx in self.cross_attention_layers:
10831083
if not skip_cross_attention:
10841084
hidden_states = decoder_layer(
10851085
hidden_states=hidden_states,
@@ -1091,7 +1091,7 @@ def forward(
10911091
kv_cache=kv_caches[idx],
10921092
attn_metadata=attn_metadata,
10931093
)
1094-
elif isinstance(decoder_layer, LlamaDecoderLayer):
1094+
else:
10951095
hidden_states, residual = decoder_layer(
10961096
positions=positions,
10971097
hidden_states=hidden_states,
@@ -1100,9 +1100,6 @@ def forward(
11001100
residual=None,
11011101
)
11021102
hidden_states = hidden_states + residual
1103-
else:
1104-
raise ValueError(
1105-
f"Unknown decoder layer type {type(decoder_layer)}")
11061103
hidden_states = self.norm(hidden_states)
11071104
return hidden_states
11081105

0 commit comments

Comments
 (0)