|
8 | 8 | import argparse
|
9 | 9 | import requests
|
10 | 10 |
|
| 11 | +import os |
| 12 | +import threading |
| 13 | +from concurrent.futures import ThreadPoolExecutor |
| 14 | +from tqdm import tqdm |
| 15 | +from functools import partial |
| 16 | + |
11 | 17 | API_KEY = ""
|
12 | 18 | random.seed(12345)
|
13 | 19 |
|
@@ -326,49 +332,117 @@ def merge_result(res, curr):
|
326 | 332 | return res
|
327 | 333 |
|
328 | 334 |
|
| 335 | +# def evaluate(subjects): |
| 336 | +# # client = get_client() |
| 337 | +# test_df, dev_df = load_mmlu_pro() |
| 338 | +# if not subjects: |
| 339 | +# subjects = list(test_df.keys()) |
| 340 | +# print("assigned subjects", subjects) |
| 341 | +# for subject in subjects: |
| 342 | +# test_data = test_df[subject] |
| 343 | +# output_res_path = os.path.join(args.output_dir, subject + "_result.json") |
| 344 | +# output_summary_path = os.path.join(args.output_dir, subject + "_summary.json") |
| 345 | +# res, category_record = update_result(output_res_path) |
| 346 | + |
| 347 | +# k = 0 |
| 348 | +# for each in tqdm(test_data): |
| 349 | +# # k += 1 |
| 350 | +# # if k % 10 != 0: |
| 351 | +# # continue |
| 352 | + |
| 353 | +# label = each["answer"] |
| 354 | +# category = subject |
| 355 | +# # import pdb;pdb.set_trace() |
| 356 | +# pred, response, exist = single_request(None, each, dev_df, res) |
| 357 | +# if response is not None: |
| 358 | +# res, category_record = update_result(output_res_path) |
| 359 | +# if category not in category_record: |
| 360 | +# category_record[category] = {"corr": 0.0, "wrong": 0.0} |
| 361 | +# each["pred"] = pred |
| 362 | +# each["model_outputs"] = response |
| 363 | +# merge_result(res, each) |
| 364 | + |
| 365 | +# if pred is not None: |
| 366 | +# if pred == label: |
| 367 | +# category_record[category]["corr"] += 1 |
| 368 | +# else: |
| 369 | +# category_record[category]["wrong"] += 1 |
| 370 | +# else: |
| 371 | +# category_record[category]["wrong"] += 1 |
| 372 | +# # import pdb;pdb.set_trace() |
| 373 | +# save_res(res, output_res_path) |
| 374 | +# save_summary(category_record, output_summary_path) |
| 375 | +# res, category_record = update_result(output_res_path) |
| 376 | +# save_res(res, output_res_path) |
| 377 | +# save_summary(category_record, output_summary_path) |
| 378 | + |
| 379 | + |
329 | 380 | def evaluate(subjects):
|
330 |
| - # client = get_client() |
331 | 381 | test_df, dev_df = load_mmlu_pro()
|
332 | 382 | if not subjects:
|
333 | 383 | subjects = list(test_df.keys())
|
334 | 384 | print("assigned subjects", subjects)
|
| 385 | + |
335 | 386 | for subject in subjects:
|
336 | 387 | test_data = test_df[subject]
|
337 |
| - output_res_path = os.path.join(args.output_dir, subject + "_result.json") |
338 |
| - output_summary_path = os.path.join(args.output_dir, subject + "_summary.json") |
| 388 | + output_res_path = os.path.join(args.output_dir, f"{subject}_result.json") |
| 389 | + output_summary_path = os.path.join(args.output_dir, f"{subject}_summary.json") |
339 | 390 | res, category_record = update_result(output_res_path)
|
| 391 | + |
| 392 | + lock = threading.Lock() |
340 | 393 |
|
341 |
| - k = 0 |
342 |
| - for each in tqdm(test_data): |
343 |
| - # k += 1 |
344 |
| - # if k % 10 != 0: |
345 |
| - # continue |
346 |
| - |
| 394 | + def process_each(each, subject, dev_df, output_res_path, output_summary_path, res): |
347 | 395 | label = each["answer"]
|
348 | 396 | category = subject
|
349 |
| - # import pdb;pdb.set_trace() |
| 397 | + |
| 398 | + # 多线程执行single_request |
350 | 399 | pred, response, exist = single_request(None, each, dev_df, res)
|
351 |
| - if response is not None: |
| 400 | + |
| 401 | + if response is None: |
| 402 | + return |
| 403 | + |
| 404 | + with lock: # 保证以下操作单线程访问 |
| 405 | + # 读取最新结果 |
352 | 406 | res, category_record = update_result(output_res_path)
|
353 |
| - if category not in category_record: |
354 |
| - category_record[category] = {"corr": 0.0, "wrong": 0.0} |
| 407 | + |
| 408 | + # # 检查是否已处理(假设each有唯一标识) |
| 409 | + # if any(e['id'] == each['id'] for e in res.values()): |
| 410 | + # return |
| 411 | + |
| 412 | + # 更新结果数据 |
355 | 413 | each["pred"] = pred
|
356 | 414 | each["model_outputs"] = response
|
357 | 415 | merge_result(res, each)
|
358 | 416 |
|
359 |
| - if pred is not None: |
360 |
| - if pred == label: |
361 |
| - category_record[category]["corr"] += 1 |
362 |
| - else: |
363 |
| - category_record[category]["wrong"] += 1 |
| 417 | + # 更新统计信息 |
| 418 | + if category not in category_record: |
| 419 | + category_record[category] = {"corr": 0, "wrong": 0} |
| 420 | + if pred == label: |
| 421 | + category_record[category]["corr"] += 1 |
364 | 422 | else:
|
365 | 423 | category_record[category]["wrong"] += 1
|
366 |
| - # import pdb;pdb.set_trace() |
| 424 | + |
| 425 | + # 保存更新 |
367 | 426 | save_res(res, output_res_path)
|
368 | 427 | save_summary(category_record, output_summary_path)
|
369 | 428 | res, category_record = update_result(output_res_path)
|
370 |
| - save_res(res, output_res_path) |
371 |
| - save_summary(category_record, output_summary_path) |
| 429 | + |
| 430 | + # 绑定固定参数 |
| 431 | + process_func = partial(process_each, |
| 432 | + subject=subject, |
| 433 | + dev_df=dev_df, |
| 434 | + output_res_path=output_res_path, |
| 435 | + output_summary_path=output_summary_path, |
| 436 | + res=res) |
| 437 | + |
| 438 | + # 使用线程池并发处理 |
| 439 | + with ThreadPoolExecutor(max_workers=20) as executor: |
| 440 | + tasks = list(tqdm(executor.map(process_func, test_data), total=len(test_data))) |
| 441 | + |
| 442 | + # 最终保存确保完整性 |
| 443 | + final_res, final_summary = update_result(output_res_path) |
| 444 | + save_res(final_res, output_res_path) |
| 445 | + save_summary(final_summary, output_summary_path) |
372 | 446 |
|
373 | 447 |
|
374 | 448 | def save_res(res, output_res_path):
|
|
0 commit comments