55
55
# Unused keys in load hooks (explicitly removed)
56
56
r'layers.(\d+).attention.wqkv._extra_state' : None ,
57
57
r'layers.(\d+).attention.wo._extra_state' : None ,
58
+
59
+ # MLP layer variant
60
+ r"layers.(\d+).feed_forward.w1.weight" : r"language_model.model.layers.\1.feed_forward.gate_proj.weight" , # might need to be fused for efficiency?
61
+ r"layers.(\d+).feed_forward.w3.weight" : r"language_model.model.layers.\1.feed_forward.up_proj.weight" , # might need to be fused for efficiency?
62
+ # r"layers.(\d+).feed_forward.mlp.fc1_weight": r"language_model.model.layers.\1.feed_forward.gate_up_proj.weight",
63
+ r"layers.(\d+).feed_forward.mlp.fc2_weight" : r"language_model.model.layers.\1.feed_forward.down_proj.weight" ,
64
+ r"layers.(\d+).feed_forward.mlp.layer_norm.weight" : r"language_model.model.layers.\1.post_attention_layernorm.weight" ,
65
+
58
66
# Vision encoder mapping
59
67
r"vision_embeddings.vision_encoder.conv1._linear" : r"vision_model.patch_embedding.linear" ,
60
68
r'vision_embeddings.vision_adapter.mlp.c_fc' : r"vision_model.vision_adapter.mlp.fc1" ,
@@ -142,6 +150,9 @@ def get_concat_dim(key):
142
150
"experts.gate_proj" ,
143
151
"experts.up_proj" ,
144
152
"expert.down_proj" ,
153
+ # "feed_forward.up_proj",
154
+ # "feed_forward.gate_proj",
155
+ "feed_forward.down_proj" ,
145
156
"global_gate_stats" ,
146
157
# vision dim1 sharded stuff
147
158
"mlp.fc2.weight" , # covers all rowparallels across vis
@@ -166,6 +177,20 @@ def safe_load(filename):
166
177
return shard
167
178
168
179
180
+ # Unpack mlp projections - possibly to be removed when they are fused
181
+ def preprocess_keys (state_dict ):
182
+ new_state_dict = dict ()
183
+ for key , value in state_dict .items ():
184
+ if "mlp.fc1_weight" in key :
185
+ prefix = key .split ("mlp.fc1_weight" )[0 ]
186
+ w1 , w3 = value .chunk (2 , dim = 0 )
187
+ new_state_dict [prefix + "w1.weight" ] = w1
188
+ new_state_dict [prefix + "w3.weight" ] = w3
189
+ else :
190
+ new_state_dict [key ] = value
191
+ return new_state_dict
192
+
193
+
169
194
def write_model (
170
195
model_path ,
171
196
input_base_path ,
@@ -194,14 +219,17 @@ def write_model(
194
219
rms_norm_eps = params ["norm_eps" ]
195
220
rope_theta = params ["rope_theta" ]
196
221
197
- # some constans from original code
198
- rope_scaling = {
199
- "rope_type" : "llama3" ,
200
- "factor" : 8.0 ,
201
- "low_freq_factor" : 1.0 ,
202
- "high_freq_factor" : 4.0 ,
203
- "original_max_position_embeddings" : 8192 ,
204
- }
222
+ config_kwargs = {}
223
+ if params ["use_scaled_rope" ]:
224
+ # some constans from original code
225
+ rope_scaling = {
226
+ "rope_type" : "llama3" ,
227
+ "factor" : 8.0 ,
228
+ "low_freq_factor" : 1.0 ,
229
+ "high_freq_factor" : 4.0 ,
230
+ "original_max_position_embeddings" : 8192 ,
231
+ }
232
+ config_kwargs .update (dict (rope_scaling = rope_scaling ))
205
233
206
234
# compute additional params for weight conversion
207
235
num_heads_per_shard = num_heads // num_shards
@@ -211,9 +239,10 @@ def write_model(
211
239
num_key_value_heads = params ["n_kv_heads" ] # for GQA / MQA
212
240
213
241
num_experts = params ["moe_args" ]["num_experts" ]
242
+ interleave_moe_layer_step = params ["moe_args" ].get ("interleave_moe_layer_step" , 1 )
214
243
215
244
bos_token_id = 200000
216
- eos_token_id = [200001 , 200002 , 200003 ] if instruct else 200001
245
+ eos_token_id = [200001 , 200002 , 200003 , 200008 ] if instruct else 200001
217
246
pad_token_id = 200008
218
247
219
248
text_config = Llama4TextConfig (
@@ -224,13 +253,16 @@ def write_model(
224
253
rope_theta = rope_theta ,
225
254
num_hidden_layers = num_layers ,
226
255
intermediate_size = 8192 ,
227
- rope_scaling = rope_scaling ,
256
+ intermediate_size_mlp = 16384 ,
228
257
num_local_experts = num_experts ,
258
+ interleave_moe_layer_step = interleave_moe_layer_step ,
259
+ use_qk_norm = params ["use_qk_norm" ],
229
260
bos_token_id = bos_token_id ,
230
261
eos_token_id = eos_token_id ,
231
262
pad_token_id = pad_token_id ,
232
263
tie_word_embeddings = False , # Constant set to False
233
264
torch_dtype = torch_dtype ,
265
+ ** config_kwargs ,
234
266
)
235
267
# default vision config frmo params
236
268
@@ -273,6 +305,7 @@ def write_model(
273
305
safe_load (os .path .join (input_base_path , f"consolidated.{ i :02d} .pth" ))
274
306
for i in tqdm (range (num_shards ), desc = "Loading shards" , unit = "shard" )
275
307
]
308
+ loaded = [preprocess_keys (d ) for d in loaded ]
276
309
277
310
all_keys_raw = list (loaded [0 ].keys ())
278
311
repeated_keys = []
@@ -354,7 +387,7 @@ def write_model(
354
387
if gate_key == new_key :
355
388
state_dict [new_key ] = torch .cat (current_parameter , dim = concat_dim )
356
389
elif new_key == up_key :
357
- if "shared" in new_key :
390
+ if "experts" not in new_key :
358
391
gate_proj = state_dict .pop (gate_key )
359
392
up_proj = torch .cat (current_parameter , dim = concat_dim )
360
393
state_dict [gate_key ] = gate_proj
@@ -365,11 +398,11 @@ def write_model(
365
398
else :
366
399
gate_proj = state_dict .pop (gate_key )
367
400
gate_proj = [
368
- gate_proj .reshape (16 , - 1 , 8 , 1024 )[:, :, k , :].reshape (16 , - 1 , 1024 ) for k in range (8 )
401
+ gate_proj .reshape (num_experts , - 1 , 8 , 1024 )[:, :, k , :].reshape (num_experts , - 1 , 1024 ) for k in range (8 )
369
402
]
370
403
gate_proj = torch .cat (gate_proj , dim = - 1 )
371
404
372
- up_proj = [k .reshape (16 , - 1 , 8 , 1024 ).reshape (16 , - 1 , 1024 ) for k in current_parameter ]
405
+ up_proj = [k .reshape (num_experts , - 1 , 8 , 1024 ).reshape (num_experts , - 1 , 1024 ) for k in current_parameter ]
373
406
up_proj = torch .cat (up_proj , dim = - 1 )
374
407
375
408
gate_up_proj = torch .cat ((gate_proj , up_proj ), dim = - 1 )
@@ -432,10 +465,7 @@ def write_model(
432
465
print ("Loading the checkpoint in a Llama4 model." )
433
466
state_dict .pop ("" )
434
467
model .load_state_dict (state_dict , strict = True , assign = True )
435
- print ("Model reloaded successfully. Checking logits..." )
436
- # ipdb.set_trace()
437
- # zero_out = model.forward(inputs_embeds=torch.zeros((1,743, 4096)))
438
- # ipdb.set_trace()
468
+ print ("Model reloaded successfully." )
439
469
print ("Saving the model." )
440
470
model .save_pretrained (model_path , safe_serialization = safe_serialization )
441
471
del state_dict , model
@@ -448,8 +478,7 @@ def write_model(
448
478
model = Llama4ForConditionalGeneration .from_pretrained (
449
479
model_path , torch_dtype = torch .bfloat16 , device_map = "auto" , attn_implementation = "eager"
450
480
)
451
- # ipdb.set_trace()
452
- model .eval ()
481
+
453
482
model .generation_config .top_p = 0.9
454
483
model .generation_config .temperature = 0.6
455
484
print ("Model reloaded successfully." )
@@ -458,7 +487,7 @@ def write_model(
458
487
459
488
tokenizer = AutoTokenizer .from_pretrained (model_path )
460
489
inputs = tokenizer (["Roses are red," ], return_tensors = "pt" ).to (model .device )
461
- out = model .generate (** inputs , max_new_tokens = 10 )
490
+ out = model .generate (** inputs , max_new_tokens = 4 )
462
491
print (tokenizer .batch_decode (out ))
463
492
# generation config
464
493
if instruct :
0 commit comments