@@ -659,6 +659,37 @@ def test_corner_case():
659
659
assert res == 1.0
660
660
661
661
662
+ def test_multiclass_recall_ignore_index ():
663
+ """Issue: https://github.com/Lightning-AI/torchmetrics/issues/2441."""
664
+ y_true = torch .tensor ([0 , 0 , 1 , 1 ])
665
+ y_pred = torch .tensor ([
666
+ [0.9 , 0.1 ],
667
+ [0.9 , 0.1 ],
668
+ [0.9 , 0.1 ],
669
+ [0.1 , 0.9 ],
670
+ ])
671
+
672
+ # Test with ignore_index=0 and average="macro"
673
+ metric_ignore_0 = MulticlassRecall (num_classes = 2 , ignore_index = 0 , average = "macro" )
674
+ res_ignore_0 = metric_ignore_0 (y_pred , y_true )
675
+ assert res_ignore_0 == 0.5 , f"Expected 0.5, but got { res_ignore_0 } "
676
+
677
+ # Test with ignore_index=1 and average="macro"
678
+ metric_ignore_1 = MulticlassRecall (num_classes = 2 , ignore_index = 1 , average = "macro" )
679
+ res_ignore_1 = metric_ignore_1 (y_pred , y_true )
680
+ assert res_ignore_1 == 1.0 , f"Expected 1.0, but got { res_ignore_1 } "
681
+
682
+ # Test with no ignore_index and average="macro"
683
+ metric_no_ignore = MulticlassRecall (num_classes = 2 , average = "macro" )
684
+ res_no_ignore = metric_no_ignore (y_pred , y_true )
685
+ assert res_no_ignore == 0.75 , f"Expected 0.75, but got { res_no_ignore } "
686
+
687
+ # Test with ignore_index=0 and average="none"
688
+ metric_none = MulticlassRecall (num_classes = 2 , ignore_index = 0 , average = "none" )
689
+ res_none = metric_none (y_pred , y_true )
690
+ assert torch .allclose (res_none , torch .tensor ([0.0 , 0.5 ])), f"Expected [0.0, 0.5], but got { res_none } "
691
+
692
+
662
693
@pytest .mark .parametrize (
663
694
("metric" , "kwargs" , "base_metric" ),
664
695
[
0 commit comments