Skip to content

Commit 043b483

Browse files
committed
switch logic
1 parent e1f04db commit 043b483

23 files changed

+141
-129
lines changed

lmms_eval/tasks/ok_vqa/_default_template_vqa_yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ metric_list:
1313
ignore_case: true
1414
ignore_punctuation: true
1515
- metric: submission
16-
aggregation: !function utils.ok_vqa_aggreate_submissions
16+
aggregation: !function utils.ok_vqa_aggregate_submissions
1717
higher_is_better: true
1818
process_results: !function utils.ok_vqa_process_results
1919
model_specific_prompt_kwargs:

lmms_eval/tasks/ok_vqa/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def ok_vqa_doc_to_visual(doc):
1919

2020
def ok_vqa_process_results(doc, result):
2121
eval_ai_processor = EvalAIAnswerProcessor()
22-
assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}."
22+
assert (
23+
len(result) == 1
24+
), f"The result should be a list of length 1, but got {len(result)}."
2325
resAns = eval_ai_processor(result[0])
2426
accuracy = 0
2527

@@ -30,7 +32,9 @@ def ok_vqa_process_results(doc, result):
3032
doc["answers"][i] = eval_ai_processor(doc["answers"][i])
3133

3234
for i in range(len(doc["answers"])):
33-
otherGTAns = [doc["answers"][j] for j in range(len(doc["answers"])) if i != j]
35+
otherGTAns = [
36+
doc["answers"][j] for j in range(len(doc["answers"])) if i != j
37+
]
3438
matchingAns = [item for item in otherGTAns if item == resAns]
3539
acc = min(1, float(len(matchingAns)) / 3)
3640
gtAcc.append(acc)
@@ -61,7 +65,7 @@ def ok_vqa_doc_to_text(doc, model_specific_prompt_kwargs=None):
6165
return f"{pre_prompt}{question}{post_prompt}"
6266

6367

64-
def ok_vqa_aggreate_submissions(results, args):
68+
def ok_vqa_aggregate_submissions(results, args):
6569
now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S")
6670
file = f"ok_vqa-test-submission-{now_date_time}.json"
6771
path = generate_submission_file(file, args)

lmms_eval/tasks/textvqa/textvqa_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ task: textvqa_test
22
test_split: test
33
metric_list:
44
- metric: submission
5-
aggregation: !function utils.textvqa_aggreate_submissions
5+
aggregation: !function utils.textvqa_aggregate_submissions
66
higher_is_better: true
77
include: _default_template_textvqa_yaml

lmms_eval/tasks/textvqa/textvqa_val.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ metric_list:
77
ignore_case: true
88
ignore_punctuation: true
99
- metric: submission
10-
aggregation: !function utils.textvqa_aggreate_submissions
10+
aggregation: !function utils.textvqa_aggregate_submissions
1111
higher_is_better: true
1212
include: _default_template_textvqa_yaml

lmms_eval/tasks/textvqa/utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def textvqa_doc_to_visual(doc):
1919

2020
def textvqa_process_results(doc, result):
2121
eval_ai_processor = EvalAIAnswerProcessor()
22-
assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}."
22+
assert (
23+
len(result) == 1
24+
), f"The result should be a list of length 1, but got {len(result)}."
2325
resAns = eval_ai_processor(result[0])
2426
accuracy = 0
2527

@@ -30,7 +32,9 @@ def textvqa_process_results(doc, result):
3032
doc["answers"][i] = eval_ai_processor(doc["answers"][i])
3133

3234
for i in range(len(doc["answers"])):
33-
otherGTAns = [doc["answers"][j] for j in range(len(doc["answers"])) if i != j]
35+
otherGTAns = [
36+
doc["answers"][j] for j in range(len(doc["answers"])) if i != j
37+
]
3438
matchingAns = [item for item in otherGTAns if item == resAns]
3539
acc = min(1, float(len(matchingAns)) / 3)
3640
gtAcc.append(acc)
@@ -54,12 +58,15 @@ def textvqa_doc_to_text(doc, model_specific_prompt_kwargs=None):
5458
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
5559
if "post_prompt" in model_specific_prompt_kwargs:
5660
post_prompt = model_specific_prompt_kwargs["post_prompt"]
57-
if "ocr" in model_specific_prompt_kwargs and model_specific_prompt_kwargs["ocr"]:
61+
if (
62+
"ocr" in model_specific_prompt_kwargs
63+
and model_specific_prompt_kwargs["ocr"]
64+
):
5865
ocr_ref = f"\nReference OCR token: {', '.join(doc['ocr_tokens'])}"
5966
return f"{pre_prompt}{doc['question'].capitalize()}{ocr_ref}{post_prompt}"
6067

6168

62-
def textvqa_aggreate_submissions(results, args):
69+
def textvqa_aggregate_submissions(results, args):
6370
now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
6471
path = generate_submission_file(f"textvqa_submission_{now_date_time}.json", args)
6572
with open(path, "w") as f:

lmms_eval/tasks/vcr_wiki/utils.py

Lines changed: 64 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from spacy.cli import download
77
from nltk.util import ngrams
88
from functools import partial
9+
import datetime
10+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
11+
import json
912

1013
# Download the English and Chinese models
1114
download("en_core_web_sm")
@@ -46,7 +49,7 @@ def fast_filter(answer_text):
4649

4750

4851
def vcr_doc_to_visual(doc):
49-
return [doc["stacked_image"].convert("RGB"), doc["only_it_image"].convert("RGB")]
52+
return [doc["stacked_image"].convert("RGB")]
5053

5154

5255
def vcr_doc_to_text(doc, model_specific_prompt_kwargs=None):
@@ -80,7 +83,7 @@ def vcr_process_results_single(crossed_text, result, language):
8083
doc: a instance of the eval dataset
8184
results: [pred]
8285
Returns:
83-
a dictionary with key: metric name (in this case mme score), value: metric value
86+
a dictionary with key: metric name (in this case vcr score), value: metric value
8487
"""
8588

8689
assert language in ["en", "zh"], f"Language {language} is not supported."
@@ -171,29 +174,28 @@ def vcr_en_process_results(doc, results):
171174
doc: a instance of the eval dataset
172175
results: [pred]
173176
Returns:
174-
a dictionary with key: metric name (in this case mme score), value: metric value
177+
a dictionary with key: metric name (in this case vcr score), value: metric value
175178
"""
176-
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
177-
output = {}
178-
for i in range(len(doc["crossed_text"])):
179-
res_stacked_image_results = vcr_process_results_single(
180-
doc["crossed_text"][i], results[0], "en"
181-
)
182-
res_only_image_results = vcr_process_results_single(
183-
doc["crossed_text"][i], results[1], "en"
184-
)
185-
output.update(
186-
{
187-
f"res_stacked_image__{k}___{i}": v
188-
for k, v in res_stacked_image_results.items()
189-
}
190-
)
191-
output.update(
192-
{
193-
f"res_only_it_image__{k}___{i}": v
194-
for k, v in res_only_image_results.items()
195-
}
196-
)
179+
output = {
180+
"max_sim_val": [],
181+
"precision": [],
182+
"recall": [],
183+
"f1": [],
184+
"jaccard": [],
185+
"rouge1": [],
186+
"exact_match": [],
187+
}
188+
crossed_text = doc["crossed_text"]
189+
for i in range(len(crossed_text)):
190+
tmp = vcr_process_results_single(crossed_text[i], results, "en")
191+
for k in output.keys():
192+
output[k].append(
193+
{
194+
"score": tmp[k],
195+
"max_sim_string": tmp["max_sim_string"],
196+
"caption": doc["caption"],
197+
}
198+
)
197199
return output
198200

199201

@@ -203,62 +205,51 @@ def vcr_zh_process_results(doc, results):
203205
doc: a instance of the eval dataset
204206
results: [pred]
205207
Returns:
206-
a dictionary with key: metric name (in this case mme score), value: metric value
208+
a dictionary with key: metric name (in this case vcr score), value: metric value
207209
"""
208-
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
209-
output = {}
210-
for i in range(len(doc["crossed_text"])):
211-
res_stacked_image_results = vcr_process_results_single(
212-
doc["crossed_text"][i], results[0], "zh"
213-
)
214-
res_only_image_results = vcr_process_results_single(
215-
doc["crossed_text"][i], results[1], "zh"
216-
)
217-
output.update(
218-
{
219-
f"res_stacked_image__{k}___{i}": v
220-
for k, v in res_stacked_image_results.items()
221-
}
222-
)
223-
output.update(
224-
{
225-
f"res_only_it_image__{k}___{i}": v
226-
for k, v in res_only_image_results.items()
227-
}
228-
)
210+
output = {
211+
"max_sim_val": [],
212+
"precision": [],
213+
"recall": [],
214+
"f1": [],
215+
"jaccard": [],
216+
"rouge1": [],
217+
"exact_match": [],
218+
}
219+
crossed_text = doc["crossed_text"]
220+
for i in range(len(crossed_text)):
221+
tmp = vcr_process_results_single(crossed_text[i], results, "zh")
222+
for k in output.keys():
223+
output[k].append(
224+
{
225+
"score": tmp[k],
226+
"max_sim_string": tmp["max_sim_string"],
227+
"caption": doc["caption"],
228+
}
229+
)
229230
return output
230231

231232

232-
def vcr_aggregate_results(results):
233+
def vcr_aggregate_results(results, args):
233234
"""
234235
Args:
235236
results: a list of values returned by process_results
236237
Returns:
237238
A dictionary of dictionary of float, where the outer dictionary has keys "res_stacked_image" and "res_only_it_image"
238239
"""
239-
output = {
240-
"res_stacked_image__precision": 0,
241-
"res_stacked_image__recall": 0,
242-
"res_stacked_image__f1": 0,
243-
"res_stacked_image__jaccard": 0,
244-
"res_stacked_image__rouge1": 0,
245-
"res_stacked_image__exact_match": 0,
246-
"res_only_it_image__precision": 0,
247-
"res_only_it_image__recall": 0,
248-
"res_only_it_image__f1": 0,
249-
"res_only_it_image__jaccard": 0,
250-
"res_only_it_image__rouge1": 0,
251-
"res_only_it_image__exact_match": 0,
252-
}
253-
254-
for output_key in output.keys():
255-
count = 0
256-
query_domain, query_metric_name = output_key.split("__")
257-
for inner_dict in results:
258-
for inner_key, inner_value in inner_dict.items():
259-
key_domain, key_metric_name, _ = inner_key.split("__")
260-
if query_domain == key_domain and query_metric_name == key_metric_name:
261-
output[output_key] += inner_value
262-
count += 1
263-
output[output_key] /= count
264-
return output
240+
scores = 0
241+
count = 0
242+
output_dict = {}
243+
for i in range(len(results)):
244+
for blank_id in range(len(results[i])):
245+
scores += results[i][blank_id]["score"]
246+
count += 1
247+
output_dict[str(i)] = results[i]
248+
249+
now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
250+
path = generate_submission_file(f"vcr_submission_{now_date_time}.json", args)
251+
with open(path, "w") as f:
252+
json.dump(output_dict, f)
253+
# print(f"Submission file saved to {path}")
254+
eval_logger.info(f"Submission file saved to {path}")
255+
return scores / count

lmms_eval/tasks/vcr_wiki/vcr_wiki_en_easy_100.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"include": "_default_template_vcr_yaml"
22
dataset_path: vcr-org/VCR-wiki-en-easy-test
3-
task: "vcr_wiki_en_easy"
3+
task: "vcr_wiki_en_easy_100"
44
test_split: train[:100]
55
process_results: !function utils.vcr_en_process_results
66
metric_list:
7-
- metric: vcr_percetion_score
7+
- metric: jaccard
88
aggregation: !function utils.vcr_en_process_results
99
higher_is_better: true
10-
- metric: vcr_cognition_score
10+
- metric: exact_match
1111
aggregation: !function utils.vcr_en_process_results
1212
higher_is_better: true
1313
model_specific_prompt_kwargs:

lmms_eval/tasks/vcr_wiki/vcr_wiki_en_easy_500.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"include": "_default_template_vcr_yaml"
22
dataset_path: vcr-org/VCR-wiki-en-easy-test
3-
task: "vcr_wiki_en_easy"
3+
task: "vcr_wiki_en_easy_500"
44
test_split: train[:500]
55
process_results: !function utils.vcr_en_process_results
66
metric_list:
7-
- metric: vcr_percetion_score
7+
- metric: jaccard
88
aggregation: !function utils.vcr_en_process_results
99
higher_is_better: true
10-
- metric: vcr_cognition_score
10+
- metric: exact_match
1111
aggregation: !function utils.vcr_en_process_results
1212
higher_is_better: true
1313
model_specific_prompt_kwargs:

lmms_eval/tasks/vcr_wiki/vcr_wiki_en_easy_5000.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"include": "_default_template_vcr_yaml"
22
dataset_path: vcr-org/VCR-wiki-en-easy-test
3-
task: "vcr_wiki_en_easy"
3+
task: "vcr_wiki_en_easy_5000"
44
test_split: train
55
process_results: !function utils.vcr_en_process_results
66
metric_list:
7-
- metric: vcr_percetion_score
7+
- metric: jaccard
88
aggregation: !function utils.vcr_en_process_results
99
higher_is_better: true
10-
- metric: vcr_cognition_score
10+
- metric: exact_match
1111
aggregation: !function utils.vcr_en_process_results
1212
higher_is_better: true
1313
model_specific_prompt_kwargs:

lmms_eval/tasks/vcr_wiki/vcr_wiki_en_hard_100.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"include": "_default_template_vcr_yaml"
22
dataset_path: vcr-org/VCR-wiki-en-hard-test
3-
task: "vcr_wiki_en_hard"
3+
task: "vcr_wiki_en_hard_100"
44
test_split: train[:100]
55
process_results: !function utils.vcr_en_process_results
66
metric_list:
7-
- metric: vcr_percetion_score
7+
- metric: jaccard
88
aggregation: !function utils.vcr_en_process_results
99
higher_is_better: true
10-
- metric: vcr_cognition_score
10+
- metric: exact_match
1111
aggregation: !function utils.vcr_en_process_results
1212
higher_is_better: true
1313
model_specific_prompt_kwargs:

lmms_eval/tasks/vcr_wiki/vcr_wiki_en_hard_500.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"include": "_default_template_vcr_yaml"
22
dataset_path: vcr-org/VCR-wiki-en-hard-test
3-
task: "vcr_wiki_en_hard"
3+
task: "vcr_wiki_en_hard_500"
44
test_split: train[:500]
55
process_results: !function utils.vcr_en_process_results
66
metric_list:
7-
- metric: vcr_percetion_score
7+
- metric: jaccard
88
aggregation: !function utils.vcr_en_process_results
99
higher_is_better: true
10-
- metric: vcr_cognition_score
10+
- metric: exact_match
1111
aggregation: !function utils.vcr_en_process_results
1212
higher_is_better: true
1313
model_specific_prompt_kwargs:

lmms_eval/tasks/vcr_wiki/vcr_wiki_en_hard_5000.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"include": "_default_template_vcr_yaml"
22
dataset_path: vcr-org/VCR-wiki-en-hard-test
3-
task: "vcr_wiki_en_hard"
3+
task: "vcr_wiki_en_hard_5000"
44
test_split: train
55
process_results: !function utils.vcr_en_process_results
66
metric_list:
7-
- metric: vcr_percetion_score
7+
- metric: jaccard
88
aggregation: !function utils.vcr_en_process_results
99
higher_is_better: true
10-
- metric: vcr_cognition_score
10+
- metric: exact_match
1111
aggregation: !function utils.vcr_en_process_results
1212
higher_is_better: true
1313
model_specific_prompt_kwargs:

lmms_eval/tasks/vcr_wiki/vcr_wiki_zh_easy_100.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"include": "_default_template_vcr_yaml"
22
dataset_path: vcr-org/VCR-wiki-zh-easy-test
3-
task: "vcr_wiki_zh_easy"
3+
task: "vcr_wiki_zh_easy_100"
44
test_split: train[:100]
55
process_results: !function utils.vcr_zh_process_results
66
metric_list:
7-
- metric: vcr_percetion_score
7+
- metric: jaccard
88
aggregation: !function utils.vcr_zh_process_results
99
higher_is_better: true
10-
- metric: vcr_cognition_score
10+
- metric: exact_match
1111
aggregation: !function utils.vcr_zh_process_results
1212
higher_is_better: true
1313
model_specific_prompt_kwargs:

0 commit comments

Comments
 (0)