|
21 | 21 | from transformers.integrations.tiktoken import TikTokenConverter
|
22 | 22 |
|
23 | 23 |
|
| 24 | +_OFFLINE_QUANT_COMPATIBLE = os.environ.get("OFFLINE_QUANT_COMPATIBLE", "0") == "1" |
| 25 | + |
24 | 26 | torch.serialization.add_safe_globals([io.BytesIO])
|
25 | 27 | # fmt: off
|
26 | 28 |
|
|
29 | 31 | # Still not sure what to do with those!
|
30 | 32 | # `None` means we drop the key
|
31 | 33 |
|
| 34 | + |
| 35 | +weight_postfix = ".weight" if _OFFLINE_QUANT_COMPATIBLE else "" |
32 | 36 | ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
33 | 37 | # CausalLM keys
|
34 | 38 | r"output.weight": r"language_model.lm_head.weight",
|
|
44 | 48 | r"layers.(\d+).attention.wqkv.weight": r"language_model.model.layers.\1.self_attn.qkv_proj.weight",
|
45 | 49 |
|
46 | 50 | # MoE keys: no simple MLPmodel.
|
47 |
| - r"layers.(\d+).feed_forward.experts.moe_w_in_eD_F": r"language_model.model.layers.\1.feed_forward.experts.gate_proj", # will be fused with up |
48 |
| - r"layers.(\d+).feed_forward.experts.moe_w_out_eF_D": r"language_model.model.layers.\1.feed_forward.experts.down_proj", # expert win |
49 |
| - r"layers.(\d+).feed_forward.experts.moe_w_swiglu_eD_F": r"language_model.model.layers.\1.feed_forward.experts.up_proj", # fused with up |
| 51 | + r"layers.(\d+).feed_forward.experts.moe_w_in_eD_F": r"language_model.model.layers.\1.feed_forward.experts.gate_proj" + weight_postfix, # will be fused with up |
| 52 | + r"layers.(\d+).feed_forward.experts.moe_w_out_eF_D": r"language_model.model.layers.\1.feed_forward.experts.down_proj" + weight_postfix, # expert win |
| 53 | + r"layers.(\d+).feed_forward.experts.moe_w_swiglu_eD_F": r"language_model.model.layers.\1.feed_forward.experts.up_proj" + weight_postfix, # fused with up |
50 | 54 | r"layers.(\d+).feed_forward.router_DE": r"language_model.model.layers.\1.feed_forward.router.weight", # used for top
|
51 | 55 | r"layers.(\d+).feed_forward.w_in_shared_FD": r"language_model.model.layers.\1.feed_forward.shared_expert.gate_proj", # might need to be fused for efficiency?
|
52 | 56 | r"layers.(\d+).feed_forward.w_out_shared_DF": r"language_model.model.layers.\1.feed_forward.shared_expert.down_proj", # might need to be fused for efficiency?
|
@@ -262,6 +266,7 @@ def write_model(
|
262 | 266 | pad_token_id=pad_token_id,
|
263 | 267 | tie_word_embeddings=False, # Constant set to False
|
264 | 268 | torch_dtype=torch_dtype,
|
| 269 | + for_llm_compressor=_OFFLINE_QUANT_COMPATIBLE, |
265 | 270 | **config_kwargs,
|
266 | 271 | )
|
267 | 272 | # default vision config frmo params
|
@@ -380,6 +385,16 @@ def write_model(
|
380 | 385 | v = new_key.replace("qkv", "v")
|
381 | 386 | tqdm.write(f"Processing: {key.ljust(50)} ->\t {v}, {values.shape}")
|
382 | 387 | state_dict[v] = values
|
| 388 | + elif _OFFLINE_QUANT_COMPATIBLE and "feed_forward.experts." in new_key: |
| 389 | + # for experts, we need to split expert for offline quantiation purpose and don't need to fuse |
| 390 | + expert_lists = [] |
| 391 | + for k in current_parameter: |
| 392 | + expert_lists.append(list(k.reshape(num_experts, -1, k.shape[-1]).unbind(0))) # [#expert * IN, OUT] -> #experts * [IN, OUT] |
| 393 | + for i in range(num_experts): |
| 394 | + expert = torch.cat([expert_list[i] for expert_list in expert_lists], dim=concat_dim) |
| 395 | + expert_key = new_key.replace("experts.", f"experts.{i}.") |
| 396 | + state_dict[expert_key] = expert.transpose(0,1).contiguous() #[OUT, IN] |
| 397 | + tqdm.write(f"Processing: {key.ljust(50)} ->\t {expert_key}, {state_dict[expert_key].shape}") |
383 | 398 | elif re.search(r"(gate|up)_proj", new_key):
|
384 | 399 | path = new_key.split(".")
|
385 | 400 | gate_key = re.sub(r"(gate|up)_proj", lambda m: "gate_proj", new_key)
|
@@ -408,6 +423,7 @@ def write_model(
|
408 | 423 | gate_up_proj = torch.cat((gate_proj, up_proj), dim=-1)
|
409 | 424 | new_key = new_key.replace("up_proj", "gate_up_proj")
|
410 | 425 | state_dict[new_key] = gate_up_proj.contiguous()
|
| 426 | + |
411 | 427 | tqdm.write(f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}")
|
412 | 428 | elif "down_proj" in new_key:
|
413 | 429 | current_parameter = torch.cat(current_parameter, dim=concat_dim)
|
|
0 commit comments