@@ -75,10 +75,12 @@ def _reference_rouge_score(
75
75
aggregator_avg = BootstrapAggregator ()
76
76
77
77
if accumulate == "best" :
78
- key_curr = next (iter (list_results [0 ].keys ()))
79
- all_fmeasure = torch .tensor ([v [key_curr ].fmeasure for v in list_results ])
80
- highest_idx = torch .argmax (all_fmeasure ).item ()
81
- aggregator .add_scores (list_results [highest_idx ])
78
+ scores = {}
79
+ for rouge_key in list_results [0 ]:
80
+ all_fmeasure = torch .tensor ([v [rouge_key ].fmeasure for v in list_results ])
81
+ highest_idx = torch .argmax (all_fmeasure ).item ()
82
+ scores [rouge_key ] = list_results [highest_idx ][rouge_key ]
83
+ aggregator .add_scores (scores )
82
84
elif accumulate == "avg" :
83
85
for _score in list_results :
84
86
aggregator_avg .add_scores (_score )
@@ -270,3 +272,56 @@ def test_rouge_lsum_score(pl_rouge_metric_key, use_stemmer):
270
272
use_stemmer = use_stemmer ,
271
273
)
272
274
assert torch .isclose (metrics_score [rouge_level + "_" + metric ], original_score )
275
+
276
+
277
+ @pytest .mark .parametrize (
278
+ ("preds" , "references" , "expected_scores" ),
279
+ [
280
+ (
281
+ "a b c" ,
282
+ ["a b c" , "c b a" ],
283
+ {
284
+ "rouge1_fmeasure" : 1.0 ,
285
+ "rouge1_precision" : 1.0 ,
286
+ "rouge1_recall" : 1.0 ,
287
+ "rouge2_fmeasure" : 1.0 ,
288
+ "rouge2_precision" : 1.0 ,
289
+ "rouge2_recall" : 1.0 ,
290
+ "rougeL_fmeasure" : 1.0 ,
291
+ "rougeL_precision" : 1.0 ,
292
+ "rougeL_recall" : 1.0 ,
293
+ "rougeLsum_fmeasure" : 1.0 ,
294
+ "rougeLsum_precision" : 1.0 ,
295
+ "rougeLsum_recall" : 1.0 ,
296
+ },
297
+ ),
298
+ (
299
+ "a b c" ,
300
+ ["c b a" , "a b c" ],
301
+ {
302
+ "rouge1_fmeasure" : 1.0 ,
303
+ "rouge1_precision" : 1.0 ,
304
+ "rouge1_recall" : 1.0 ,
305
+ "rouge2_fmeasure" : 1.0 ,
306
+ "rouge2_precision" : 1.0 ,
307
+ "rouge2_recall" : 1.0 ,
308
+ "rougeL_fmeasure" : 1.0 ,
309
+ "rougeL_precision" : 1.0 ,
310
+ "rougeL_recall" : 1.0 ,
311
+ "rougeLsum_fmeasure" : 1.0 ,
312
+ "rougeLsum_precision" : 1.0 ,
313
+ "rougeLsum_recall" : 1.0 ,
314
+ },
315
+ ),
316
+ ],
317
+ )
318
+ def test_rouge_score_accumulate_best (preds , references , expected_scores ):
319
+ """Issue: https://github.com/Lightning-AI/torchmetrics/issues/2148."""
320
+ # Calculate ROUGE scores
321
+ result = rouge_score (preds , references , accumulate = "best" )
322
+
323
+ # Assert each expected score
324
+ for key in expected_scores :
325
+ assert torch .isclose (
326
+ result [key ], torch .tensor (expected_scores [key ])
327
+ ), f"Expected { expected_scores [key ]} for { key } , but got { result [key ]} "
0 commit comments