Skip to content

Commit 87abef5

Browse files
authored
Merge pull request #13 from huggingface/meta_vllm
Meta vllm
2 parents c487c62 + 5b8dd83 commit 87abef5

File tree

5 files changed

+42
-10
lines changed

5 files changed

+42
-10
lines changed

src/transformers/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
530530
"F32": torch.float32,
531531
"F64": torch.float64,
532532
"I64": torch.int64,
533+
"F8_E4M3": torch.float8_e4m3fn
533534
}
534535

535536
if is_torch_greater_or_equal("2.1.0"):

src/transformers/models/llama4/configuration_llama4.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def __init__(
175175
router_aux_loss_coef=0.001,
176176
router_jitter_noise=0.0,
177177
rope_scaling=None,
178+
for_llm_compressor=False,
178179
**kwargs,
179180
):
180181
super().__init__(
@@ -215,6 +216,8 @@ def __init__(
215216
self.router_aux_loss_coef = router_aux_loss_coef
216217
self.router_jitter_noise = router_jitter_noise
217218

219+
self.for_llm_compressor = for_llm_compressor
220+
218221

219222
class Llama4Config(PretrainedConfig):
220223
r"""
@@ -288,6 +291,9 @@ class Llama4Config(PretrainedConfig):
288291
The aux loss factor for the total loss.
289292
router_jitter_noise (`float`, *optional*, defaults to 0.0):
290293
Amount of noise to add to the router.
294+
for_llm_compressor: (`bool`, *optional*, defaults to `False`):
295+
Whether this config is for a checkpoint that aims to use LLM compressor for fp8 quantization.
296+
If `True`, the model MoE part would swap to use Linear instead of FusedMoE.
291297
292298
```python
293299
>>> from transformers import Llama4Model, Llama4Config

src/transformers/models/llama4/convert_llama4_weights_to_hf.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from transformers.integrations.tiktoken import TikTokenConverter
2222

2323

24+
_OFFLINE_QUANT_COMPATIBLE = os.environ.get("OFFLINE_QUANT_COMPATIBLE", "0") == "1"
25+
2426
torch.serialization.add_safe_globals([io.BytesIO])
2527
# fmt: off
2628

@@ -29,6 +31,8 @@
2931
# Still not sure what to do with those!
3032
# `None` means we drop the key
3133

34+
35+
weight_postfix = ".weight" if _OFFLINE_QUANT_COMPATIBLE else ""
3236
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
3337
# CausalLM keys
3438
r"output.weight": r"language_model.lm_head.weight",
@@ -44,9 +48,9 @@
4448
r"layers.(\d+).attention.wqkv.weight": r"language_model.model.layers.\1.self_attn.qkv_proj.weight",
4549

4650
# MoE keys: no simple MLPmodel.
47-
r"layers.(\d+).feed_forward.experts.moe_w_in_eD_F": r"language_model.model.layers.\1.feed_forward.experts.gate_proj", # will be fused with up
48-
r"layers.(\d+).feed_forward.experts.moe_w_out_eF_D": r"language_model.model.layers.\1.feed_forward.experts.down_proj", # expert win
49-
r"layers.(\d+).feed_forward.experts.moe_w_swiglu_eD_F": r"language_model.model.layers.\1.feed_forward.experts.up_proj", # fused with up
51+
r"layers.(\d+).feed_forward.experts.moe_w_in_eD_F": r"language_model.model.layers.\1.feed_forward.experts.gate_proj" + weight_postfix, # will be fused with up
52+
r"layers.(\d+).feed_forward.experts.moe_w_out_eF_D": r"language_model.model.layers.\1.feed_forward.experts.down_proj" + weight_postfix, # expert win
53+
r"layers.(\d+).feed_forward.experts.moe_w_swiglu_eD_F": r"language_model.model.layers.\1.feed_forward.experts.up_proj" + weight_postfix, # fused with up
5054
r"layers.(\d+).feed_forward.router_DE": r"language_model.model.layers.\1.feed_forward.router.weight", # used for top
5155
r"layers.(\d+).feed_forward.w_in_shared_FD": r"language_model.model.layers.\1.feed_forward.shared_expert.gate_proj", # might need to be fused for efficiency?
5256
r"layers.(\d+).feed_forward.w_out_shared_DF": r"language_model.model.layers.\1.feed_forward.shared_expert.down_proj", # might need to be fused for efficiency?
@@ -262,6 +266,7 @@ def write_model(
262266
pad_token_id=pad_token_id,
263267
tie_word_embeddings=False, # Constant set to False
264268
torch_dtype=torch_dtype,
269+
for_llm_compressor=_OFFLINE_QUANT_COMPATIBLE,
265270
**config_kwargs,
266271
)
267272
# default vision config frmo params
@@ -380,6 +385,16 @@ def write_model(
380385
v = new_key.replace("qkv", "v")
381386
tqdm.write(f"Processing: {key.ljust(50)} ->\t {v}, {values.shape}")
382387
state_dict[v] = values
388+
elif _OFFLINE_QUANT_COMPATIBLE and "feed_forward.experts." in new_key:
389+
# for experts, we need to split expert for offline quantiation purpose and don't need to fuse
390+
expert_lists = []
391+
for k in current_parameter:
392+
expert_lists.append(list(k.reshape(num_experts, -1, k.shape[-1]).unbind(0))) # [#expert * IN, OUT] -> #experts * [IN, OUT]
393+
for i in range(num_experts):
394+
expert = torch.cat([expert_list[i] for expert_list in expert_lists], dim=concat_dim)
395+
expert_key = new_key.replace("experts.", f"experts.{i}.")
396+
state_dict[expert_key] = expert.transpose(0,1).contiguous() #[OUT, IN]
397+
tqdm.write(f"Processing: {key.ljust(50)} ->\t {expert_key}, {state_dict[expert_key].shape}")
383398
elif re.search(r"(gate|up)_proj", new_key):
384399
path = new_key.split(".")
385400
gate_key = re.sub(r"(gate|up)_proj", lambda m: "gate_proj", new_key)
@@ -408,6 +423,7 @@ def write_model(
408423
gate_up_proj = torch.cat((gate_proj, up_proj), dim=-1)
409424
new_key = new_key.replace("up_proj", "gate_up_proj")
410425
state_dict[new_key] = gate_up_proj.contiguous()
426+
411427
tqdm.write(f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}")
412428
elif "down_proj" in new_key:
413429
current_parameter = torch.cat(current_parameter, dim=concat_dim)

src/transformers/models/llama4/modeling_llama4.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
from ...utils import (
4848
add_start_docstrings,
4949
add_start_docstrings_to_model_forward,
50-
is_torchdynamo_compiling,
5150
logging,
5251
replace_return_docstrings,
5352
)
@@ -61,7 +60,6 @@
6160
_CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B"
6261
_CONFIG_FOR_DOC = "Llama4Config"
6362

64-
6563
class Llama4TextExperts(nn.Module):
6664
def __init__(self, config: Llama4Config):
6765
super().__init__()
@@ -153,7 +151,12 @@ def __init__(self, config):
153151
super().__init__()
154152
self.top_k = config.num_experts_per_tok
155153
self.hidden_dim = config.hidden_size
156-
self.experts = Llama4TextExperts(config)
154+
self.num_experts = config.num_local_experts
155+
self.for_llm_compressor = config.for_llm_compressor
156+
if self.for_llm_compressor:
157+
self.experts = nn.ModuleList([Llama4TextMLP(config) for _ in range(self.num_experts)])
158+
else:
159+
self.experts = Llama4TextExperts(config)
157160
self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
158161
self.shared_expert = Llama4TextMLP(config)
159162

@@ -184,8 +187,14 @@ def forward(self, hidden_states):
184187
)
185188
# we gather inputs corresponding to each expert based on the router indices
186189
routed_in = routed_in * router_scores.reshape(-1, 1)
187-
routed_out = self.experts(routed_in) # routed in is "sorted" / ready for EP
188-
190+
expert_routed_out_list = []
191+
if self.for_llm_compressor:
192+
routed_in = routed_in.reshape(self.num_experts, -1, routed_in.shape[-1])
193+
for expert_idx in range(self.num_experts):
194+
expert_routed_out_list.append(self.experts[expert_idx](routed_in[expert_idx]))
195+
routed_out = torch.cat(expert_routed_out_list, dim=0)
196+
else:
197+
routed_out = self.experts(routed_in)
189198
out = self.shared_expert(hidden_states)
190199
# now that we finished expert computation -> we scatter add because we gathered previously
191200
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
@@ -1706,7 +1715,7 @@ def forward(
17061715
projected_vision_flat = self.multi_modal_projector(vision_flat)
17071716

17081717
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
1709-
final_mask = special_image_mask.to(inputs_embeds.device)
1718+
final_mask = special_image_mask.to(inputs_embeds.device)
17101719
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
17111720

17121721
final_mask_1d = final_mask[..., 0].reshape(-1)

src/transformers/models/llama4/processing_llama4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def _prompt_split_image(self, aspect_ratio, num_patches_per_chunk):
279279
img_string += "<|tile_x_separator|>"
280280

281281
img_string += "<|tile_y_separator|>"
282-
# img_string += "<|image|>"
282+
img_string += "<|image|>"
283283
img_string += "<|patch|>" * num_patches_per_chunk
284284
img_string += "<|image_end|>"
285285

0 commit comments

Comments
 (0)