Skip to content

Commit 0350540

Browse files
authored
Merge pull request #10 from huggingface/refactor-configuration
Refactor configuration
2 parents f51ccec + 41a4d8a commit 0350540

13 files changed

+324
-392
lines changed

src/transformers/__init__.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -2690,12 +2690,7 @@
26902690
_import_structure["models.mllama"].extend(
26912691
[
26922692
"MllamaForConditionalGeneration",
2693-
"MllamaPreTrainedModel",
2694-
]
2695-
)
2696-
_import_structure["models.mllama"].extend(
2697-
[
2698-
"MllamaForConditionalGeneration",
2693+
"MllamaForCausalLM",
26992694
"MllamaPreTrainedModel",
27002695
]
27012696
)
@@ -7265,6 +7260,7 @@
72657260
)
72667261
from .models.mllama import (
72677262
MllamaForConditionalGeneration,
7263+
MllamaForCausalLM,
72687264
MllamaPreTrainedModel,
72697265
)
72707266
from .models.mobilebert import (

src/transformers/models/auto/modeling_auto.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,6 @@
323323
("mega", "MegaForMaskedLM"),
324324
("megatron-bert", "MegatronBertForPreTraining"),
325325
("mllama", "MllamaForConditionalGeneration"),
326-
("mllama", "MllamaForConditionalGeneration"),
327326
("mobilebert", "MobileBertForPreTraining"),
328327
("mpnet", "MPNetForMaskedLM"),
329328
("mpt", "MptForCausalLM"),
@@ -496,6 +495,7 @@
496495
("megatron-bert", "MegatronBertForCausalLM"),
497496
("mistral", "MistralForCausalLM"),
498497
("mixtral", "MixtralForCausalLM"),
498+
("mllama", "MllamaForCausalLM"),
499499
("mpt", "MptForCausalLM"),
500500
("musicgen", "MusicgenForCausalLM"),
501501
("musicgen_melody", "MusicgenMelodyForCausalLM"),
@@ -734,7 +734,6 @@
734734
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
735735
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
736736
("mllama", "MllamaForConditionalGeneration"),
737-
("mllama", "MllamaForConditionalGeneration"),
738737
("paligemma", "PaliGemmaForConditionalGeneration"),
739738
("pix2struct", "Pix2StructForConditionalGeneration"),
740739
("qwen2_vl", "Qwen2VLForConditionalGeneration"),

src/transformers/models/mllama/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
else:
3636
_import_structure["modeling_mllama"] = [
3737
"MllamaForConditionalGeneration",
38+
"MllamaForCausalLM",
3839
"MllamaPreTrainedModel",
3940
]
4041

@@ -59,6 +60,7 @@
5960
else:
6061
from .modeling_mllama import (
6162
MllamaForConditionalGeneration,
63+
MllamaForCausalLM,
6264
MllamaPreTrainedModel,
6365
)
6466

src/transformers/models/mllama/configuration_mllama.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -87,41 +87,32 @@ def __init__(
8787
layer_norm_eps=1e-6,
8888
attention_dropout=0.0,
8989
num_global_layers=8,
90-
vision_chunk_size=448,
9190
projection_dim=4096,
92-
vision_input_dim=1280,
9391
vision_output_dim=7680,
94-
return_intermediate=None,
92+
intermediate_layers_indices=[3, 7, 15, 23, 30],
9593
max_num_tiles=4, # same as vision max num chunks? yes ;-)
9694
norm_eps=1.0e-5,
9795
in_channels=3,
9896
supported_aspect_ratios=None,
9997
**kwargs,
10098
):
101-
super().__init__()
99+
super().__init__(**kwargs)
102100
self.hidden_size = hidden_size
101+
self.hidden_act = hidden_act
103102
self.num_hidden_layers = num_hidden_layers
104103
self.intermediate_size = intermediate_size
105104
self.num_channels = num_channels
106105
self.image_size = image_size
107106
self.layer_norm_eps = layer_norm_eps
108107
self.vision_output_dim = vision_output_dim
109-
self.vision_chunk_size = vision_chunk_size
110108
self.patch_size = patch_size
111109
self.projection_dim = projection_dim
112-
self.vision_input_dim = vision_input_dim
113-
if return_intermediate is None:
114-
return_intermediate = [3, 7, 15, 23, 30]
115-
self.return_intermediate = return_intermediate
110+
self.intermediate_layers_indices = intermediate_layers_indices
116111
self.num_global_layers = num_global_layers
117112
self.max_num_tiles = max_num_tiles
118113
self.norm_eps = norm_eps
119114
self.in_channels = in_channels
120-
121-
self.hidden_size = vision_input_dim
122115
self.attention_heads = num_attention_heads
123-
self.intermediate_size = 4 * vision_input_dim
124-
self.hidden_act = hidden_act
125116
self.supported_aspect_ratios = supported_aspect_ratios
126117

127118
@property

src/transformers/models/mllama/convert_mllama_weights_to_hf.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,14 @@
7878
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).ln_1": r"vision_model.\1.layers.\2.input_layernorm",
7979
r"vision_model.vision_encoder.(global_transformer|transformer).resblocks.(\d+).ln_2": r"vision_model.\1.layers.\2.post_attention_layernorm",
8080
r"vision_model.vision_encoder.global_transformer.resblocks.(\d+).(gate_ffn|gate_attn)": r"vision_model.global_transformer.layers.\1.\2",
81-
r'vision_model.vision_encoder.ln_(pre|post).(weight|bias)': r'vision_model.vision_encoder.ln_\1.\2',
81+
r'vision_model.vision_encoder.ln_(pre|post).(weight|bias)': r'vision_model.vision_encoder.layernorm_\1.\2',
8282
r'vision_model.vision_encoder.positional_embedding\b': r'vision_model.gated_positional_embedding.embedding',
83-
r'vision_model.vision_encoder.gated_positional_embedding\b': r'vision_model.gated_positional_embedding.tile_embedding',
83+
r'vision_model.vision_encoder.gated_positional_embedding\b': r'vision_model.gated_positional_embedding.tile_embedding.weight',
8484
r'vision_model.vision_encoder.gated_positional_embedding_gate': r'vision_model.gated_positional_embedding.gate',
85+
r"vision_model.vision_encoder.pre_tile_pos_embed.embedding": r"vision_model.pre_tile_positional_embedding.embedding.weight",
86+
r"vision_model.vision_encoder.post_tile_pos_embed.embedding": r"vision_model.post_tile_positional_embedding.embedding.weight",
87+
r"vision_model.vision_encoder.pre_tile_pos_embed.gate": r"vision_model.pre_tile_positional_embedding.gate",
88+
r"vision_model.vision_encoder.post_tile_pos_embed.gate": r"vision_model.post_tile_positional_embedding.gate",
8589
r"vision_model.vision_encoder.(?=\w)": r"vision_model.",
8690
}
8791
# fmt: on
@@ -159,6 +163,7 @@ def pre_compute_positional_embedding(embedding):
159163
aspect_ratio_id = i + 1 # we keep 0 index for padding
160164
current_embedding = embedding[:height, :width].reshape(height * width, num_patches, hidden_size)
161165
precomputed_embeddings[aspect_ratio_id, : height * width] = current_embedding
166+
precomputed_embeddings = precomputed_embeddings.flatten(1)
162167
return precomputed_embeddings
163168

164169

@@ -230,6 +235,7 @@ def write_model(
230235
num_channels = 3
231236
# intermediate size: 28672 for 90B, 5120 for 11B
232237
intermediate_size = compute_intermediate_size(dim, multiple_of=params["multiple_of"])
238+
intermediate_layers_indices = [3, 7, 15, 23, 30] # TODO: Check for 90B model
233239

234240
# vision model
235241
n_layers_vision = 32 # constant
@@ -338,7 +344,9 @@ def write_model(
338344
elif new_key.endswith("gate"):
339345
state_dict[new_key] = current_parameter[0].view(1)
340346

341-
elif "tile_pos_embed.embedding" in new_key or "gated_positional_embedding.tile_embedding" in new_key:
347+
elif (
348+
"tile_positional_embedding.embedding" in new_key or "gated_positional_embedding.tile_embedding" in new_key
349+
):
342350
# pre-compute the embeddings
343351
state_dict[new_key] = pre_compute_positional_embedding(current_parameter)
344352

@@ -360,20 +368,20 @@ def write_model(
360368
# Write configs
361369
config_parameters = {CONFIG_KEY_MAPPING[key]: params[key] for key in CONFIG_KEY_MAPPING.keys()}
362370
vision_config = MllamaVisionConfig(
371+
hidden_size=dim_vision, # Constant, taken directly from your notes
372+
intermediate_size=dim_vision * 4,
363373
num_hidden_layers=n_layers_vision,
364-
vision_input_dim=dim_vision, # Constant, taken directly from your notes
365-
return_intermediate=[3, 7, 15, 23, 30], # Based on return_intermediate indices
366-
num_global_layers=n_layers_vision_global,
367-
vision_chunk_size=params["vision_chunk_size"],
368374
num_attention_heads=n_heads_vision,
375+
num_global_layers=n_layers_vision_global,
376+
intermediate_layers_indices=intermediate_layers_indices, # Based on return_intermediate indices
377+
image_size=params["vision_chunk_size"],
369378
max_num_tiles=4,
370379
supported_aspect_ratios=get_all_supported_aspect_ratios(4),
371380
)
372381
text_config = MllamaTextConfig(
373382
**config_parameters,
374383
num_hidden_layers=len(cross_layer_shift) + n_layers,
375384
cross_attention_layers=cross_layer_shift,
376-
vision_input_dim=dim_vision, # Constant, aligned with vision config
377385
attention_bias=False, # Constant set to False
378386
tie_word_embeddings=False, # Constant set to False
379387
intermediate_size=intermediate_size,

src/transformers/models/mllama/dummy_convert.py

-130
This file was deleted.

src/transformers/models/mllama/image_processing_mllama.py

-2
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,6 @@ def pack_aspect_ratios(aspect_ratios: List[List[Tuple[int, int]]], pad_value: in
343343
The aspect ratios stacked into a numpy array with shape (batch_size, max_num_images, 2).
344344
"""
345345
batch_size = len(aspect_ratios)
346-
347-
# TODO: in original code there is also max_images = max(max_images, 1)
348346
max_num_images = max([len(row) for row in aspect_ratios])
349347

350348
aspect_ratios_stacked = np.full((batch_size, max_num_images, 2), pad_value, dtype=np.int64)

0 commit comments

Comments
 (0)