Skip to content

Commit 8fb58eb

Browse files
authored
Add instruction data from belle (#5718)
* Add instruction data from belle * add databuilder * resolve some question * fix oom & belle url * fix prediction_loss_only
1 parent 2cb2441 commit 8fb58eb

File tree

5 files changed

+332
-1
lines changed

5 files changed

+332
-1
lines changed

examples/language_model/glm/data.py

+49
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,55 @@
1515
import numpy as np
1616

1717

18+
def custom_instruction_convert_example(example, tokenizer, data_args, is_test=True, is_do_generation=False):
19+
instruction = ""
20+
input = ""
21+
output = ""
22+
if "instruction" in example and "output" in example:
23+
instruction = example["instruction"]
24+
output = example["output"]
25+
else:
26+
assert False, "instruction and output are not in the input dictionary."
27+
if "input" in example["input"]:
28+
input = example["input"]
29+
30+
if "chat" in data_args.task_name:
31+
example["text_a"] = instruction + input
32+
else:
33+
example["text_a"] = "Human: " + instruction + input + "\n Assistant: "
34+
example["text_b"] = output
35+
inputs = tokenizer.encode(example["text_a"], max_length=data_args.src_length - 1, truncation=True)
36+
inputs["input_ids"] = inputs["input_ids"][:-1] + [tokenizer.gmask_token_id] + inputs["input_ids"][-1:]
37+
pad_length = data_args.src_length - len(inputs["input_ids"])
38+
inputs["input_ids"] = np.array([inputs["input_ids"] + [tokenizer.pad_token_id] * pad_length])
39+
inputs["attention_mask"] = np.array([inputs["attention_mask"] + [1] + [0] * pad_length])
40+
sep = inputs["input_ids"].shape[1]
41+
42+
inputs = tokenizer.build_inputs_for_generation(
43+
inputs,
44+
max_gen_length=data_args.tgt_length,
45+
targets=" " + example["text_b"] if not is_test or not is_do_generation else None,
46+
padding="max_length",
47+
)
48+
for input_name in inputs.keys():
49+
inputs[input_name] = inputs[input_name].squeeze(0)
50+
if is_test:
51+
inputs["position_ids"] = inputs["position_ids"][:, : inputs["input_ids"].shape[-1]]
52+
labels = tokenizer.encode(
53+
" " + example["text_b"], add_special_tokens=False, max_length=data_args.tgt_length - 1
54+
)["input_ids"]
55+
loss_mask = [0] * sep + [1] * len(labels) + [0] * (data_args.tgt_length - len(labels))
56+
labels = (
57+
[0] * sep
58+
+ labels
59+
+ [tokenizer.eop_token_id]
60+
+ [tokenizer.pad_token_id] * (data_args.tgt_length - len(labels) - 1)
61+
)
62+
inputs["label_ids"] = labels
63+
inputs["loss_mask"] = loss_mask
64+
return inputs
65+
66+
1867
def custom_convert_example(example, tokenizer, data_args, is_test=True):
1968
source = None
2069
title = None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from dataclasses import dataclass, field
17+
from functools import partial
18+
19+
import paddle
20+
from data import custom_instruction_convert_example
21+
from utils import GLMTrainer
22+
23+
from paddlenlp.data import DefaultDataCollator
24+
from paddlenlp.datasets import load_dataset
25+
from paddlenlp.layers import LoRAConfig, LoRAModel
26+
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
27+
from paddlenlp.transformers import AutoModelForConditionalGeneration, AutoTokenizer
28+
from paddlenlp.utils.log import logger
29+
30+
31+
@dataclass
32+
class DataArgument:
33+
task_name: str = field(default="school_math_0.25M", metadata={"help": "The name of task."})
34+
data_name: str = field(default="bellegroup", metadata={"help": "The name of data."})
35+
src_length: int = field(default=608, metadata={"help": "The max length of source text."})
36+
tgt_length: int = field(default=160, metadata={"help": "The max length of target text."})
37+
min_tgt_length: int = field(default=55, metadata={"help": "The min length of target text."})
38+
length_penalty: float = field(default=0.7, metadata={"help": "The length penalty."})
39+
no_repeat_ngram_size: int = field(default=3, metadata={"help": "The no repeat ngram size."})
40+
num_beams: int = field(default=5, metadata={"help": "The number of beams."})
41+
select_topk: bool = field(default=True, metadata={"help": "Whether to select top k tokens for generation."})
42+
top_p: float = field(
43+
default=0.0, metadata={"help": "The cumulative probability for top-p-filtering in the 'sampling' strategy."}
44+
)
45+
top_k: int = field(
46+
default=0,
47+
metadata={
48+
"help": "The number of highest probability tokens to keep for top-k-filtering in the 'sampling' strategy."
49+
},
50+
)
51+
no_block_position: bool = field(default=False)
52+
53+
54+
@dataclass
55+
class ModelArgument:
56+
model_name_or_path: str = field(
57+
default="THUDM/glm-2b", metadata={"help": "Build-in pretrained model name or the path to local model."}
58+
)
59+
label_smoothing: float = field(default=0.1, metadata={"help": "The label smoothing parameter."})
60+
lr_decay_ratio: float = field(default=0.1, metadata={"help": "The ratio for learning rate decrease"})
61+
lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"})
62+
63+
64+
def main():
65+
parser = PdArgumentParser((ModelArgument, DataArgument, TrainingArguments))
66+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
67+
68+
training_args.print_config(model_args, "Model")
69+
training_args.print_config(data_args, "Data")
70+
setattr(training_args, "label_smoothing", model_args.label_smoothing)
71+
setattr(training_args, "lr_decay_ratio", model_args.lr_decay_ratio)
72+
73+
paddle.set_device(training_args.device)
74+
75+
# Log on each process the small summary:
76+
logger.warning(
77+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
78+
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}"
79+
)
80+
81+
# Detecting last checkpoint.
82+
last_checkpoint = None
83+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
84+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
85+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 1:
86+
raise ValueError(
87+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
88+
"Use --overwrite_output_dir to overcome."
89+
)
90+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
91+
logger.info(
92+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
93+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
94+
)
95+
96+
dtype = None
97+
if training_args.fp16_opt_level == "O2":
98+
if training_args.fp16:
99+
dtype = "float16"
100+
if training_args.bf16:
101+
dtype = "bfloat16"
102+
103+
# Load the pretrained language model.
104+
model = AutoModelForConditionalGeneration.from_pretrained(
105+
model_args.model_name_or_path,
106+
output_predict=True,
107+
parallel_output=True,
108+
load_state_as_np=True,
109+
dtype=dtype, # todo enable set dtype to avoid additional mem usage
110+
tensor_parallel_degree=training_args.tensor_parallel_degree,
111+
tensor_parallel_rank=training_args.tensor_parallel_rank,
112+
)
113+
if model_args.lora:
114+
# TODO: hardcode parameters for now. Change after MergedLoRA is introduced
115+
lora_config = LoRAConfig(
116+
target_modules=[".*query_key_value.*"],
117+
r=4,
118+
lora_alpha=8,
119+
merge_weights=True,
120+
enable_lora_list=[[True, False, True]],
121+
tensor_parallel_degree=training_args.tensor_parallel_degree,
122+
)
123+
model = LoRAModel(model, lora_config)
124+
model.mark_only_lora_as_trainable()
125+
model.print_trainable_parameters()
126+
127+
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
128+
129+
# Load the dataset.
130+
train_ds, dev_ds = load_dataset(data_args.data_name, data_args.task_name, splits=["train", "dev"])
131+
132+
trans_func = partial(custom_instruction_convert_example, tokenizer=tokenizer, data_args=data_args)
133+
train_ds = train_ds.map(partial(trans_func, is_test=False, is_do_generation=False))
134+
test_ds = dev_ds.map(partial(trans_func, is_do_generation=False))
135+
collate_fn = DefaultDataCollator()
136+
137+
trainer = GLMTrainer(
138+
model=model,
139+
args=training_args,
140+
train_dataset=train_ds,
141+
eval_dataset=dev_ds,
142+
tokenizer=tokenizer,
143+
do_generation=False,
144+
data_collator=collate_fn,
145+
)
146+
if training_args.fp16_opt_level == "O2":
147+
trainer.disable_autocast_context_manager()
148+
149+
if training_args.do_train:
150+
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
151+
trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)
152+
trainer.log_metrics("train", train_result.metrics)
153+
trainer.save_metrics("train", train_result.metrics)
154+
trainer.save_state()
155+
156+
if training_args.do_eval:
157+
eval_result = trainer.evaluate(test_ds)
158+
trainer.log_metrics("test", eval_result)
159+
160+
161+
if __name__ == "__main__":
162+
main()

examples/language_model/glm/utils.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from collections import UserDict
1615
from typing import Any, Dict, List, Optional, Tuple, Union
1716

17+
import numpy as np
1818
import paddle
1919
import paddle.nn as nn
2020
from paddle import Tensor
@@ -91,6 +91,17 @@ def lr_lambda(current_step: int):
9191
self.lr_scheduler = LambdaDecay(self.args.learning_rate, lr_lambda, last_epoch=-1)
9292
return self.lr_scheduler
9393

94+
def log(self, logs: Dict[str, float], **kwargs) -> None:
95+
96+
if self.state.epoch is not None:
97+
logs["epoch"] = round(self.state.epoch, 4)
98+
99+
if "eval_loss" in logs:
100+
logs["eval_ppl"] = np.exp(logs["eval_loss"])
101+
output = {**logs, **{"step": self.state.global_step}}
102+
self.state.log_history.append(output)
103+
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs, **kwargs)
104+
94105

95106
@paddle.no_grad()
96107
def generate(

paddlenlp/datasets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .bellegroup import *
1516
from .cail2018_small import *
1617
from .cblue import *
1718
from .chnsenticorp import *

paddlenlp/datasets/bellegroup.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import os
17+
18+
from paddle.dataset.common import md5file
19+
from paddle.utils.download import get_path_from_url
20+
21+
from ..utils.env import DATA_HOME
22+
from .dataset import DatasetBuilder
23+
24+
__all__ = ["BelleGroup"]
25+
26+
27+
class BelleGroup(DatasetBuilder):
28+
"""
29+
From https://github.com/LianjiaTech/BELLE/tree/main
30+
31+
"""
32+
33+
BUILDER_CONFIGS = {
34+
"generated_chat_0.4M": {
35+
"url": "https://paddlenlp.bj.bcebos.com/datasets/BelleGroup/generated_chat_0.4M.zip",
36+
"md5": "9bb71d4f2aa99acede2a0c3a8e761905",
37+
"splits": {
38+
"train": [os.path.join("generated_chat_0.4M", "train.json"), "47ea511025fbda9ffd6e5178677bb027"],
39+
"dev": [os.path.join("generated_chat_0.4M", "dev.json"), "d7bd4b71cdb006b9de90ebb634ca1179"],
40+
},
41+
},
42+
"school_math_0.25M": {
43+
"url": "https://paddlenlp.bj.bcebos.com/datasets/BelleGroup/school_math_0.25M.zip",
44+
"md5": "10076cbdc0a7436d55481f0234db8609",
45+
"splits": {
46+
"train": [os.path.join("school_math_0.25M", "train.json"), "e5a36fc9deb015254686c51e21528683"],
47+
"dev": [os.path.join("school_math_0.25M", "dev.json"), "99e967c38e39ed919327c011d9f6288f"],
48+
},
49+
},
50+
"train_2M_CN": {
51+
"url": "https://paddlenlp.bj.bcebos.com/datasets/BelleGroup/train_2M_CN.zip",
52+
"md5": "da88aca71eb9f454fab39db6a7e851e6",
53+
"splits": {
54+
"train": [os.path.join("train_2M_CN", "train.json"), "83e2917701a31ecf5152e4e9f234fcd0"],
55+
"dev": [os.path.join("train_2M_CN", "dev.json"), "74f67f04e30896aeccc10930a7dc1f40"],
56+
},
57+
},
58+
"train_1M_CN": {
59+
"url": "https://paddlenlp.bj.bcebos.com/datasets/BelleGroup/train_1M_CN.zip",
60+
"md5": "65380b542e8ddb4db8f8d2be0f28795c",
61+
"splits": {
62+
"train": [os.path.join("train_1M_CN.zip", "train.json"), "489886aba320c74a1fdfad43c652635b"],
63+
"dev": [os.path.join("train_1M_CN.zip", "dev.json"), "7bbf382aeab89f4398b2beca984e20e8"],
64+
},
65+
},
66+
"train_0.5M_CN": {
67+
"url": "https://paddlenlp.bj.bcebos.com/datasets/BelleGroup/train_0.5M_CN.zip",
68+
"md5": "45be55109ca9595efa36eaaed7c475d3",
69+
"splits": {
70+
"train": [os.path.join("train_0.5M_CN.zip", "train.json"), "61dc155956622c8389265de33b439757"],
71+
"dev": [os.path.join("train_0.5M_CN.zip", "dev.json"), "72617388fbc4897cb2952df3e5303c2b"],
72+
},
73+
},
74+
"multiturn_chat_0.8M": {
75+
"url": "https://paddlenlp.bj.bcebos.com/datasets/BelleGroup/multiturn_chat_0.8M.zip",
76+
"md5": "974bc42c5920e5722146a89dce2b10cc",
77+
"splits": {
78+
"train": [os.path.join("multiturn_chat_0.8M", "train.json"), "27e3a7ecff0f4a199f6e7119909988e9"],
79+
"dev": [os.path.join("multiturn_chat_0.8M", "dev.json"), "8fec175ea5e71cc78498d8ca3c1d5e66"],
80+
},
81+
},
82+
}
83+
84+
def _get_data(self, mode, **kwargs):
85+
builder_config = self.BUILDER_CONFIGS[self.name]
86+
87+
default_root = os.path.join(DATA_HOME, self.__class__.__name__)
88+
filename, data_hash = builder_config["splits"][mode]
89+
fullname = os.path.join(default_root, filename)
90+
if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash):
91+
get_path_from_url(builder_config["url"], default_root, builder_config["md5"])
92+
93+
return fullname
94+
95+
def _read(self, filename, *args):
96+
with open(filename, "r", encoding="utf8") as f:
97+
for line in f:
98+
line = line.strip()
99+
if not line:
100+
continue
101+
102+
json_data = json.loads(line)
103+
104+
yield {
105+
"instruction": json_data["instruction"],
106+
"input": json_data["input"],
107+
"output": json_data["output"],
108+
}

0 commit comments

Comments
 (0)