Skip to content

Commit 45b143b

Browse files
committed
feat: finetune Qwen and demo
1 parent dc9208f commit 45b143b

14 files changed

+606
-16
lines changed

.github/workflows/reademe-contributors

-14
This file was deleted.

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ ESConv.json
22
.DS_Store
33
__pycache__/
44
tmp/
5-
data/zhipuai/
5+
zhipuai/

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33

44
## 🌟 Contributors
55

6-
[![EmoLLM contributors](https://contrib.rocks/image?repo=aJupyter/EmoLLM&max=2000)](https://github.com/aJupyter/EmoLLM/graphs/contributors)
6+
[![EmoLLM contributors](https://contrib.rocks/image?repo=aJupyter/EmoLLM&max=200)](https://github.com/aJupyter/EmoLLM/graphs/contributors)
77

88

demo/cli_qwen.py

+210
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright (c) Alibaba Cloud.
2+
#
3+
# This source code is licensed under the license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""A simple command-line interactive chat demo."""
7+
8+
import argparse
9+
import os
10+
import platform
11+
import shutil
12+
from copy import deepcopy
13+
14+
import torch
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
from transformers.generation import GenerationConfig
17+
from transformers.trainer_utils import set_seed
18+
19+
DEFAULT_CKPT_PATH = './merged'
20+
21+
_WELCOME_MSG = '''\
22+
Welcome to use Emo-Chat model, type text to start chat, type :h to show command help.
23+
(欢迎使用 Emo-Chat 模型,输入内容即可进行对话,:h 显示命令帮助。)
24+
25+
Note: This demo is governed by the original license of Qwen.
26+
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc.
27+
(注:本演示受EmoLLM的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)
28+
'''
29+
_HELP_MSG = '''\
30+
Commands:
31+
:help / :h Show this help message 显示帮助信息
32+
:exit / :quit / :q Exit the demo 退出Demo
33+
:clear / :cl Clear screen 清屏
34+
:clear-his / :clh Clear history 清除对话历史
35+
:history / :his Show history 显示对话历史
36+
:seed Show current random seed 显示当前随机种子
37+
:seed <N> Set random seed to <N> 设置随机种子
38+
:conf Show current generation config 显示生成配置
39+
:conf <key>=<value> Change generation config 修改生成配置
40+
:reset-conf Reset generation config 重置生成配置
41+
'''
42+
43+
44+
def _load_model_tokenizer(args):
45+
tokenizer = AutoTokenizer.from_pretrained(
46+
args.checkpoint_path, trust_remote_code=True, resume_download=True,
47+
)
48+
49+
if args.cpu_only:
50+
device_map = "cpu"
51+
else:
52+
device_map = "auto"
53+
54+
model = AutoModelForCausalLM.from_pretrained(
55+
args.checkpoint_path,
56+
device_map=device_map,
57+
trust_remote_code=True,
58+
resume_download=True,
59+
).eval()
60+
61+
config = GenerationConfig.from_pretrained(
62+
args.checkpoint_path, trust_remote_code=True, resume_download=True,
63+
)
64+
65+
return model, tokenizer, config
66+
67+
68+
def _gc():
69+
import gc
70+
gc.collect()
71+
if torch.cuda.is_available():
72+
torch.cuda.empty_cache()
73+
74+
75+
def _clear_screen():
76+
if platform.system() == "Windows":
77+
os.system("cls")
78+
else:
79+
os.system("clear")
80+
81+
82+
def _print_history(history):
83+
terminal_width = shutil.get_terminal_size()[0]
84+
print(f'History ({len(history)})'.center(terminal_width, '='))
85+
for index, (query, response) in enumerate(history):
86+
print(f'User[{index}]: {query}')
87+
print(f'QWen[{index}]: {response}')
88+
print('=' * terminal_width)
89+
90+
91+
def _get_input() -> str:
92+
while True:
93+
try:
94+
message = input('User> ').strip()
95+
except UnicodeDecodeError:
96+
print('[ERROR] Encoding error in input')
97+
continue
98+
except KeyboardInterrupt:
99+
exit(1)
100+
if message:
101+
return message
102+
print('[ERROR] Query is empty')
103+
104+
105+
def main():
106+
parser = argparse.ArgumentParser(
107+
description='QWen-Chat command-line interactive chat demo.')
108+
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
109+
help="Checkpoint name or path, default to %(default)r")
110+
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
111+
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
112+
args = parser.parse_args()
113+
114+
history, response = [], ''
115+
116+
model, tokenizer, config = _load_model_tokenizer(args)
117+
orig_gen_config = deepcopy(model.generation_config)
118+
119+
_clear_screen()
120+
print(_WELCOME_MSG)
121+
122+
seed = args.seed
123+
124+
while True:
125+
query = _get_input()
126+
127+
# Process commands.
128+
if query.startswith(':'):
129+
command_words = query[1:].strip().split()
130+
if not command_words:
131+
command = ''
132+
else:
133+
command = command_words[0]
134+
135+
if command in ['exit', 'quit', 'q']:
136+
break
137+
elif command in ['clear', 'cl']:
138+
_clear_screen()
139+
print(_WELCOME_MSG)
140+
_gc()
141+
continue
142+
elif command in ['clear-history', 'clh']:
143+
print(f'[INFO] All {len(history)} history cleared')
144+
history.clear()
145+
_gc()
146+
continue
147+
elif command in ['help', 'h']:
148+
print(_HELP_MSG)
149+
continue
150+
elif command in ['history', 'his']:
151+
_print_history(history)
152+
continue
153+
elif command in ['seed']:
154+
if len(command_words) == 1:
155+
print(f'[INFO] Current random seed: {seed}')
156+
continue
157+
else:
158+
new_seed_s = command_words[1]
159+
try:
160+
new_seed = int(new_seed_s)
161+
except ValueError:
162+
print(f'[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number')
163+
else:
164+
print(f'[INFO] Random seed changed to {new_seed}')
165+
seed = new_seed
166+
continue
167+
elif command in ['conf']:
168+
if len(command_words) == 1:
169+
print(model.generation_config)
170+
else:
171+
for key_value_pairs_str in command_words[1:]:
172+
eq_idx = key_value_pairs_str.find('=')
173+
if eq_idx == -1:
174+
print('[WARNING] format: <key>=<value>')
175+
continue
176+
conf_key, conf_value_str = key_value_pairs_str[:eq_idx], key_value_pairs_str[eq_idx + 1:]
177+
try:
178+
conf_value = eval(conf_value_str)
179+
except Exception as e:
180+
print(e)
181+
continue
182+
else:
183+
print(f'[INFO] Change config: model.generation_config.{conf_key} = {conf_value}')
184+
setattr(model.generation_config, conf_key, conf_value)
185+
continue
186+
elif command in ['reset-conf']:
187+
print('[INFO] Reset generation config')
188+
model.generation_config = deepcopy(orig_gen_config)
189+
print(model.generation_config)
190+
continue
191+
else:
192+
# As normal query.
193+
pass
194+
195+
# Run chat.
196+
set_seed(seed)
197+
try:
198+
for response in model.chat_stream(tokenizer, query, history=history, generation_config=config):
199+
_clear_screen()
200+
print(f"\nUser: {query}")
201+
print(f"\nQwen-Chat: {response}")
202+
except KeyboardInterrupt:
203+
print('[WARNING] Generation interrupted')
204+
continue
205+
206+
history.append((query, response))
207+
208+
209+
if __name__ == "__main__":
210+
main()

demo/requirements_qwen.txt

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
gradio<3.42
2+
mdtex2html

0 commit comments

Comments
 (0)