Skip to content

Commit 7147275

Browse files
authored
Fix mixed results of rouge_score with accumulate='best' (#2830)
1 parent ea29c89 commit 7147275

File tree

3 files changed

+63
-14
lines changed

3 files changed

+63
-14
lines changed

CHANGELOG.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5050

5151
- Removed `num_outputs` in `R2Score` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))
5252

53-
5453
### Fixed
5554

56-
-
57-
55+
- Fixed mixed results of `rouge_score` with `accumulate='best'` ([#2830](https://github.com/Lightning-AI/torchmetrics/pull/2830))
5856

59-
---
6057

6158
## [1.5.2] - 2024-11-07
6259

src/torchmetrics/functional/text/rouge.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,12 +362,9 @@ def _rouge_score_update(
362362
list_results.append(result_inner.copy())
363363

364364
if accumulate == "best":
365-
key_curr = rouge_keys_values[0]
366-
all_fmeasure = torch.tensor([v[key_curr]["fmeasure"] for v in list_results])
367-
highest_idx = int(torch.argmax(all_fmeasure).item())
368-
369-
for rouge_key in rouge_keys_values:
370-
results[rouge_key].append(list_results[highest_idx][rouge_key]) # todo
365+
for k in rouge_keys_values:
366+
index = torch.argmax(torch.tensor([s[k]["fmeasure"] for s in list_results]))
367+
results[k].append(list_results[index][k])
371368

372369
elif accumulate == "avg":
373370
new_result_avg: dict[Union[int, str], dict[str, Tensor]] = {

tests/unittests/text/test_rouge.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,12 @@ def _reference_rouge_score(
7575
aggregator_avg = BootstrapAggregator()
7676

7777
if accumulate == "best":
78-
key_curr = next(iter(list_results[0].keys()))
79-
all_fmeasure = torch.tensor([v[key_curr].fmeasure for v in list_results])
80-
highest_idx = torch.argmax(all_fmeasure).item()
81-
aggregator.add_scores(list_results[highest_idx])
78+
scores = {}
79+
for rouge_key in list_results[0]:
80+
all_fmeasure = torch.tensor([v[rouge_key].fmeasure for v in list_results])
81+
highest_idx = torch.argmax(all_fmeasure).item()
82+
scores[rouge_key] = list_results[highest_idx][rouge_key]
83+
aggregator.add_scores(scores)
8284
elif accumulate == "avg":
8385
for _score in list_results:
8486
aggregator_avg.add_scores(_score)
@@ -270,3 +272,56 @@ def test_rouge_lsum_score(pl_rouge_metric_key, use_stemmer):
270272
use_stemmer=use_stemmer,
271273
)
272274
assert torch.isclose(metrics_score[rouge_level + "_" + metric], original_score)
275+
276+
277+
@pytest.mark.parametrize(
278+
("preds", "references", "expected_scores"),
279+
[
280+
(
281+
"a b c",
282+
["a b c", "c b a"],
283+
{
284+
"rouge1_fmeasure": 1.0,
285+
"rouge1_precision": 1.0,
286+
"rouge1_recall": 1.0,
287+
"rouge2_fmeasure": 1.0,
288+
"rouge2_precision": 1.0,
289+
"rouge2_recall": 1.0,
290+
"rougeL_fmeasure": 1.0,
291+
"rougeL_precision": 1.0,
292+
"rougeL_recall": 1.0,
293+
"rougeLsum_fmeasure": 1.0,
294+
"rougeLsum_precision": 1.0,
295+
"rougeLsum_recall": 1.0,
296+
},
297+
),
298+
(
299+
"a b c",
300+
["c b a", "a b c"],
301+
{
302+
"rouge1_fmeasure": 1.0,
303+
"rouge1_precision": 1.0,
304+
"rouge1_recall": 1.0,
305+
"rouge2_fmeasure": 1.0,
306+
"rouge2_precision": 1.0,
307+
"rouge2_recall": 1.0,
308+
"rougeL_fmeasure": 1.0,
309+
"rougeL_precision": 1.0,
310+
"rougeL_recall": 1.0,
311+
"rougeLsum_fmeasure": 1.0,
312+
"rougeLsum_precision": 1.0,
313+
"rougeLsum_recall": 1.0,
314+
},
315+
),
316+
],
317+
)
318+
def test_rouge_score_accumulate_best(preds, references, expected_scores):
319+
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/2148."""
320+
# Calculate ROUGE scores
321+
result = rouge_score(preds, references, accumulate="best")
322+
323+
# Assert each expected score
324+
for key in expected_scores:
325+
assert torch.isclose(
326+
result[key], torch.tensor(expected_scores[key])
327+
), f"Expected {expected_scores[key]} for {key}, but got {result[key]}"

0 commit comments

Comments
 (0)