Skip to content

Commit 412002c

Browse files
authored
Fix missing reset in classwise wrapper (#1129)
* missing reset * changelog
1 parent 883254e commit 412002c

File tree

3 files changed

+28
-15
lines changed

3 files changed

+28
-15
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4040

4141
### Fixed
4242

43-
-
43+
- Fixed missing reset in `ClasswiseWrapper` ([#1129](https://github.com/Lightning-AI/metrics/pull/1129))
4444

4545

4646
- Fixed JaccardIndex multi-label compute ([#1125](https://github.com/Lightning-AI/metrics/pull/1125))

src/torchmetrics/wrappers/classwise.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class ClasswiseWrapper(Metric):
5151
'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)}
5252
"""
5353

54+
full_state_update: Optional[bool] = True
55+
5456
def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None:
5557
super().__init__()
5658
if not isinstance(metric, Metric):
@@ -71,3 +73,6 @@ def update(self, *args: Any, **kwargs: Any) -> None:
7173

7274
def compute(self) -> Dict[str, Tensor]:
7375
return self._convert(self.metric.compute())
76+
77+
def reset(self) -> None:
78+
self.metric.reset()

tests/unittests/wrappers/test_classwise.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,35 @@ def test_raises_error_on_wrong_input():
1515

1616
def test_output_no_labels():
1717
"""Test that wrapper works with no label input."""
18+
base = Accuracy(num_classes=3, average=None)
1819
metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None))
19-
preds = torch.randn(10, 3).softmax(dim=-1)
20-
target = torch.randint(3, (10,))
21-
val = metric(preds, target)
22-
assert isinstance(val, dict)
23-
assert len(val) == 3
24-
for i in range(3):
25-
assert f"accuracy_{i}" in val
20+
for _ in range(2):
21+
preds = torch.randn(10, 3).softmax(dim=-1)
22+
target = torch.randint(3, (10,))
23+
val = metric(preds, target)
24+
val_base = base(preds, target)
25+
assert isinstance(val, dict)
26+
assert len(val) == 3
27+
for i in range(3):
28+
assert f"accuracy_{i}" in val
29+
assert val[f"accuracy_{i}"] == val_base[i]
2630

2731

2832
def test_output_with_labels():
2933
"""Test that wrapper works with label input."""
3034
labels = ["horse", "fish", "cat"]
35+
base = Accuracy(num_classes=3, average=None)
3136
metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels)
32-
preds = torch.randn(10, 3).softmax(dim=-1)
33-
target = torch.randint(3, (10,))
34-
val = metric(preds, target)
35-
assert isinstance(val, dict)
36-
assert len(val) == 3
37-
for lab in labels:
38-
assert f"accuracy_{lab}" in val
37+
for _ in range(2):
38+
preds = torch.randn(10, 3).softmax(dim=-1)
39+
target = torch.randint(3, (10,))
40+
val = metric(preds, target)
41+
val_base = base(preds, target)
42+
assert isinstance(val, dict)
43+
assert len(val) == 3
44+
for i, lab in enumerate(labels):
45+
assert f"accuracy_{lab}" in val
46+
assert val[f"accuracy_{lab}"] == val_base[i]
3947

4048

4149
@pytest.mark.parametrize("prefix", [None, "pre_"])

0 commit comments

Comments
 (0)