Skip to content

Commit c27b39a

Browse files
committed
add/update tests
1 parent 368bbdd commit c27b39a

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

tests/test_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
DATE_FORMATS = ["2022-12-01T00:00:00", "20221201", "2022-12-01", "12.01.2022"]
88
EXPECTED_DATE_STR = "202212010000"
9-
MODEL_LOADING_ARGS = ["from_run_id", "epoch"]
9+
MODEL_LOADING_ARGS = ["from_run_id", "epoch", "reuse_run_id"]
1010
GENERAL_ARGS = ["config", "private_config", "options", "run_id"]
1111
MODEL_LOADING_PARSERS = [cli.get_continue_parser(), cli.get_evaluate_parser()]
1212
BASIC_ARGLIST = ["--from_run_id", "test123"]

tests/test_config.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import weathergen.utils.config as config
88

9+
TEST_RUN_ID = "test123"
910
SECRET_COMPONENT = "53CR3T"
1011
DUMMY_PRIVATE_CONF = {
1112
"data_path_anemoi": "/path/to/anmoi/data",
@@ -144,6 +145,7 @@ def overwrite_file(overwrite_config):
144145
@pytest.fixture
145146
def config_fresh(private_config_file):
146147
cf = config.load_config(private_config_file, None, None)
148+
config.set_run_id(cf, TEST_RUN_ID, False)
147149
cf.data_loader_rng_seed = 42
148150

149151
return cf
@@ -188,14 +190,12 @@ def test_load_multiple_overwrites(private_config_file):
188190

189191
@pytest.mark.parametrize("epoch", [None, 0, 1, 2, -1])
190192
def test_load_existing_config(epoch, private_config_file, config_fresh):
191-
test_run_id = "test123"
192193
test_num_epochs = 3000
193194

194-
config_fresh.run_id = test_run_id # done in trainer
195195
config_fresh.num_epochs = test_num_epochs # some specific change
196196
config.save(config_fresh, epoch)
197197

198-
cf = config.load_config(private_config_file, test_run_id, epoch)
198+
cf = config.load_config(private_config_file, config_fresh.run_id, epoch)
199199

200200
assert cf.num_epochs == test_num_epochs
201201

@@ -205,7 +205,13 @@ def test_from_cli(options, cf):
205205
parsed_config = config.from_cli_arglist(options)
206206

207207
assert parsed_config == OmegaConf.create(cf)
208-
208+
@pytest.mark.parametrize("run_id,reuse,expected", [(None, False, "generated"), ("new_id", False, "new_id"), (None, True, TEST_RUN_ID), ("new_id", True, TEST_RUN_ID)])
209+
def test_set_run_id(config_fresh, run_id, reuse, expected, mocker):
210+
patch = mocker.patch("weathergen.utils.config.get_run_id", return_value="generated")
211+
212+
config.set_run_id(config_fresh, run_id, reuse)
213+
214+
assert config_fresh.run_id == expected
209215

210216
def test_print_cf_no_secrets(config_fresh):
211217
output = config._format_cf(config_fresh)
@@ -243,9 +249,7 @@ def test_load_malformed_stream(streams_dir):
243249

244250
@pytest.mark.parametrize("epoch", [None, 0, 1, 2, -1]) # maybe add -5 as test case
245251
def test_save(epoch, config_fresh):
246-
test_run_id = "test123"
247-
config_fresh.run_id = test_run_id
248252
config.save(config_fresh, epoch)
249253

250-
cf = config.load_model_config(test_run_id, epoch, config_fresh.model_path)
251-
assert is_equal(cf, config_fresh)
254+
cf = config.load_model_config(config_fresh.run_id, epoch, config_fresh.model_path)
255+
assert is_equal(cf, config_fresh)

0 commit comments

Comments
 (0)