@@ -1070,8 +1070,8 @@ def forward(
1070
1070
inputs_embeds = self .embed_tokens (input_ids )
1071
1071
hidden_states = inputs_embeds
1072
1072
1073
- for decoder_layer in self .layers :
1074
- if isinstance ( decoder_layer , MllamaCrossAttentionDecoderLayer ) :
1073
+ for idx , decoder_layer in enumerate ( self .layers ) :
1074
+ if idx in self . cross_attention_layers :
1075
1075
if not skip_cross_attention :
1076
1076
hidden_states = decoder_layer (
1077
1077
hidden_states = hidden_states ,
@@ -1081,16 +1081,13 @@ def forward(
1081
1081
full_text_row_masked_out_mask =
1082
1082
full_text_row_masked_out_mask ,
1083
1083
)
1084
- elif isinstance ( decoder_layer , LlamaDecoderLayer ) :
1084
+ else :
1085
1085
hidden_states , residual = decoder_layer (
1086
1086
positions = positions ,
1087
1087
hidden_states = hidden_states ,
1088
1088
residual = None ,
1089
1089
)
1090
1090
hidden_states = hidden_states + residual
1091
- else :
1092
- raise ValueError (
1093
- f"Unknown decoder layer type { type (decoder_layer )} " )
1094
1091
hidden_states = self .norm (hidden_states )
1095
1092
return hidden_states
1096
1093
@@ -1551,4 +1548,4 @@ def convert_dense_cross_attention_mask_to_tensor(
1551
1548
full_text_mask = ((mask != ninf ).any (dim = - 1 ).type_as (mask )[..., None ])
1552
1549
mask *= full_text_mask
1553
1550
# (num_prompt_tokens, num_encoder_tokens)
1554
- return mask
1551
+ return mask
0 commit comments