@@ -15,27 +15,35 @@ def test_raises_error_on_wrong_input():
15
15
16
16
def test_output_no_labels ():
17
17
"""Test that wrapper works with no label input."""
18
+ base = Accuracy (num_classes = 3 , average = None )
18
19
metric = ClasswiseWrapper (Accuracy (num_classes = 3 , average = None ))
19
- preds = torch .randn (10 , 3 ).softmax (dim = - 1 )
20
- target = torch .randint (3 , (10 ,))
21
- val = metric (preds , target )
22
- assert isinstance (val , dict )
23
- assert len (val ) == 3
24
- for i in range (3 ):
25
- assert f"accuracy_{ i } " in val
20
+ for _ in range (2 ):
21
+ preds = torch .randn (10 , 3 ).softmax (dim = - 1 )
22
+ target = torch .randint (3 , (10 ,))
23
+ val = metric (preds , target )
24
+ val_base = base (preds , target )
25
+ assert isinstance (val , dict )
26
+ assert len (val ) == 3
27
+ for i in range (3 ):
28
+ assert f"accuracy_{ i } " in val
29
+ assert val [f"accuracy_{ i } " ] == val_base [i ]
26
30
27
31
28
32
def test_output_with_labels ():
29
33
"""Test that wrapper works with label input."""
30
34
labels = ["horse" , "fish" , "cat" ]
35
+ base = Accuracy (num_classes = 3 , average = None )
31
36
metric = ClasswiseWrapper (Accuracy (num_classes = 3 , average = None ), labels = labels )
32
- preds = torch .randn (10 , 3 ).softmax (dim = - 1 )
33
- target = torch .randint (3 , (10 ,))
34
- val = metric (preds , target )
35
- assert isinstance (val , dict )
36
- assert len (val ) == 3
37
- for lab in labels :
38
- assert f"accuracy_{ lab } " in val
37
+ for _ in range (2 ):
38
+ preds = torch .randn (10 , 3 ).softmax (dim = - 1 )
39
+ target = torch .randint (3 , (10 ,))
40
+ val = metric (preds , target )
41
+ val_base = base (preds , target )
42
+ assert isinstance (val , dict )
43
+ assert len (val ) == 3
44
+ for i , lab in enumerate (labels ):
45
+ assert f"accuracy_{ lab } " in val
46
+ assert val [f"accuracy_{ lab } " ] == val_base [i ]
39
47
40
48
41
49
@pytest .mark .parametrize ("prefix" , [None , "pre_" ])
0 commit comments