Skip to content

Commit df81a99

Browse files
committed
add full finetune code from internlm2
1 parent 252adc7 commit df81a99

File tree

1 file changed

+222
-0
lines changed

1 file changed

+222
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
"""Data format:
3+
[
4+
{
5+
"conversation": [
6+
{
7+
"system": "",
8+
"input": "xxx",
9+
"output": "xxx"
10+
},
11+
{
12+
"input": "xxx",
13+
"output": "xxx"
14+
}
15+
]
16+
},
17+
...
18+
]
19+
Please refer to https://github.com/InternLM/xtuner/blob/main/docs/en/user_guides/dataset_format.md for details.
20+
""" # noqa: E501
21+
from datasets import load_dataset
22+
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
23+
LoggerHook, ParamSchedulerHook)
24+
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
25+
from torch.optim import AdamW
26+
from torch.utils.data import BatchSampler
27+
from transformers import AutoModelForCausalLM, AutoTokenizer
28+
29+
from xtuner.dataset import process_hf_dataset
30+
from xtuner.dataset.collate_fns import default_collate_fn
31+
from xtuner.dataset.map_fns import template_map_fn_factory
32+
from xtuner.dataset.samplers import InternRepoSampler
33+
from xtuner.engine import (DatasetInfoHook, EvaluateChatHook, ThroughputHook,
34+
VarlenAttnArgsToMessageHubHook)
35+
from xtuner.engine.runner import TrainLoop
36+
from xtuner.model import SupervisedFinetune
37+
from xtuner.utils import PROMPT_TEMPLATE
38+
39+
#######################################################################
40+
# PART 1 Settings #
41+
#######################################################################
42+
# Model
43+
pretrained_model_name_or_path = 'internlm/internlm2-chat-7b'
44+
use_varlen_attn = True
45+
46+
# Data
47+
data_files = ['/path/to/json/file.json']
48+
prompt_template = PROMPT_TEMPLATE.internlm2_chat
49+
max_length = 32768
50+
pack_to_max_length = True
51+
52+
# Scheduler & Optimizer
53+
# batch size per device, set to 1 if `use_varlen_attn` = True
54+
# To clarify, enlarging the batch size essentially enlarges the `max_length`.
55+
# For example, doubling the max length is tantamount to doubling the batch size
56+
batch_size = 1
57+
accumulative_counts = 1 # 1bs * 1acc * 64gpu = 64 batchsize
58+
dataloader_num_workers = 4
59+
max_epochs = 1
60+
optim_type = AdamW
61+
lr = 4e-5
62+
betas = (0.9, 0.95)
63+
weight_decay = 0.01
64+
max_norm = 1 # grad clip
65+
warm_up_ratio = 0.025
66+
67+
# Save
68+
save_steps = 500
69+
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
70+
71+
# Evaluate the generation performance during the training
72+
evaluation_freq = 500
73+
SYSTEM = ''
74+
evaluation_inputs = [
75+
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
76+
]
77+
78+
#######################################################################
79+
# PART 2 Model & Tokenizer #
80+
#######################################################################
81+
tokenizer = dict(
82+
type=AutoTokenizer.from_pretrained,
83+
pretrained_model_name_or_path=pretrained_model_name_or_path,
84+
trust_remote_code=True,
85+
padding_side='right')
86+
87+
model = dict(
88+
type=SupervisedFinetune,
89+
use_varlen_attn=use_varlen_attn,
90+
llm=dict(
91+
type=AutoModelForCausalLM.from_pretrained,
92+
pretrained_model_name_or_path=pretrained_model_name_or_path,
93+
trust_remote_code=True))
94+
95+
#######################################################################
96+
# PART 3 Dataset & Dataloader #
97+
#######################################################################
98+
train_dataset = dict(
99+
type=process_hf_dataset,
100+
use_varlen_attn=use_varlen_attn,
101+
dataset=dict(type=load_dataset, path='json', data_files=data_files),
102+
tokenizer=tokenizer,
103+
max_length=max_length,
104+
dataset_map_fn=None,
105+
template_map_fn=dict(
106+
type=template_map_fn_factory, template=prompt_template),
107+
remove_unused_columns=True,
108+
shuffle_before_pack=True,
109+
pack_to_max_length=pack_to_max_length)
110+
111+
train_dataloader = dict(
112+
batch_size=batch_size,
113+
num_workers=dataloader_num_workers,
114+
dataset=train_dataset,
115+
sampler=dict(type=InternRepoSampler, shuffle=True, seed=1024),
116+
batch_sampler=dict(
117+
type=BatchSampler, drop_last=True, batch_size=batch_size),
118+
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
119+
120+
#######################################################################
121+
# PART 4 Scheduler & Optimizer #
122+
#######################################################################
123+
# optimizer
124+
optim_wrapper = dict(
125+
type=AmpOptimWrapper,
126+
optimizer=dict(
127+
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
128+
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
129+
accumulative_counts=accumulative_counts,
130+
loss_scale='dynamic',
131+
)
132+
133+
# learning policy
134+
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
135+
param_scheduler = [
136+
dict(
137+
type='LinearLR',
138+
start_factor=1 / 40,
139+
by_epoch=True,
140+
begin=0,
141+
end=warm_up_ratio * max_epochs,
142+
convert_to_iter_based=True),
143+
dict(
144+
type=CosineAnnealingLR,
145+
eta_min=lr * 0.15,
146+
by_epoch=True,
147+
begin=warm_up_ratio * max_epochs,
148+
end=max_epochs,
149+
convert_to_iter_based=True)
150+
]
151+
152+
# train, val, test setting
153+
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
154+
155+
#######################################################################
156+
# PART 5 Runtime #
157+
#######################################################################
158+
# Log the dialogue periodically during the training process, optional
159+
custom_hooks = [
160+
dict(
161+
type=DatasetInfoHook, tokenizer=tokenizer,
162+
is_intern_repo_dataset=True),
163+
dict(
164+
type=EvaluateChatHook,
165+
tokenizer=tokenizer,
166+
every_n_iters=evaluation_freq,
167+
evaluation_inputs=evaluation_inputs,
168+
system=SYSTEM,
169+
prompt_template=prompt_template),
170+
dict(type=ThroughputHook)
171+
]
172+
173+
if use_varlen_attn:
174+
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
175+
176+
# configure default hooks
177+
default_hooks = dict(
178+
# record the time of every iteration.
179+
timer=dict(type=IterTimerHook),
180+
# print log every 100 iterations.
181+
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1),
182+
# enable the parameter scheduler.
183+
param_scheduler=dict(type=ParamSchedulerHook),
184+
# save checkpoint per `save_steps`.
185+
checkpoint=dict(
186+
type=CheckpointHook,
187+
by_epoch=False,
188+
interval=save_steps,
189+
max_keep_ckpts=save_total_limit),
190+
# set sampler seed in distributed evrionment.
191+
sampler_seed=dict(type=DistSamplerSeedHook),
192+
)
193+
194+
# configure environment
195+
env_cfg = dict(
196+
# whether to enable cudnn benchmark
197+
cudnn_benchmark=False,
198+
# set multi process parameters
199+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
200+
# set distributed parameters
201+
dist_cfg=dict(backend='nccl'),
202+
)
203+
204+
# set visualizer
205+
visualizer = None
206+
207+
# set log level
208+
log_level = 'INFO'
209+
210+
# load from which checkpoint
211+
load_from = None
212+
213+
# whether to resume training from the loaded checkpoint
214+
resume = False
215+
216+
# Defaults to use random seed and disable `deterministic`
217+
randomness = dict(seed=None, deterministic=False)
218+
219+
log_processor = dict(
220+
by_epoch=False,
221+
window_size=1,
222+
mean_pattern=r'.*(loss|time|data_time|grad_norm|tflops).*')

0 commit comments

Comments
 (0)