Skip to content

Commit c487c62

Browse files
authored
Merge pull request #4 from huggingface/moe-128
128 experts
2 parents b5373e2 + b077bb5 commit c487c62

File tree

4 files changed

+82
-41
lines changed

4 files changed

+82
-41
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5883,10 +5883,10 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
58835883
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
58845884
param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
58855885

5886-
total_byte_count[device] += param_byte_count
5886+
parameter_count[device] += param_byte_count
58875887

58885888
# This will kick off the caching allocator to avoid having to Malloc afterwards
5889-
for device, byte_count in total_byte_count.items():
5889+
for device, byte_count in parameter_count.items():
58905890
if device.type == "cuda":
58915891
index = device.index if device.index is not None else torch.cuda.current_device()
58925892
device_memory = torch.cuda.mem_get_info(index)[0]

src/transformers/models/llama4/configuration_llama4.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ class Llama4TextConfig(PretrainedConfig):
142142
"layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise", # row because not linear
143143
"layers.*.feed_forward.experts.down_proj": "local_colwise", # col because not linear
144144
"layers.*.feed_forward.experts": "local",
145+
"layers.*.feed_forward.down_proj": "local_rowwise",
145146
"layers.*.feed_forward": "gather",
146147
}
147148

@@ -150,6 +151,7 @@ def __init__(
150151
vocab_size=202048,
151152
hidden_size=5120,
152153
intermediate_size=8192,
154+
intermediate_size_mlp=16384,
153155
num_hidden_layers=48,
154156
num_attention_heads=40,
155157
num_key_value_heads=8,
@@ -167,10 +169,12 @@ def __init__(
167169
attention_dropout=0.0,
168170
num_experts_per_tok=1,
169171
num_local_experts=16,
172+
interleave_moe_layer_step=1,
173+
use_qk_norm=True,
170174
output_router_logits=False,
171175
router_aux_loss_coef=0.001,
172176
router_jitter_noise=0.0,
173-
rope_scaling="llama3",
177+
rope_scaling=None,
174178
**kwargs,
175179
):
176180
super().__init__(
@@ -184,6 +188,7 @@ def __init__(
184188
self.max_position_embeddings = max_position_embeddings
185189
self.hidden_size = hidden_size
186190
self.intermediate_size = intermediate_size
191+
self.intermediate_size_mlp = intermediate_size_mlp
187192
self.num_hidden_layers = num_hidden_layers
188193
self.num_attention_heads = num_attention_heads
189194
self.rope_scaling = rope_scaling
@@ -201,9 +206,11 @@ def __init__(
201206
self.rope_theta = rope_theta
202207
self.attention_dropout = attention_dropout
203208
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
209+
self.use_qk_norm = use_qk_norm
204210

205211
self.num_experts_per_tok = num_experts_per_tok
206212
self.num_local_experts = num_local_experts
213+
self.interleave_moe_layer_step = interleave_moe_layer_step
207214
self.output_router_logits = output_router_logits
208215
self.router_aux_loss_coef = router_aux_loss_coef
209216
self.router_jitter_noise = router_jitter_noise

src/transformers/models/llama4/convert_llama4_weights_to_hf.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@
5555
# Unused keys in load hooks (explicitly removed)
5656
r'layers.(\d+).attention.wqkv._extra_state': None,
5757
r'layers.(\d+).attention.wo._extra_state': None,
58+
59+
# MLP layer variant
60+
r"layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.feed_forward.gate_proj.weight", # might need to be fused for efficiency?
61+
r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.feed_forward.up_proj.weight", # might need to be fused for efficiency?
62+
# r"layers.(\d+).feed_forward.mlp.fc1_weight": r"language_model.model.layers.\1.feed_forward.gate_up_proj.weight",
63+
r"layers.(\d+).feed_forward.mlp.fc2_weight": r"language_model.model.layers.\1.feed_forward.down_proj.weight",
64+
r"layers.(\d+).feed_forward.mlp.layer_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
65+
5866
# Vision encoder mapping
5967
r"vision_embeddings.vision_encoder.conv1._linear": r"vision_model.patch_embedding.linear",
6068
r'vision_embeddings.vision_adapter.mlp.c_fc': r"vision_model.vision_adapter.mlp.fc1",
@@ -142,6 +150,9 @@ def get_concat_dim(key):
142150
"experts.gate_proj",
143151
"experts.up_proj",
144152
"expert.down_proj",
153+
# "feed_forward.up_proj",
154+
# "feed_forward.gate_proj",
155+
"feed_forward.down_proj",
145156
"global_gate_stats",
146157
# vision dim1 sharded stuff
147158
"mlp.fc2.weight", # covers all rowparallels across vis
@@ -166,6 +177,20 @@ def safe_load(filename):
166177
return shard
167178

168179

180+
# Unpack mlp projections - possibly to be removed when they are fused
181+
def preprocess_keys(state_dict):
182+
new_state_dict = dict()
183+
for key, value in state_dict.items():
184+
if "mlp.fc1_weight" in key:
185+
prefix = key.split("mlp.fc1_weight")[0]
186+
w1, w3 = value.chunk(2, dim=0)
187+
new_state_dict[prefix + "w1.weight"] = w1
188+
new_state_dict[prefix + "w3.weight"] = w3
189+
else:
190+
new_state_dict[key] = value
191+
return new_state_dict
192+
193+
169194
def write_model(
170195
model_path,
171196
input_base_path,
@@ -194,14 +219,17 @@ def write_model(
194219
rms_norm_eps = params["norm_eps"]
195220
rope_theta = params["rope_theta"]
196221

197-
# some constans from original code
198-
rope_scaling = {
199-
"rope_type": "llama3",
200-
"factor": 8.0,
201-
"low_freq_factor": 1.0,
202-
"high_freq_factor": 4.0,
203-
"original_max_position_embeddings": 8192,
204-
}
222+
config_kwargs = {}
223+
if params["use_scaled_rope"]:
224+
# some constans from original code
225+
rope_scaling = {
226+
"rope_type": "llama3",
227+
"factor": 8.0,
228+
"low_freq_factor": 1.0,
229+
"high_freq_factor": 4.0,
230+
"original_max_position_embeddings": 8192,
231+
}
232+
config_kwargs.update(dict(rope_scaling=rope_scaling))
205233

206234
# compute additional params for weight conversion
207235
num_heads_per_shard = num_heads // num_shards
@@ -211,9 +239,10 @@ def write_model(
211239
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
212240

213241
num_experts = params["moe_args"]["num_experts"]
242+
interleave_moe_layer_step = params["moe_args"].get("interleave_moe_layer_step", 1)
214243

215244
bos_token_id = 200000
216-
eos_token_id = [200001, 200002, 200003] if instruct else 200001
245+
eos_token_id = [200001, 200002, 200003, 200008] if instruct else 200001
217246
pad_token_id = 200008
218247

219248
text_config = Llama4TextConfig(
@@ -224,13 +253,16 @@ def write_model(
224253
rope_theta=rope_theta,
225254
num_hidden_layers=num_layers,
226255
intermediate_size=8192,
227-
rope_scaling=rope_scaling,
256+
intermediate_size_mlp=16384,
228257
num_local_experts=num_experts,
258+
interleave_moe_layer_step=interleave_moe_layer_step,
259+
use_qk_norm=params["use_qk_norm"],
229260
bos_token_id=bos_token_id,
230261
eos_token_id=eos_token_id,
231262
pad_token_id=pad_token_id,
232263
tie_word_embeddings=False, # Constant set to False
233264
torch_dtype=torch_dtype,
265+
**config_kwargs,
234266
)
235267
# default vision config frmo params
236268

@@ -273,6 +305,7 @@ def write_model(
273305
safe_load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"))
274306
for i in tqdm(range(num_shards), desc="Loading shards", unit="shard")
275307
]
308+
loaded = [preprocess_keys(d) for d in loaded]
276309

277310
all_keys_raw = list(loaded[0].keys())
278311
repeated_keys = []
@@ -354,7 +387,7 @@ def write_model(
354387
if gate_key == new_key:
355388
state_dict[new_key] = torch.cat(current_parameter, dim=concat_dim)
356389
elif new_key == up_key:
357-
if "shared" in new_key:
390+
if "experts" not in new_key:
358391
gate_proj = state_dict.pop(gate_key)
359392
up_proj = torch.cat(current_parameter, dim=concat_dim)
360393
state_dict[gate_key] = gate_proj
@@ -365,11 +398,11 @@ def write_model(
365398
else:
366399
gate_proj = state_dict.pop(gate_key)
367400
gate_proj = [
368-
gate_proj.reshape(16, -1, 8, 1024)[:, :, k, :].reshape(16, -1, 1024) for k in range(8)
401+
gate_proj.reshape(num_experts, -1, 8, 1024)[:, :, k, :].reshape(num_experts, -1, 1024) for k in range(8)
369402
]
370403
gate_proj = torch.cat(gate_proj, dim=-1)
371404

372-
up_proj = [k.reshape(16, -1, 8, 1024).reshape(16, -1, 1024) for k in current_parameter]
405+
up_proj = [k.reshape(num_experts, -1, 8, 1024).reshape(num_experts, -1, 1024) for k in current_parameter]
373406
up_proj = torch.cat(up_proj, dim=-1)
374407

375408
gate_up_proj = torch.cat((gate_proj, up_proj), dim=-1)
@@ -432,10 +465,7 @@ def write_model(
432465
print("Loading the checkpoint in a Llama4 model.")
433466
state_dict.pop("")
434467
model.load_state_dict(state_dict, strict=True, assign=True)
435-
print("Model reloaded successfully. Checking logits...")
436-
# ipdb.set_trace()
437-
# zero_out = model.forward(inputs_embeds=torch.zeros((1,743, 4096)))
438-
# ipdb.set_trace()
468+
print("Model reloaded successfully.")
439469
print("Saving the model.")
440470
model.save_pretrained(model_path, safe_serialization=safe_serialization)
441471
del state_dict, model
@@ -448,8 +478,7 @@ def write_model(
448478
model = Llama4ForConditionalGeneration.from_pretrained(
449479
model_path, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager"
450480
)
451-
# ipdb.set_trace()
452-
model.eval()
481+
453482
model.generation_config.top_p = 0.9
454483
model.generation_config.temperature = 0.6
455484
print("Model reloaded successfully.")
@@ -458,7 +487,7 @@ def write_model(
458487

459488
tokenizer = AutoTokenizer.from_pretrained(model_path)
460489
inputs = tokenizer(["Roses are red,"], return_tensors="pt").to(model.device)
461-
out = model.generate(**inputs, max_new_tokens=10)
490+
out = model.generate(**inputs, max_new_tokens=4)
462491
print(tokenizer.batch_decode(out))
463492
# generation config
464493
if instruct:

src/transformers/models/llama4/modeling_llama4.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -96,18 +96,23 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9696

9797
# Phi3MLP
9898
class Llama4TextMLP(nn.Module):
99-
def __init__(self, config):
99+
def __init__(self, config, intermediate_size=None):
100100
super().__init__()
101+
102+
if intermediate_size is None:
103+
intermediate_size = config.intermediate_size
104+
101105
self.config = config
102-
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
103-
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
104-
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
106+
self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
107+
self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
108+
self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
105109
self.activation_fn = ACT2FN[config.hidden_act]
106110

107111
def forward(self, x):
108112
down_proj = self.activation_fn(self.gate_proj(x)) * self.up_proj(x)
109113
return self.down_proj(down_proj)
110114

115+
111116
class Llama4TextL2Norm(torch.nn.Module):
112117
def __init__(self, dim: int=None, eps: float = 1e-6):
113118
super().__init__()
@@ -193,7 +198,7 @@ class Llama4TextRotaryEmbedding(nn.Module):
193198
def __init__(self, config: Llama4TextConfig, device=None):
194199
super().__init__()
195200
# BC: "rope_type" was originally "type"
196-
self.rope_type = "llama3"
201+
self.rope_type = "llama3" if config.rope_scaling is not None else "default"
197202

198203
self.max_seq_len_cached = config.max_position_embeddings
199204
self.original_max_seq_len = config.max_position_embeddings
@@ -319,7 +324,8 @@ def __init__(self, config, layer_idx):
319324
self.o_proj = nn.Linear(
320325
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
321326
)
322-
self.qk_norm = Llama4TextL2Norm()
327+
if self.config.use_qk_norm:
328+
self.qk_norm = Llama4TextL2Norm()
323329

324330
def forward(
325331
self,
@@ -341,16 +347,7 @@ def forward(
341347
query_states, key_states, position_embeddings.to(query_states.device)
342348
)
343349

344-
# because L2 is computed on the shards, we need to find an appropriate reshape
345-
# here, to make sure in TP but also non TP settings. Logits diverge otherwise
346-
if query_states.shape[-1] == self.num_attention_heads * self.head_dim:
347-
query_states = self.qk_norm(
348-
query_states.view(input_shape[0], input_shape[1], self.pretraining_tp, -1)
349-
).reshape(hidden_shape)
350-
key_states = self.qk_norm(
351-
key_states.view(input_shape[0], input_shape[1], self.pretraining_tp, -1)
352-
).reshape((*input_shape, self.pretraining_tp, -1))
353-
else:
350+
if self.config.use_qk_norm:
354351
query_states = self.qk_norm(query_states)
355352
key_states = self.qk_norm(key_states)
356353

@@ -394,7 +391,11 @@ def __init__(self, config, layer_idx):
394391
self.hidden_size = config.hidden_size
395392

396393
self.self_attn = Llama4TextAttention(config, layer_idx)
397-
self.feed_forward = Llama4TextMoe(config)
394+
self.is_moe_layer = (layer_idx + 1) % config.interleave_moe_layer_step == 0
395+
if self.is_moe_layer:
396+
self.feed_forward = Llama4TextMoe(config)
397+
else:
398+
self.feed_forward = Llama4TextMLP(config, intermediate_size=config.intermediate_size_mlp)
398399

399400
self.input_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
400401
self.post_attention_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -457,7 +458,11 @@ def forward(
457458
residual = hidden_states
458459

459460
hidden_states = self.post_attention_layernorm(hidden_states)
460-
hidden_states, router_logits = self.feed_forward(hidden_states)
461+
hidden_states = self.feed_forward(hidden_states)
462+
if self.is_moe_layer:
463+
hidden_states, router_logits = hidden_states
464+
else:
465+
router_logits = None
461466
hidden_states = residual + hidden_states.view(residual.shape)
462467

463468
outputs = (hidden_states,)

0 commit comments

Comments
 (0)