Skip to content

Commit 0124001

Browse files
committed
other 2 configs for base model
1 parent df81a99 commit 0124001

File tree

2 files changed

+408
-0
lines changed

2 files changed

+408
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
from datasets import load_dataset
4+
from mmengine.dataset import DefaultSampler
5+
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
6+
LoggerHook, ParamSchedulerHook)
7+
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
8+
from peft import LoraConfig
9+
from torch.optim import AdamW
10+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
11+
BitsAndBytesConfig)
12+
13+
from xtuner.dataset import process_hf_dataset
14+
from xtuner.dataset.collate_fns import default_collate_fn
15+
from xtuner.dataset.map_fns import template_map_fn_factory
16+
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
17+
from xtuner.model import SupervisedFinetune
18+
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
19+
20+
#######################################################################
21+
# PART 1 Settings #
22+
#######################################################################
23+
# Model
24+
# pretrained_model_name_or_path = '/root/share/model_repos/internlm2-chat-7b'
25+
pretrained_model_name_or_path = '/root/share/model_repos/internlm2-base-7b'
26+
27+
# Data
28+
# data_path = 'merge.json'
29+
data_path ='/root/StableCascade/emollm2/EmoLLM/datasets/processed/combined_data.json'
30+
31+
# https://github.com/InternLM/xtuner/blob/main/xtuner/utils/templates.py#L24C25-L24C25
32+
prompt_template = PROMPT_TEMPLATE.internlm2_chat # there is No internlm2_base
33+
34+
max_length = 2048
35+
pack_to_max_length = True
36+
37+
# Scheduler & Optimizer
38+
39+
# batch_size = 8 # per_device
40+
# accumulative_counts = 2
41+
batch_size = 8 # per_device
42+
accumulative_counts = 1
43+
44+
dataloader_num_workers = 0
45+
max_epochs = 10
46+
optim_type = AdamW
47+
lr = 2e-4
48+
betas = (0.9, 0.999)
49+
weight_decay = 0
50+
max_norm = 1 # grad clip
51+
warmup_ratio = 0.03
52+
53+
# Evaluate the generation performance during the training
54+
evaluation_freq = 500
55+
# SYSTEM = "现在你是一个心理专家,我有一些心理问题,请你用专业的知识帮我解决。"
56+
SYSTEM = "你是心理健康助手EmoLLM,由EmoLLM团队打造。你旨在通过专业心理咨询,协助来访者完成心理诊断。请充分利用专业心理学知识与咨询技术,一步步帮助来访者解决心理问题。"
57+
evaluation_inputs = [
58+
'我最近总是感到很焦虑,尤其是在学业上。我有个特别崇拜的同学,他好像在各方面都比我优秀,我总觉得自己怎么努力也追不上他,这让我压力特别大。', '我知道应该理性看待,但就是忍不住会去比较。我甚至晚上会因为这个睡不着觉,总想着怎样才能像他那样出色。'
59+
]
60+
61+
#######################################################################
62+
# PART 2 Model & Tokenizer #
63+
#######################################################################
64+
tokenizer = dict(
65+
type=AutoTokenizer.from_pretrained,
66+
pretrained_model_name_or_path=pretrained_model_name_or_path,
67+
trust_remote_code=True,
68+
padding_side='right')
69+
70+
model = dict(
71+
type=SupervisedFinetune,
72+
llm=dict(
73+
type=AutoModelForCausalLM.from_pretrained,
74+
pretrained_model_name_or_path=pretrained_model_name_or_path,
75+
trust_remote_code=True,
76+
torch_dtype=torch.float16,
77+
quantization_config=dict(
78+
type=BitsAndBytesConfig,
79+
load_in_4bit=True,
80+
load_in_8bit=False,
81+
llm_int8_threshold=6.0,
82+
llm_int8_has_fp16_weight=False,
83+
bnb_4bit_compute_dtype=torch.float16,
84+
bnb_4bit_use_double_quant=True,
85+
bnb_4bit_quant_type='nf4')),
86+
lora=dict(
87+
type=LoraConfig,
88+
# r=64,
89+
# lora_alpha=16,
90+
r=16,
91+
lora_alpha=32,
92+
lora_dropout=0.1,
93+
bias='none',
94+
task_type='CAUSAL_LM'))
95+
96+
#######################################################################
97+
# PART 3 Dataset & Dataloader #
98+
#######################################################################
99+
alpaca_en = dict(
100+
type=process_hf_dataset,
101+
dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)),
102+
tokenizer=tokenizer,
103+
max_length=max_length,
104+
dataset_map_fn=None,
105+
template_map_fn=dict(
106+
type=template_map_fn_factory, template=prompt_template),
107+
remove_unused_columns=True,
108+
shuffle_before_pack=True,
109+
pack_to_max_length=pack_to_max_length)
110+
111+
train_dataloader = dict(
112+
batch_size=batch_size,
113+
num_workers=dataloader_num_workers,
114+
dataset=alpaca_en,
115+
sampler=dict(type=DefaultSampler, shuffle=True),
116+
collate_fn=dict(type=default_collate_fn))
117+
118+
#######################################################################
119+
# PART 4 Scheduler & Optimizer #
120+
#######################################################################
121+
# optimizer
122+
optim_wrapper = dict(
123+
type=AmpOptimWrapper,
124+
optimizer=dict(
125+
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
126+
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
127+
accumulative_counts=accumulative_counts,
128+
loss_scale='dynamic',
129+
dtype='float16')
130+
131+
# learning policy
132+
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
133+
param_scheduler = [
134+
dict(
135+
type=LinearLR,
136+
start_factor=1e-5,
137+
by_epoch=True,
138+
begin=0,
139+
end=warmup_ratio * max_epochs,
140+
convert_to_iter_based=True),
141+
dict(
142+
type=CosineAnnealingLR,
143+
eta_min=0.0,
144+
by_epoch=True,
145+
begin=warmup_ratio * max_epochs,
146+
T_max=max_epochs,
147+
convert_to_iter_based=True)
148+
]
149+
150+
# train, val, test setting
151+
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
152+
153+
#######################################################################
154+
# PART 5 Runtime #
155+
#######################################################################
156+
# Log the dialogue periodically during the training process, optional
157+
custom_hooks = [
158+
dict(type=DatasetInfoHook, tokenizer=tokenizer),
159+
dict(
160+
type=EvaluateChatHook,
161+
tokenizer=tokenizer,
162+
every_n_iters=evaluation_freq,
163+
evaluation_inputs=evaluation_inputs,
164+
system=SYSTEM,
165+
prompt_template=prompt_template)
166+
]
167+
168+
# configure default hooks
169+
default_hooks = dict(
170+
# record the time of every iteration.
171+
timer=dict(type=IterTimerHook),
172+
# print log every 100 iterations.
173+
logger=dict(type=LoggerHook, interval=10),
174+
# enable the parameter scheduler.
175+
param_scheduler=dict(type=ParamSchedulerHook),
176+
# save checkpoint per epoch.
177+
checkpoint=dict(type=CheckpointHook, interval=1),
178+
# set sampler seed in distributed evrionment.
179+
sampler_seed=dict(type=DistSamplerSeedHook),
180+
)
181+
182+
# configure environment
183+
env_cfg = dict(
184+
# whether to enable cudnn benchmark
185+
cudnn_benchmark=False,
186+
# set multi process parameters
187+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
188+
# set distributed parameters
189+
dist_cfg=dict(backend='nccl'),
190+
)
191+
192+
# set visualizer
193+
visualizer = None
194+
195+
# set log level
196+
log_level = 'INFO'
197+
198+
# load from which checkpoint
199+
load_from = None
200+
201+
# whether to resume training from the loaded checkpoint
202+
resume = False
203+
204+
# Defaults to use random seed and disable `deterministic`
205+
randomness = dict(seed=None, deterministic=False)

0 commit comments

Comments
 (0)