Skip to content

Commit d93ae95

Browse files
author
Your Name
committed
重构代码
1 parent 472926d commit d93ae95

File tree

18 files changed

+165
-213
lines changed

18 files changed

+165
-213
lines changed

examples/train/rft/rft.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int):
2222
for device in range(device_count):
2323
sample_cmd = (f'{conda_prefix} CUDA_VISIBLE_DEVICES={device} swift sample '
2424
f'--model {model} --model_type {model_type} '
25-
f'--dataset {" ".join(dataset)} '
25+
f'--dataset {' '.join(dataset)} '
2626
f'--data_range {device} {device_count} '
2727
f'--max_length 2048 '
2828
f'--system "You are a math model, you should **think step by step** carefully, '
@@ -61,7 +61,7 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int):
6161
sample_cmd = (
6262
f'{conda_prefix} CUDA_VISIBLE_DEVICES={device} swift sample '
6363
f'--model {model} --model_type {model_type} ' # change to --resume_from_checkpoint to use the latest optimizer state # noqa
64-
f'--dataset {" ".join(dataset)} '
64+
f'--dataset {' '.join(dataset)} '
6565
f'--data_range {device} {device_count} '
6666
f'--max_length 2048 '
6767
f'--system "You are a math model, you should **think step by step** carefully, '
@@ -108,7 +108,7 @@ def do_train(model: str, model_type: str, datasets: List[str], iter, cmd='sft'):
108108
ga = 128 // get_device_count() // 2
109109
train_cmd = (f'{conda_prefix} {gpu_prefix} swift {cmd} '
110110
f'--model {model} --model_type {model_type} '
111-
f'--dataset {" ".join(datasets)} '
111+
f'--dataset {' '.join(datasets)} '
112112
f'--max_length 2048 '
113113
f'--num_train_epochs 1 '
114114
f'--load_args false '

gen_data.py

+22-13
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,36 @@
99
from typing import Dict, List
1010
from math_tool import parse_expression, SYMBOL_TO_OPERATION, OPERATION_DEFINITIONS
1111

12+
1213
def generate_safe_expression():
1314
"""Generate an expression that won't cause overflow errors"""
1415
# Start with a moderate number
1516
expression = str(random.randint(1, 20))
16-
17+
1718
# Add 3-5 operations with safe numbers
1819
num_ops = random.randint(1, 6)
19-
20+
2021
for i in range(num_ops):
2122
# Choose operation
2223
op = random.choice(['@', '&', '$', '^'])
23-
24+
2425
# For @ operation, use smaller numbers to avoid overflow
2526
if op == '@':
2627
# For exponentiation, keep the exponent small
2728
num = random.randint(1, 3)
2829
else:
2930
num = random.randint(1, 10)
30-
31+
3132
expression += op + str(num)
32-
33+
3334
return expression
3435

35-
def generate_dataset(num_samples: int = 1000, output_file: str = "math_operations_dataset.jsonl") -> None:
36+
37+
def generate_dataset(num_samples: int = 1000, output_file: str = 'math_operations_dataset.jsonl') -> None:
3638
"""
3739
Generates a dataset of custom mathematical expressions and their results.
3840
Saves the dataset as a JSONL file.
39-
41+
4042
Args:
4143
num_samples: Number of samples to generate
4244
output_file: Path to save the JSONL file
@@ -45,24 +47,31 @@ def generate_dataset(num_samples: int = 1000, output_file: str = "math_operation
4547
for _ in range(num_samples):
4648
# Generate a safe expression
4749
expression = generate_safe_expression()
48-
50+
4951
# Calculate the result
5052
try:
5153
result = parse_expression(expression)
52-
54+
5355
# Create the data entry
5456
# data_entry = {
5557
# "query": f"Calculate the result of the expression: {expression}",
5658
# "answer": result
5759
# }
58-
data_entry = {"messages": [{"role": "user", "content": f"Calculate the result of the expression: {expression}"}],"response":result}
59-
60+
data_entry = {
61+
'messages': [{
62+
'role': 'user',
63+
'content': f"Calculate the result of the expression: {expression}"
64+
}],
65+
'response': result
66+
}
67+
6068
# Write to JSONL file
6169
f.write(json.dumps(data_entry) + '\n')
6270
except Exception as e:
6371
print(f"Skipping problematic expression {expression}: {e}")
6472
continue
65-
73+
6674
print(f"Generated dataset with {num_samples} samples and saved to {output_file}")
6775

68-
generate_dataset()
76+
77+
generate_dataset()

scripts/benchmark/exp_utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def run(self, exp: Experiment):
122122
exp.runtime = runtime
123123
envs = deepcopy(runtime.get('env', {}))
124124
envs.update(os.environ)
125-
logger.info(f'Running cmd: {runtime["running_cmd"]}, env: {runtime.get("env", {})}')
125+
logger.info(f'Running cmd: {runtime['running_cmd']}, env: {runtime.get('env', {})}')
126126
os.makedirs('exp', exist_ok=True)
127127
log_file = os.path.join('exp', f'{exp.name}.eval.log')
128128
exp.handler = subprocess.Popen(runtime['running_cmd'] + f' > {log_file} 2>&1', env=envs, shell=True)
@@ -140,7 +140,7 @@ def run(self, exp: Experiment):
140140
exp.runtime = runtime
141141
envs = deepcopy(runtime.get('env', {}))
142142
envs.update(os.environ)
143-
logger.info(f'Running cmd: {runtime["running_cmd"]}, env: {runtime.get("env", {})}')
143+
logger.info(f'Running cmd: {runtime['running_cmd']}, env: {runtime.get('env', {})}')
144144
os.makedirs('exp', exist_ok=True)
145145
log_file = os.path.join('exp', f'{exp.name}.{exp.cmd}.log')
146146
exp.handler = subprocess.Popen(runtime['running_cmd'] + f' > {log_file} 2>&1', env=envs, shell=True)
@@ -162,10 +162,10 @@ def _build_eval_cmd(self, exp: Experiment):
162162
if best_model_checkpoint is not None:
163163
if not os.path.exists(os.path.join(best_model_checkpoint, 'args.json')):
164164
cmd = f'swift eval --ckpt_dir {best_model_checkpoint} ' \
165-
+ f'--infer_backend pt --train_type full --eval_dataset {" ".join(eval_dataset)}'
165+
+ f'--infer_backend pt --train_type full --eval_dataset {' '.join(eval_dataset)}'
166166
else:
167-
cmd = f'swift eval --model {exp.args.get("model")} --infer_backend pt ' \
168-
f'--eval_dataset {" ".join(eval_dataset)}'
167+
cmd = f'swift eval --model {exp.args.get('model')} --infer_backend pt ' \
168+
f'--eval_dataset {' '.join(eval_dataset)}'
169169

170170
return {
171171
'running_cmd': cmd,

scripts/benchmark/generate_report.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -69,23 +69,23 @@ def tuner_hyper_params(self):
6969
return ''
7070
if args['sft_type'] in ('lora', 'adalora', 'longlora'):
7171
if 'lora_rank' in args:
72-
hyper_params += f'rank={args["lora_rank"]}/' \
73-
f'target={args["lora_target_modules"]}/' \
74-
f'alpha={args["lora_alpha"]}/' \
75-
f'lr_ratio={args.get("lora_lr_ratio", None)}/' \
76-
f'use_rslora={args.get("use_rslora", False)}/' \
77-
f'use_dora={args.get("use_dora", False)}'
72+
hyper_params += f'rank={args['lora_rank']}/' \
73+
f'target={args['lora_target_modules']}/' \
74+
f'alpha={args['lora_alpha']}/' \
75+
f'lr_ratio={args.get('lora_lr_ratio', None)}/' \
76+
f'use_rslora={args.get('use_rslora', False)}/' \
77+
f'use_dora={args.get('use_dora', False)}'
7878
else:
7979
hyper_params = ''
8080
if args['sft_type'] == 'full':
8181
if 'use_galore' in args and args['use_galore'] == 'true':
82-
hyper_params += f'galore_rank={args["galore_rank"]}/' \
83-
f'galore_per_parameter={args["galore_optim_per_parameter"]}/' \
84-
f'galore_with_embedding={args["galore_with_embedding"]}/'
82+
hyper_params += f'galore_rank={args['galore_rank']}/' \
83+
f'galore_per_parameter={args['galore_optim_per_parameter']}/' \
84+
f'galore_with_embedding={args['galore_with_embedding']}/'
8585
if args['sft_type'] == 'llamapro':
86-
hyper_params += f'num_blocks={args["llamapro_num_new_blocks"]}/'
86+
hyper_params += f'num_blocks={args['llamapro_num_new_blocks']}/'
8787
if 'neftune_noise_alpha' in args and args['neftune_noise_alpha']:
88-
hyper_params += f'neftune_noise_alpha={args["neftune_noise_alpha"]}/'
88+
hyper_params += f'neftune_noise_alpha={args['neftune_noise_alpha']}/'
8989

9090
if hyper_params.endswith('/'):
9191
hyper_params = hyper_params[:-1]
@@ -95,8 +95,8 @@ def tuner_hyper_params(self):
9595
def hyper_parameters(self):
9696
if 'learning_rate' not in self.args:
9797
return ''
98-
return f'lr={self.args["learning_rate"]}/' \
99-
f'epoch={self.args["num_train_epochs"]}'
98+
return f'lr={self.args['learning_rate']}/' \
99+
f'epoch={self.args['num_train_epochs']}'
100100

101101
@property
102102
def train_speed(self):
@@ -190,10 +190,10 @@ def generate_sft_report(outputs: List[ModelOutput]):
190190
ceval_acc = '' if not ceval_acc else f'**{ceval_acc:.3f}**'
191191

192192
line = f'|{output.name}|' \
193-
f'{output.args["model_type"]}|' \
194-
f'{output.args.get("dataset")}|' \
195-
f'{output.args.get("train_dataset_mix_ratio", 0.)}|' \
196-
f'{output.args.get("sft_type")}|' \
193+
f'{output.args['model_type']}|' \
194+
f'{output.args.get('dataset')}|' \
195+
f'{output.args.get('train_dataset_mix_ratio', 0.)}|' \
196+
f'{output.args.get('sft_type')}|' \
197197
f'{output.tuner_hyper_params}|' \
198198
f'{output.num_trainable_parameters}({output.trainable_parameters_percentage})|' \
199199
f'{use_flash_attn}|' \
@@ -267,14 +267,14 @@ def generate_export_report(outputs: List[ModelOutput]):
267267
ceval_acc = '' if not ceval_acc else f'**{ceval_acc:.3f}**'
268268

269269
if output.train_dataset_info:
270-
dataset_info = f'{output.args["dataset"]}/{output.train_dataset_info}'
270+
dataset_info = f'{output.args['dataset']}/{output.train_dataset_info}'
271271
else:
272-
dataset_info = f'{output.args["dataset"]}'
272+
dataset_info = f'{output.args['dataset']}'
273273
line = f'|{output.name}|' \
274-
f'{output.args["model_type"]}|' \
274+
f'{output.args['model_type']}|' \
275275
f'{dataset_info}|' \
276-
f'{output.args["quant_method"]}|' \
277-
f'{output.args["quant_bits"]}|' \
276+
f'{output.args['quant_method']}|' \
277+
f'{output.args['quant_bits']}|' \
278278
f'{infer_speed}|' \
279279
f'{gsm8k_acc}|' \
280280
f'{arc_acc}|' \

swift/llm/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def _compat_dsw_gradio(args) -> None:
4343
os.environ['GRADIO_ROOT_PATH'] = f"/{os.environ['JUPYTER_NAME']}/proxy/{args.server_port}"
4444

4545
def main(self):
46-
logger.info(f'Start time of running main: {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")}')
46+
logger.info(f'Start time of running main: {dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}')
4747
result = self.run()
48-
logger.info(f'End time of running main: {dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")}')
48+
logger.info(f'End time of running main: {dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}')
4949
return result
5050

5151
@abstractmethod

swift/llm/dataset/dataset/llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
285285
chapter = row[f'chapter{i}']
286286
if chapter is not None:
287287
cur_prompt += f'{chapter}'
288-
cur_prompt += f'{row["response"]}'
288+
cur_prompt += f'{row['response']}'
289289
return super().preprocess({'response': cur_prompt})
290290

291291

swift/llm/dataset/dataset/mllm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def preprocess_row(row: Dict[str, Any]) -> Dict[str, Any]:
566566
what = ''
567567
if ':' in action:
568568
action, what = action[:action.find(':')], action[action.find(':') + 1:]
569-
row['response'] = f'Action: {action.strip()}\nAction Input: {where.strip()}{"," + what.strip()}'
569+
row['response'] = f'Action: {action.strip()}\nAction Input: {where.strip()}{',' + what.strip()}'
570570
return row
571571

572572
conversations = []

swift/llm/export/ollama.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ def export_to_ollama(args: ExportArguments):
3838
with open(os.path.join(args.output_dir, 'Modelfile'), 'w', encoding='utf-8') as f:
3939
f.write(f'FROM {pt_engine.model_dir}\n')
4040
f.write(f'TEMPLATE """{{{{ if .System }}}}'
41-
f'{replace_and_concat(template, template_meta.system_prefix, "{{SYSTEM}}", "{{ .System }}")}'
42-
f'{{{{ else }}}}{replace_and_concat(template, template_meta.prefix, "", "")}'
41+
f'{replace_and_concat(template, template_meta.system_prefix, '{{SYSTEM}}', '{{ .System }}')}'
42+
f'{{{{ else }}}}{replace_and_concat(template, template_meta.prefix, '', '')}'
4343
f'{{{{ end }}}}')
4444
f.write(f'{{{{ if .Prompt }}}}'
45-
f'{replace_and_concat(template, template_meta.prompt, "{{QUERY}}", "{{ .Prompt }}")}'
45+
f'{replace_and_concat(template, template_meta.prompt, '{{QUERY}}', '{{ .Prompt }}')}'
4646
f'{{{{ end }}}}')
4747
f.write('{{ .Response }}')
4848
f.write(replace_and_concat(template, template_meta.suffix, '', '') + '"""\n')
49-
f.write(f'PARAMETER stop "{replace_and_concat(template, template_meta.suffix, "", "")}"\n')
49+
f.write(f'PARAMETER stop "{replace_and_concat(template, template_meta.suffix, '', '')}"\n')
5050

5151
request_config = RequestConfig(
5252
temperature=args.temperature,
@@ -65,5 +65,5 @@ def export_to_ollama(args: ExportArguments):
6565
logger.info('Save Modelfile done, you can start ollama by:')
6666
logger.info('> ollama serve')
6767
logger.info('In another terminal:')
68-
logger.info('> ollama create my-custom-model ' f'-f {os.path.join(args.output_dir, "Modelfile")}')
68+
logger.info('> ollama create my-custom-model ' f'-f {os.path.join(args.output_dir, 'Modelfile')}')
6969
logger.info('> ollama run my-custom-model')

swift/llm/infer/protocol.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ class ChatCompletionResponse:
261261
def to_cmpl_response(self) -> 'CompletionResponse':
262262
self = deepcopy(self)
263263
choices = [choice.to_cmpl_choice() for choice in self.choices]
264-
id_ = f'cmpl{self.id[len("chatcmpl"):]}'
264+
id_ = f'cmpl{self.id[len('chatcmpl'):]}'
265265
return CompletionResponse(self.model, choices, self.usage, id_, created=self.created)
266266

267267

@@ -315,7 +315,7 @@ class ChatCompletionStreamResponse:
315315
def to_cmpl_response(self) -> 'CompletionStreamResponse':
316316
self = deepcopy(self)
317317
choices = [choice.to_cmpl_choice() for choice in self.choices]
318-
id_ = f'cmpl{self.id[len("chatcmpl"):]}'
318+
id_ = f'cmpl{self.id[len('chatcmpl'):]}'
319319
return CompletionStreamResponse(self.model, choices, self.usage, id_, created=self.created)
320320

321321

swift/plugin/tool_call.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Tuple,Any, Optional
1+
from typing import Tuple, Any, Optional
2+
3+
24
class TOOL_CALL:
5+
36
def __call__(self, completion: str) -> Tuple[Any, bool, Optional[float]]:
47
raise NotImplementedError
58

69

7-
tools = {
8-
9-
}
10+
tools = {}

swift/trainers/arguments.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ class GRPOArgumentsMixin:
104104
offload_optimizer: bool = False
105105
offload_model: bool = False
106106
gc_collect_after_offload: bool = False
107-
is_reward_tool_call:bool = True #是否额外单独计算每个tool call的format得分
108-
tool_call_weight:float = 1.0
109-
tool_call:str = None
107+
is_reward_tool_call: bool = True #是否额外单独计算每个tool call的format得分
108+
tool_call_weight: float = 1.0
109+
tool_call: str = None
110110

111111

112112
@dataclass

swift/trainers/rlhf_arguments.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import List, Optional,Callable
2+
from typing import List, Optional, Callable
33

44
from trl import CPOConfig as HfCPOConfig
55
from trl import DPOConfig as HfDPOConfig

0 commit comments

Comments
 (0)