Skip to content

[LLM] Add Qwen3 (for CI Test) #10534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,20 @@

## News 📢

* **2025.03.17 《DeepSeek-R1满血版单机部署实测》** 🔥🔥🔥 飞桨框架3.0大模型推理部署全面升级,支持多款主流大模型,DeepSeek-R1满血版实现单机部署,吞吐提升一倍!欢迎广大用户开箱体验~现已开启有奖活动:完成 DeepSeek-R1-MTP 单机部署任务、提交高质量测评 blog,即可实时赢取奖金!💰💰💰
报名[地址](https://www.wjx.top/vm/OlzzmbG.aspx#), 活动详情:https://github.com/PaddlePaddle/PaddleNLP/issues/10166 , 参考文档:https://github.com/PaddlePaddle/PaddleNLP/issues/10157 。
* **2025.04.29 PaddleNLP 现已支持 Qwen3 系列模型**: Qwen3 系列模型支持持两种思考模式,预训练约 36 万亿个 token、119 种语言和方言。包括六个 Dense 模型, Qwen3-32B、Qwen3-14B、Qwen3-8B、Qwen3-4B、Qwen3-1.7B 和 Qwen3-0.6B。两个 MoE 模型的权重:Qwen3-235B-A22B,Qwen3-30B-A3B。

* **2025.03.12 [PaddleNLP v3.0 Beta4](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v3.0.0-beta4)**:全面支持 DeepSeek V3/R1/R1-Distill, 及 QwQ-32B 等热门思考模型。**DeepSeek V3/R1完整版支持 FP8、INT8、4-bit 量化推理,MTP 投机解码**。单机 FP8推理输出超**1000 tokens/s**; 4-bit 推理输出超**2100 tokens/s**! 发布新版推理部署镜像,热门模型[一键部署](https://paddlenlp.readthedocs.io/zh/latest/llm/server/docs/general_model_inference.html)。推理部署[使用文档](https://paddlenlp.readthedocs.io/zh/latest/llm/docs/predict/index.html)全面更新,体验全面提升!自研下一代通用信息抽取模型 PP-UIE [全新发布](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/application/information_extraction),支持8K 长度信息抽取。新增大模型 Embedding 训练,支持 INF-CL 超大 batch size 训练。新增[MergeKit](https://paddlenlp.readthedocs.io/zh/latest/llm/docs/mergekit.html)模型融合工具,缓解对齐代价。低资源训练全面优化,16G 小显存可以流畅训练。

* **2025.03.06 PaddleNLP 现已支持 Qwen/QwQ-32B 模型**: 其模型参数仅有 32B,但其数学推理、编程能力和通用能力可与具备 671B 参数(其中 37B 被激活)的 DeepSeek-R1 媲美。借助 PaddleNLP 3.0套件,现可实现多种并行策略[微调训练](./llm/README.md)、[高性能推理、低比特量化](./llm/docs/predict/qwen.md)和[服务化部署](./llm/server/README.md)。

* **2025.02.10 PaddleNLP 现已支持 DeepSeek-R1系列模型,[在线使用](https://aistudio.baidu.com/projectdetail/8775758)**:依托全新的 PaddleNLP 3.0套件,DeepSeek-R1系列模型现已全面支持。凭借数据并行、数据分组切分并行、模型并行、流水线并行以及专家并行等一系列先进的分布式训练能力,结合 Paddle 框架独有的列稀疏注意力掩码表示技术——FlashMask 方法,DeepSeek-R1系列模型在训练过程中显著降低了显存消耗,同时取得了卓越的训练性能提升。

<details><summary> <b>点击展开</b> </summary><div>

* **2025.03.17 《DeepSeek-R1满血版单机部署实测》** 🔥🔥🔥 飞桨框架3.0大模型推理部署全面升级,支持多款主流大模型,DeepSeek-R1满血版实现单机部署,吞吐提升一倍!欢迎广大用户开箱体验~现已开启有奖活动:完成 DeepSeek-R1-MTP 单机部署任务、提交高质量测评 blog,即可实时赢取奖金!💰💰💰
报名[地址](https://www.wjx.top/vm/OlzzmbG.aspx#), 活动详情:https://github.com/PaddlePaddle/PaddleNLP/issues/10166 , 参考文档:https://github.com/PaddlePaddle/PaddleNLP/issues/10157 。

* **2025.03.06 PaddleNLP 现已支持 Qwen/QwQ-32B 模型**: 其模型参数仅有 32B,但其数学推理、编程能力和通用能力可与具备 671B 参数(其中 37B 被激活)的 DeepSeek-R1 媲美。借助 PaddleNLP 3.0套件,现可实现多种并行策略[微调训练](./llm/README.md)、[高性能推理、低比特量化](./llm/docs/predict/qwen.md)和[服务化部署](./llm/server/README.md)。

* **2025.02.20 🔥🔥《PP-UIE 信息抽取智能引擎全新升级》** 强化零样本学习能力,支持极少甚至零标注数据实现高效冷启动与迁移学习,显著降低数据标注成本;具备处理长文本能力,支持 8192 个 Token 长度文档信息抽取,实现跨段落识别关键信息,形成完整理解;提供完整可定制化的训练和推理全流程,训练效率相较于 LLama-Factory 实现了1.8倍的提升。
2月26日(周三)19:00为您深度解析全新 PP-UIE 技术方案及在部署方面的功能、优势与技巧。报名链接:https://www.wjx.top/vm/mBKC6pb.aspx?udsid=606418

Expand Down Expand Up @@ -119,6 +122,7 @@
| [Qwen2.5](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-0.5B, Qwen/Qwen2.5-0.5B-Instruct, Qwen/Qwen2.5-1.5B, Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-3B, Qwen/Qwen2.5-3B-Instruct, Qwen/Qwen2.5-7B, Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-7B-Instruct-1M, Qwen/Qwen2.5-14B, Qwen/Qwen2.5-14B-Instruct, Qwen/Qwen2.5-14B-Instruct-1M, Qwen/Qwen2.5-32B, Qwen/Qwen2.5-32B-Instruct, Qwen/Qwen2.5-72B, Qwen/Qwen2.5-72B-Instruct |
| [Qwen2.5-Math](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Math-1.5B, Qwen/Qwen2.5-Math-1.5B-Instruct, Qwen/Qwen2.5-Math-7B, Qwen/Qwen2.5-Math-7B-Instruct, Qwen/Qwen2.5-Math-72B, Qwen/Qwen2.5-Math-72B-Instruct, Qwen/Qwen2.5-Math-RM-72B |
| [Qwen2.5-Coder](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen2.5-Coder-1.5B, Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-7B, Qwen/Qwen2.5-Coder-7B-Instruct |
| [Qwen3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/Qwen3-0.6B, Qwen/Qwen3-1.7B, Qwen/Qwen3-4B, Qwen/Qwen3-8B, Qwen/Qwen3-14B, Qwen/Qwen3-32B, Qwen/Qwen3-30B-A3B, Qwen/Qwen3-235B-A22B, Qwen/Qwen3-0.6B-Base, Qwen/Qwen3-1.7B-Base, Qwen/Qwen3-4B-Base, Qwen/Qwen3-8B-Base, Qwen/Qwen3-14B-Base, Qwen/Qwen3-30B-A3B-Base |
| [QwQ](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/qwen/) | Qwen/QwQ-32B, Qwen/QwQ-32B-Preview |
| [Yuan2](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/yuan/) | IEITYuan/Yuan2-2B, IEITYuan/Yuan2-51B, IEITYuan/Yuan2-102B |

Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@
from .qwen import *
from .qwen2 import *
from .qwen2_moe import *
from .qwen3 import *
from .qwen3_moe import *
from .reformer.configuration import *
from .reformer.modeling import *
from .reformer.tokenizer import *
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/transformers/auto/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@
("qwen", "QWenConfig"),
("qwen2", "Qwen2Config"),
("qwen2_moe", "Qwen2MoeConfig"),
("qwen3", "Qwen3Config"),
("qwen3_moe", "Qwen3MoeConfig"),
("reformer", "ReformerConfig"),
("rembert", "RemBertConfig"),
("roberta", "RobertaConfig"),
Expand Down Expand Up @@ -190,6 +192,8 @@
("qwen", "QWen"),
("qwen2", "Qwen2"),
("qwen2_moe", "Qwen2Moe"),
("qwen3", "Qwen3"),
("qwen3_moe", "Qwen3Moe"),
("reformer", "Reformer"),
("rembert", "RemBert"),
("roberta", "Roberta"),
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/transformers/auto/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@
("Mistral", "mistral"),
("Mixtral", "mixtral"),
("Qwen2", "qwen2"),
("Qwen3", "qwen3"),
("Qwen2Moe", "qwen2_moe"),
("Qwen3Moe", "qwen3_moe"),
("Gemma", "gemma"),
("Yuan", "yuan"),
("Mamba", "mamba"),
Expand Down
19 changes: 6 additions & 13 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from paddle.nn import Layer

from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
from paddlenlp.utils.env import CONFIG_NAME, PADDLE_WEIGHTS_NAME, PYTORCH_WEIGHTS_NAME
from paddlenlp.utils.env import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from paddlenlp.utils.import_utils import (
is_package_available,
is_torch_available,
Expand Down Expand Up @@ -1161,7 +1161,7 @@
file for file in os.listdir(os.path.dirname(weight_file)) if file.startswith("pytorch_model-")
]
state_dict = {}
for file in files:
for file in sorted(files):

Check warning on line 1164 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1164

Added line #L1164 was not covered by tests
sub_state_dict = load_torch(os.path.join(os.path.dirname(weight_file), file))
state_dict.update(sub_state_dict)
else:
Expand All @@ -1179,13 +1179,9 @@
all_layer_names.remove(name_mapping.source_name)

if all_layer_names:
logger.warning(f"there are {len(all_layer_names)} tensors not initialized:")
for layer_name in all_layer_names:
logger.warning(f"--- {layer_name}")
logger.warning(f"There are {len(all_layer_names)} tensors not initialized:")
logger.warning(f"Keys: {all_layer_names}")

model_weight_file = os.path.join(cache_dir, PADDLE_WEIGHTS_NAME)
if not os.path.isfile(model_weight_file):
paddle.save(state_dict, model_weight_file)
return state_dict

@classmethod
Expand Down Expand Up @@ -1553,10 +1549,7 @@
all_layer_names.remove(name_mapping.source_name)

if all_layer_names:
logger.warning(f"there are {len(all_layer_names)} tensors not initialized:")
for layer_name in all_layer_names:
logger.warning(f"--- {layer_name}")
logger.warning(f"There are {len(all_layer_names)} tensors not initialized:")
logger.warning(f"Keys: {all_layer_names}")

Check warning on line 1553 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L1552-L1553

Added lines #L1552 - L1553 were not covered by tests

model_weight_file = os.path.join(input_dir, PADDLE_WEIGHTS_NAME)
paddle.save(state_dict, model_weight_file)
return state_dict
40 changes: 26 additions & 14 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,12 +807,14 @@
return missing_keys, unexpected_keys


def faster_set_state_dict(model, state_dict, strict_dtype=True):
def faster_set_state_dict(model, state_dict, model_state_dict=None, strict_dtype=True):
if model_state_dict is None:
model_state_dict = model.state_dict()

Check warning on line 812 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L812

Added line #L812 was not covered by tests
# the state_dict will be destroyed.
unused_keys = set(state_dict.keys())
unset_keys = set(model.state_dict().keys())
unset_keys = set(model_state_dict.keys())
with paddle.no_grad():
for k, v in model.state_dict().items():
for k, v in model_state_dict.items():
if k in state_dict:
v_new = state_dict.pop(k)
if not isinstance(v_new, paddle.Tensor):
Expand Down Expand Up @@ -857,14 +859,14 @@
return error_msgs


def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, model_state_dict):
# torch will cast dtype in load_state_dict, but paddle strictly check dtype
if len(start_prefix) > 0:
for key in list(state_dict.keys()):
if key.startswith(start_prefix):
state_dict[key.replace(start_prefix, "")] = state_dict.pop(key)

_convert_state_dict_dtype_and_shape(state_dict, model_to_load)
_convert_state_dict_dtype_and_shape(state_dict, model_state_dict)

error_msgs = []

Expand All @@ -874,20 +876,21 @@
# paddlenlp hold missing_keys , just ignore not found warnings.
warnings.filterwarnings("ignore", message=r".*is not found in the provided dict.*")
warnings.filterwarnings("ignore", message=r".*paddle.to_tensor.*")
model_to_load.set_state_dict(state_dict)
# model_to_load.set_state_dict(state_dict)
faster_set_state_dict(model_to_load, state_dict, model_state_dict)
error_msgs.extend([str(x.message) for x in w])

del state_dict

return error_msgs


def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
def _convert_state_dict_dtype_and_shape(state_dict, model_state_dict):
# convert the dtype of state dict
def is_0d_or_1d(tensor):
return len(tensor.shape) == 0 or list(tensor.shape) == [1]

for key, value in model_to_load.state_dict().items():
for key, value in model_state_dict.items():
if key in list(state_dict.keys()):
if isinstance(state_dict[key], np.ndarray):
raise ValueError(
Expand Down Expand Up @@ -2034,12 +2037,15 @@
start_prefix = cls.base_model_prefix + "."
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix)
base_model_expected_keys = list(model_to_load.state_dict().keys())
base_model_expected_keys = list(model_state_dict.keys())
if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys):
raise ValueError(
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
"properly saved?"
)
model_to_load_state_dict = model_to_load.state_dict()
else:
model_to_load_state_dict = model_state_dict

def _find_mismatched_keys(
state_dict,
Expand Down Expand Up @@ -2143,7 +2149,12 @@
keep_in_fp32_modules=keep_in_fp32_modules,
)
else:
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
error_msgs = _load_state_dict_into_model(
model_to_load,
state_dict,
start_prefix,
model_to_load_state_dict,
)
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True

Expand All @@ -2156,9 +2167,7 @@
resume_state_dict = {}
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading checkpoint shards")
if low_cpu_mem_usage or quantization_linear_list is not None:
# model.state_dict() takes a long time
model_to_load_state_dict = model_to_load.state_dict()

for shard_file in resolved_archive_file:
pre_tensor_parallel_split = False
if quantization_linear_list is not None:
Expand Down Expand Up @@ -2270,7 +2279,9 @@
)
error_msgs += new_error_msgs
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
error_msgs += _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, model_to_load_state_dict
)

# force memory release
del state_dict
Expand Down Expand Up @@ -2581,6 +2592,7 @@
for key in model.state_dict().keys():
if "quant_weight" in key:
quantization_linear_list.append(key[:-13])

model, missing_keys, unexpected_keys, mismatched_keys = cls._load_pretrained_model(
model=model,
state_dict=state_dict,
Expand Down
14 changes: 8 additions & 6 deletions paddlenlp/transformers/moe_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@
chosen_expert = topk_idx.reshape([-1])
# Shape: [seq_len * k, num_experts].
token_priority = F.one_hot(chosen_expert, self.num_experts).cast(paddle.int32)
token_priority = paddle.logical_and(token_priority > 0, token_priority.cumsum(axis=0) < capacity)
token_priority = paddle.logical_and(token_priority > 0, token_priority.cumsum(axis=0) <= capacity)
# Shape: [seq_len, num_experts].
token_priority = token_priority.reshape([-1, k, self.num_experts]).sum(axis=1)

Expand Down Expand Up @@ -532,12 +532,14 @@
token_priority = self._priority(top_idx, capacity)

# normalize gates
# gates_masked is equal to top_gate.
gates_masked = gates * mask
if self.training:
gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True)
denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps)
if self.norm_topk_prob:
gates_masked = gates_masked / denom_s
# if self.training:
gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True)
denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps)
if self.norm_topk_prob:
gates_masked = gates_masked / denom_s

Check warning on line 541 in paddlenlp/transformers/moe_gate.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/moe_gate.py#L541

Added line #L541 was not covered by tests
gates_masked *= self.routed_scaling_factor

return (
capacity,
Expand Down
6 changes: 4 additions & 2 deletions paddlenlp/transformers/qwen2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
seq_length=32768,
# seq_length=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
Expand All @@ -113,6 +113,7 @@ def __init__(
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_bias=True,
attention_dropout=0.0,
rope_scaling_factor=1.0,
rope_scaling_type=None,
Expand All @@ -122,7 +123,7 @@ def __init__(
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.seq_length = seq_length
# self.seq_length = seq_length
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
Expand All @@ -141,6 +142,7 @@ def __init__(
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout

self.use_cache = use_cache
Expand Down
Loading