Skip to content

Commit 0773bab

Browse files
rittik9Borda
authored andcommitted
test: Add test for MulticlassRecall with ignore_index+macro (fixes Lightning-AI#2441)
1 parent df36d0f commit 0773bab

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

tests/unittests/classification/test_precision_recall.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,37 @@ def test_corner_case():
659659
assert res == 1.0
660660

661661

662+
def test_multiclass_recall_ignore_index():
663+
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/2441."""
664+
y_true = torch.tensor([0, 0, 1, 1])
665+
y_pred = torch.tensor([
666+
[0.9, 0.1],
667+
[0.9, 0.1],
668+
[0.9, 0.1],
669+
[0.1, 0.9],
670+
])
671+
672+
# Test with ignore_index=0 and average="macro"
673+
metric_ignore_0 = MulticlassRecall(num_classes=2, ignore_index=0, average="macro")
674+
res_ignore_0 = metric_ignore_0(y_pred, y_true)
675+
assert res_ignore_0 == 0.5, f"Expected 0.5, but got {res_ignore_0}"
676+
677+
# Test with ignore_index=1 and average="macro"
678+
metric_ignore_1 = MulticlassRecall(num_classes=2, ignore_index=1, average="macro")
679+
res_ignore_1 = metric_ignore_1(y_pred, y_true)
680+
assert res_ignore_1 == 1.0, f"Expected 1.0, but got {res_ignore_1}"
681+
682+
# Test with no ignore_index and average="macro"
683+
metric_no_ignore = MulticlassRecall(num_classes=2, average="macro")
684+
res_no_ignore = metric_no_ignore(y_pred, y_true)
685+
assert res_no_ignore == 0.75, f"Expected 0.75, but got {res_no_ignore}"
686+
687+
# Test with ignore_index=0 and average="none"
688+
metric_none = MulticlassRecall(num_classes=2, ignore_index=0, average="none")
689+
res_none = metric_none(y_pred, y_true)
690+
assert torch.allclose(res_none, torch.tensor([0.0, 0.5])), f"Expected [0.0, 0.5], but got {res_none}"
691+
692+
662693
@pytest.mark.parametrize(
663694
("metric", "kwargs", "base_metric"),
664695
[

0 commit comments

Comments
 (0)