35
35
unpicklable_function = lambda : None
36
36
37
37
38
- @pytest .fixture (scope = "module" )
39
38
def model_cases ():
40
39
class TestHparamsNamespace :
41
40
learning_rate = 1
@@ -93,9 +92,9 @@ class TestModel7: # test for datamodule w/ hparams w/ attribute (should use dat
93
92
return model1 , model2 , model3 , model4 , model5 , model6 , model7
94
93
95
94
96
- def test_lightning_hasattr (tmpdir , model_cases ):
95
+ def test_lightning_hasattr (tmpdir ):
97
96
"""Test that the lightning_hasattr works in all cases."""
98
- model1 , model2 , model3 , model4 , model5 , model6 , model7 = models = model_cases
97
+ model1 , model2 , model3 , model4 , model5 , model6 , model7 = models = model_cases ()
99
98
assert lightning_hasattr (model1 , "learning_rate" ), "lightning_hasattr failed to find namespace variable"
100
99
assert lightning_hasattr (model2 , "learning_rate" ), "lightning_hasattr failed to find hparams namespace variable"
101
100
assert lightning_hasattr (model3 , "learning_rate" ), "lightning_hasattr failed to find hparams dict variable"
@@ -112,9 +111,9 @@ def test_lightning_hasattr(tmpdir, model_cases):
112
111
assert not lightning_hasattr (m , "this_attr_not_exist" )
113
112
114
113
115
- def test_lightning_getattr (tmpdir , model_cases ):
114
+ def test_lightning_getattr (tmpdir ):
116
115
"""Test that the lightning_getattr works in all cases."""
117
- models = model_cases
116
+ models = model_cases ()
118
117
for i , m in enumerate (models [:3 ]):
119
118
value = lightning_getattr (m , "learning_rate" )
120
119
assert value == i , "attribute not correctly extracted"
@@ -132,9 +131,9 @@ def test_lightning_getattr(tmpdir, model_cases):
132
131
lightning_getattr (m , "this_attr_not_exist" )
133
132
134
133
135
- def test_lightning_setattr (tmpdir , model_cases ):
134
+ def test_lightning_setattr (tmpdir ):
136
135
"""Test that the lightning_setattr works in all cases."""
137
- models = model_cases
136
+ models = model_cases ()
138
137
for m in models [:3 ]:
139
138
lightning_setattr (m , "learning_rate" , 10 )
140
139
assert lightning_getattr (m , "learning_rate" ) == 10 , "attribute not correctly set"
0 commit comments