78
78
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).ln_1" : r"vision_model.\1.layers.\2.input_layernorm" ,
79
79
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).ln_2" : r"vision_model.\1.layers.\2.post_attention_layernorm" ,
80
80
r"vision_model.vision_encoder.global_transformer.resblocks.(\d+).(gate_ffn|gate_attn)" : r"vision_model.global_transformer.layers.\1.\2" ,
81
- r'vision_model.vision_encoder.ln_(pre|post).(weight|bias)' : r'vision_model.vision_encoder.ln_ \1.\2' ,
81
+ r'vision_model.vision_encoder.ln_(pre|post).(weight|bias)' : r'vision_model.vision_encoder.layernorm_ \1.\2' ,
82
82
r'vision_model.vision_encoder.positional_embedding\b' : r'vision_model.gated_positional_embedding.embedding' ,
83
- r'vision_model.vision_encoder.gated_positional_embedding\b' : r'vision_model.gated_positional_embedding.tile_embedding' ,
83
+ r'vision_model.vision_encoder.gated_positional_embedding\b' : r'vision_model.gated_positional_embedding.tile_embedding.weight ' ,
84
84
r'vision_model.vision_encoder.gated_positional_embedding_gate' : r'vision_model.gated_positional_embedding.gate' ,
85
+ r"vision_model.vision_encoder.pre_tile_pos_embed.embedding" : r"vision_model.pre_tile_positional_embedding.embedding.weight" ,
86
+ r"vision_model.vision_encoder.post_tile_pos_embed.embedding" : r"vision_model.post_tile_positional_embedding.embedding.weight" ,
87
+ r"vision_model.vision_encoder.pre_tile_pos_embed.gate" : r"vision_model.pre_tile_positional_embedding.gate" ,
88
+ r"vision_model.vision_encoder.post_tile_pos_embed.gate" : r"vision_model.post_tile_positional_embedding.gate" ,
85
89
r"vision_model.vision_encoder.(?=\w)" : r"vision_model." ,
86
90
}
87
91
# fmt: on
@@ -159,6 +163,7 @@ def pre_compute_positional_embedding(embedding):
159
163
aspect_ratio_id = i + 1 # we keep 0 index for padding
160
164
current_embedding = embedding [:height , :width ].reshape (height * width , num_patches , hidden_size )
161
165
precomputed_embeddings [aspect_ratio_id , : height * width ] = current_embedding
166
+ precomputed_embeddings = precomputed_embeddings .flatten (1 )
162
167
return precomputed_embeddings
163
168
164
169
@@ -230,6 +235,7 @@ def write_model(
230
235
num_channels = 3
231
236
# intermediate size: 28672 for 90B, 5120 for 11B
232
237
intermediate_size = compute_intermediate_size (dim , multiple_of = params ["multiple_of" ])
238
+ intermediate_layers_indices = [3 , 7 , 15 , 23 , 30 ] # TODO: Check for 90B model
233
239
234
240
# vision model
235
241
n_layers_vision = 32 # constant
@@ -338,7 +344,9 @@ def write_model(
338
344
elif new_key .endswith ("gate" ):
339
345
state_dict [new_key ] = current_parameter [0 ].view (1 )
340
346
341
- elif "tile_pos_embed.embedding" in new_key or "gated_positional_embedding.tile_embedding" in new_key :
347
+ elif (
348
+ "tile_positional_embedding.embedding" in new_key or "gated_positional_embedding.tile_embedding" in new_key
349
+ ):
342
350
# pre-compute the embeddings
343
351
state_dict [new_key ] = pre_compute_positional_embedding (current_parameter )
344
352
@@ -360,20 +368,20 @@ def write_model(
360
368
# Write configs
361
369
config_parameters = {CONFIG_KEY_MAPPING [key ]: params [key ] for key in CONFIG_KEY_MAPPING .keys ()}
362
370
vision_config = MllamaVisionConfig (
371
+ hidden_size = dim_vision , # Constant, taken directly from your notes
372
+ intermediate_size = dim_vision * 4 ,
363
373
num_hidden_layers = n_layers_vision ,
364
- vision_input_dim = dim_vision , # Constant, taken directly from your notes
365
- return_intermediate = [3 , 7 , 15 , 23 , 30 ], # Based on return_intermediate indices
366
- num_global_layers = n_layers_vision_global ,
367
- vision_chunk_size = params ["vision_chunk_size" ],
368
374
num_attention_heads = n_heads_vision ,
375
+ num_global_layers = n_layers_vision_global ,
376
+ intermediate_layers_indices = intermediate_layers_indices , # Based on return_intermediate indices
377
+ image_size = params ["vision_chunk_size" ],
369
378
max_num_tiles = 4 ,
370
379
supported_aspect_ratios = get_all_supported_aspect_ratios (4 ),
371
380
)
372
381
text_config = MllamaTextConfig (
373
382
** config_parameters ,
374
383
num_hidden_layers = len (cross_layer_shift ) + n_layers ,
375
384
cross_attention_layers = cross_layer_shift ,
376
- vision_input_dim = dim_vision , # Constant, aligned with vision config
377
385
attention_bias = False , # Constant set to False
378
386
tie_word_embeddings = False , # Constant set to False
379
387
intermediate_size = intermediate_size ,
0 commit comments