6
6
7
7
import weathergen .utils .config as config
8
8
9
+ TEST_RUN_ID = "test123"
9
10
SECRET_COMPONENT = "53CR3T"
10
11
DUMMY_PRIVATE_CONF = {
11
12
"data_path_anemoi" : "/path/to/anmoi/data" ,
@@ -144,6 +145,7 @@ def overwrite_file(overwrite_config):
144
145
@pytest .fixture
145
146
def config_fresh (private_config_file ):
146
147
cf = config .load_config (private_config_file , None , None )
148
+ config .set_run_id (cf , TEST_RUN_ID , False )
147
149
cf .data_loader_rng_seed = 42
148
150
149
151
return cf
@@ -188,14 +190,12 @@ def test_load_multiple_overwrites(private_config_file):
188
190
189
191
@pytest .mark .parametrize ("epoch" , [None , 0 , 1 , 2 , - 1 ])
190
192
def test_load_existing_config (epoch , private_config_file , config_fresh ):
191
- test_run_id = "test123"
192
193
test_num_epochs = 3000
193
194
194
- config_fresh .run_id = test_run_id # done in trainer
195
195
config_fresh .num_epochs = test_num_epochs # some specific change
196
196
config .save (config_fresh , epoch )
197
197
198
- cf = config .load_config (private_config_file , test_run_id , epoch )
198
+ cf = config .load_config (private_config_file , config_fresh . run_id , epoch )
199
199
200
200
assert cf .num_epochs == test_num_epochs
201
201
@@ -205,7 +205,13 @@ def test_from_cli(options, cf):
205
205
parsed_config = config .from_cli_arglist (options )
206
206
207
207
assert parsed_config == OmegaConf .create (cf )
208
-
208
+ @pytest .mark .parametrize ("run_id,reuse,expected" , [(None , False , "generated" ), ("new_id" , False , "new_id" ), (None , True , TEST_RUN_ID ), ("new_id" , True , TEST_RUN_ID )])
209
+ def test_set_run_id (config_fresh , run_id , reuse , expected , mocker ):
210
+ patch = mocker .patch ("weathergen.utils.config.get_run_id" , return_value = "generated" )
211
+
212
+ config .set_run_id (config_fresh , run_id , reuse )
213
+
214
+ assert config_fresh .run_id == expected
209
215
210
216
def test_print_cf_no_secrets (config_fresh ):
211
217
output = config ._format_cf (config_fresh )
@@ -243,9 +249,7 @@ def test_load_malformed_stream(streams_dir):
243
249
244
250
@pytest .mark .parametrize ("epoch" , [None , 0 , 1 , 2 , - 1 ]) # maybe add -5 as test case
245
251
def test_save (epoch , config_fresh ):
246
- test_run_id = "test123"
247
- config_fresh .run_id = test_run_id
248
252
config .save (config_fresh , epoch )
249
253
250
- cf = config .load_model_config (test_run_id , epoch , config_fresh .model_path )
251
- assert is_equal (cf , config_fresh )
254
+ cf = config .load_model_config (config_fresh . run_id , epoch , config_fresh .model_path )
255
+ assert is_equal (cf , config_fresh )
0 commit comments