Skip to content

Commit fe6bfc5

Browse files
committed
多线程 mmlu 脚本
1 parent c5a4c21 commit fe6bfc5

File tree

2 files changed

+98
-22
lines changed

2 files changed

+98
-22
lines changed

llm/benchmark/mmlu_pro/evaluate_from_api.py

+95-21
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
import argparse
99
import requests
1010

11+
import os
12+
import threading
13+
from concurrent.futures import ThreadPoolExecutor
14+
from tqdm import tqdm
15+
from functools import partial
16+
1117
API_KEY = ""
1218
random.seed(12345)
1319

@@ -326,49 +332,117 @@ def merge_result(res, curr):
326332
return res
327333

328334

335+
# def evaluate(subjects):
336+
# # client = get_client()
337+
# test_df, dev_df = load_mmlu_pro()
338+
# if not subjects:
339+
# subjects = list(test_df.keys())
340+
# print("assigned subjects", subjects)
341+
# for subject in subjects:
342+
# test_data = test_df[subject]
343+
# output_res_path = os.path.join(args.output_dir, subject + "_result.json")
344+
# output_summary_path = os.path.join(args.output_dir, subject + "_summary.json")
345+
# res, category_record = update_result(output_res_path)
346+
347+
# k = 0
348+
# for each in tqdm(test_data):
349+
# # k += 1
350+
# # if k % 10 != 0:
351+
# # continue
352+
353+
# label = each["answer"]
354+
# category = subject
355+
# # import pdb;pdb.set_trace()
356+
# pred, response, exist = single_request(None, each, dev_df, res)
357+
# if response is not None:
358+
# res, category_record = update_result(output_res_path)
359+
# if category not in category_record:
360+
# category_record[category] = {"corr": 0.0, "wrong": 0.0}
361+
# each["pred"] = pred
362+
# each["model_outputs"] = response
363+
# merge_result(res, each)
364+
365+
# if pred is not None:
366+
# if pred == label:
367+
# category_record[category]["corr"] += 1
368+
# else:
369+
# category_record[category]["wrong"] += 1
370+
# else:
371+
# category_record[category]["wrong"] += 1
372+
# # import pdb;pdb.set_trace()
373+
# save_res(res, output_res_path)
374+
# save_summary(category_record, output_summary_path)
375+
# res, category_record = update_result(output_res_path)
376+
# save_res(res, output_res_path)
377+
# save_summary(category_record, output_summary_path)
378+
379+
329380
def evaluate(subjects):
330-
# client = get_client()
331381
test_df, dev_df = load_mmlu_pro()
332382
if not subjects:
333383
subjects = list(test_df.keys())
334384
print("assigned subjects", subjects)
385+
335386
for subject in subjects:
336387
test_data = test_df[subject]
337-
output_res_path = os.path.join(args.output_dir, subject + "_result.json")
338-
output_summary_path = os.path.join(args.output_dir, subject + "_summary.json")
388+
output_res_path = os.path.join(args.output_dir, f"{subject}_result.json")
389+
output_summary_path = os.path.join(args.output_dir, f"{subject}_summary.json")
339390
res, category_record = update_result(output_res_path)
391+
392+
lock = threading.Lock()
340393

341-
k = 0
342-
for each in tqdm(test_data):
343-
# k += 1
344-
# if k % 10 != 0:
345-
# continue
346-
394+
def process_each(each, subject, dev_df, output_res_path, output_summary_path, res):
347395
label = each["answer"]
348396
category = subject
349-
# import pdb;pdb.set_trace()
397+
398+
# 多线程执行single_request
350399
pred, response, exist = single_request(None, each, dev_df, res)
351-
if response is not None:
400+
401+
if response is None:
402+
return
403+
404+
with lock: # 保证以下操作单线程访问
405+
# 读取最新结果
352406
res, category_record = update_result(output_res_path)
353-
if category not in category_record:
354-
category_record[category] = {"corr": 0.0, "wrong": 0.0}
407+
408+
# # 检查是否已处理(假设each有唯一标识)
409+
# if any(e['id'] == each['id'] for e in res.values()):
410+
# return
411+
412+
# 更新结果数据
355413
each["pred"] = pred
356414
each["model_outputs"] = response
357415
merge_result(res, each)
358416

359-
if pred is not None:
360-
if pred == label:
361-
category_record[category]["corr"] += 1
362-
else:
363-
category_record[category]["wrong"] += 1
417+
# 更新统计信息
418+
if category not in category_record:
419+
category_record[category] = {"corr": 0, "wrong": 0}
420+
if pred == label:
421+
category_record[category]["corr"] += 1
364422
else:
365423
category_record[category]["wrong"] += 1
366-
# import pdb;pdb.set_trace()
424+
425+
# 保存更新
367426
save_res(res, output_res_path)
368427
save_summary(category_record, output_summary_path)
369428
res, category_record = update_result(output_res_path)
370-
save_res(res, output_res_path)
371-
save_summary(category_record, output_summary_path)
429+
430+
# 绑定固定参数
431+
process_func = partial(process_each,
432+
subject=subject,
433+
dev_df=dev_df,
434+
output_res_path=output_res_path,
435+
output_summary_path=output_summary_path,
436+
res=res)
437+
438+
# 使用线程池并发处理
439+
with ThreadPoolExecutor(max_workers=20) as executor:
440+
tasks = list(tqdm(executor.map(process_func, test_data), total=len(test_data)))
441+
442+
# 最终保存确保完整性
443+
final_res, final_summary = update_result(output_res_path)
444+
save_res(final_res, output_res_path)
445+
save_summary(final_summary, output_summary_path)
372446

373447

374448
def save_res(res, output_res_path):
+3-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
python3 evaluate_from_api.py --backend trtllm --ip $IP --port $PORT --output_dir ./eval_trtlllm
1+
export PPNLP_HOME="/opt/output/ppnlp_home"
2+
3+
python3 evaluate_from_api.py --backend paddle --ip 127.0.0.1 --port 9965 --output_dir ./eval_paddle

0 commit comments

Comments
 (0)