Skip to content

Commit 4d0fd0b

Browse files
ZHUIgongel
andauthored
Support call sft training with clone PaddleNLP (#9516)
* init. * refactor * fix import. * fix error. * refactor * fix * refactor. * fix * refine * fix argument update. * refactor. * refactor. * refactor * fix typo. * fix. * add missing. * support short call sft training. * add missing file. * fix import * add init file. --------- Co-authored-by: gongenlei <[email protected]>
1 parent 2c1387f commit 4d0fd0b

File tree

6 files changed

+445
-9
lines changed

6 files changed

+445
-9
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
exclude: 'slm/model_zoo/gpt-3'
1+
exclude: 'slm/model_zoo/gpt-3;csrc/third_party'
22
repos:
33
# For Python files
44
- repo: https://github.com/psf/black.git
@@ -61,4 +61,4 @@ repos:
6161
entry: python scripts/codestyle/check_dead_links.py
6262
language: python
6363
files: \.(md|markdown|rst)$
64-
pass_filenames: true
64+
pass_filenames: true

README.md

+16
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,22 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py
207207
```
208208

209209
更多大模型全流程步骤,请参考[飞桨大模型套件](./llm)介绍。
210+
另外我们还提供了快速微调方式, 无需 clone 源代码:
211+
212+
```python
213+
from paddlenlp.trl import SFTConfig, SFTTrainer
214+
from datasets import load_dataset
215+
216+
dataset = load_dataset("ZHUI/alpaca_demo", split="train")
217+
218+
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT", device="gpu")
219+
trainer = SFTTrainer(
220+
args=training_args,
221+
model="Qwen/Qwen2.5-0.5B",
222+
train_dataset=dataset,
223+
)
224+
trainer.train()
225+
```
210226

211227
更多 PaddleNLP 内容可参考:
212228

paddlenlp/trl/extras/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# https://github.com/huggingface/trl/blob/c10cc8995b6fd45f3a876ec98cade97251abe733/trl/extras/dataset_formatting.py#L74
16+
17+
import logging
18+
from typing import Callable, Literal, Optional, Union
19+
20+
from datasets import Dataset, Value
21+
22+
from ...transformers import AutoTokenizer
23+
24+
FORMAT_MAPPING = {
25+
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
26+
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
27+
"paddlenlp": {"src": Value(dtype="string", id=None), "tgt": Value(dtype="string", id=None)},
28+
}
29+
30+
31+
def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]):
32+
r"""
33+
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer
34+
apply chat template to the dataset
35+
"""
36+
37+
def format_dataset(examples):
38+
if isinstance(examples[messages_field][0], list):
39+
output_texts = []
40+
for i in range(len(examples[messages_field])):
41+
output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
42+
return output_texts
43+
else:
44+
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False)
45+
46+
return format_dataset
47+
48+
49+
def instructions_formatting_function(tokenizer: AutoTokenizer):
50+
r"""
51+
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
52+
apply chat template to the dataset
53+
"""
54+
55+
def format_dataset(examples):
56+
if isinstance(examples["prompt"], list):
57+
output_texts = []
58+
for i in range(len(examples["prompt"])):
59+
converted_sample = [
60+
{"role": "user", "content": examples["prompt"][i]},
61+
{"role": "assistant", "content": examples["completion"][i]},
62+
]
63+
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
64+
return output_texts
65+
else:
66+
converted_sample = [
67+
{"role": "user", "content": examples["prompt"]},
68+
{"role": "assistant", "content": examples["completion"]},
69+
]
70+
return tokenizer.apply_chat_template(converted_sample, tokenize=False)
71+
72+
return format_dataset
73+
74+
75+
def paddlenlp_instructions_formatting_function(tokenizer: AutoTokenizer):
76+
r"""
77+
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
78+
apply chat template to the dataset
79+
"""
80+
81+
def format_dataset(examples):
82+
if isinstance(examples["src"], list):
83+
output_texts = []
84+
for i in range(len(examples["src"])):
85+
converted_sample = [
86+
{"role": "user", "content": examples["src"][i]},
87+
{"role": "assistant", "content": examples["tgt"][i]},
88+
]
89+
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
90+
return output_texts
91+
else:
92+
converted_sample = [
93+
{"role": "user", "content": examples["src"]},
94+
{"role": "assistant", "content": examples["tgt"]},
95+
]
96+
return tokenizer.apply_chat_template(converted_sample, tokenize=False)
97+
98+
return format_dataset
99+
100+
101+
def get_formatting_func_from_dataset(dataset: Union[Dataset], tokenizer: AutoTokenizer) -> Optional[Callable]:
102+
r"""
103+
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
104+
- `ChatML` with [{"role": str, "content": str}]
105+
- `instruction` with [{"prompt": str, "completion": str}]
106+
107+
Args:
108+
dataset (Dataset): User dataset
109+
tokenizer (AutoTokenizer): Tokenizer used for formatting
110+
111+
Returns:
112+
Callable: Formatting function if the dataset format is supported else None
113+
"""
114+
if isinstance(dataset, Dataset):
115+
if "messages" in dataset.features:
116+
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
117+
logging.info("Formatting dataset with chatml format")
118+
return conversations_formatting_function(tokenizer, "messages")
119+
if "conversations" in dataset.features:
120+
if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
121+
logging.info("Formatting dataset with chatml format")
122+
return conversations_formatting_function(tokenizer, "conversations")
123+
elif dataset.features == FORMAT_MAPPING["instruction"]:
124+
logging.info("Formatting dataset with instruction format")
125+
return instructions_formatting_function(tokenizer)
126+
elif dataset.features == FORMAT_MAPPING["paddlenlp"]:
127+
return paddlenlp_instructions_formatting_function(tokenizer)
128+
129+
return None

paddlenlp/trl/sft_config.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass, field
16-
from typing import Optional
16+
from typing import Any, Optional
1717

1818
from paddlenlp.trainer import TrainingArguments
1919
from paddlenlp.trainer.trainer_utils import IntervalStrategy
@@ -49,6 +49,19 @@ class SFTConfig(TrainingArguments):
4949
default="",
5050
metadata={"help": "Configs to unify hybrid parallel checkpoint.\n"},
5151
)
52+
dataset_text_field: str = "text"
53+
learning_rate: float = 2.0e-5
54+
max_seq_length: int = field(
55+
default=2048,
56+
metadata={
57+
"help": "The maximum length that model input tokens can have. When Zero Padding is set to True, it's also the maximum length for Zero Padding data stream"
58+
},
59+
)
60+
dataset_num_proc: Optional[int] = None
61+
dataset_batch_size: int = 1000
62+
model_init_kwargs: Optional[dict[str, Any]] = None
63+
dataset_kwargs: Optional[dict[str, Any]] = None
64+
eval_packing: Optional[bool] = None
5265

5366
def __post_init__(self):
5467
super().__post_init__()

0 commit comments

Comments
 (0)