Skip to content

Commit 04bd932

Browse files
committed
address review comments
1 parent 952932b commit 04bd932

File tree

3 files changed

+45
-17
lines changed

3 files changed

+45
-17
lines changed

src/weathergen/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def evaluate_from_args(argl: list[str]):
5454
evaluate_overwrite,
5555
cli_overwrite,
5656
)
57-
config.set_run_id(cf, args.run_id, args.reuse_run_id)
57+
cf = config.set_run_id(cf, args.run_id, args.reuse_run_id)
5858

5959
cf.run_history += [(args.from_run_id, cf.istep)]
6060

@@ -103,7 +103,7 @@ def train_continue() -> None:
103103
*args.config,
104104
cli_overwrite,
105105
)
106-
config.set_run_id(cf, args.run_id, args.reuse_run_id)
106+
cf = config.set_run_id(cf, args.run_id, args.reuse_run_id)
107107

108108
# track history of run to ensure traceability of results
109109
cf.run_history += [(args.from_run_id, cf.istep)]
@@ -144,7 +144,7 @@ def train_with_args(argl: list[str], stream_dir: str | None):
144144

145145
cli_overwrite = config.from_cli_arglist(args.options)
146146
cf = config.load_config(args.private_config, None, None, *args.config, cli_overwrite)
147-
config.set_run_id(cf, args.run_id, False)
147+
cf = config.set_run_id(cf, args.run_id, False)
148148

149149
if cf.with_flash_attention:
150150
assert cf.with_mixed_precision

src/weathergen/utils/config.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,20 +129,37 @@ def load_config(
129129

130130
def set_run_id(config: Config, run_id: str | None, reuse_run_id: bool):
131131
"""
132-
Determine run_id of current run.
132+
Determine and set run_id of current run.
133+
134+
Determining the run id should follow the following logic:
135+
136+
1. (default case): run train, train_continue or evaluate without any flags => generate a new run_id for this run.
137+
2. (assign run_id): run train, train_continue or evaluate with --run_id <RUNID> flag => assign a run_id manually to this run
138+
3. (reuse run_id -> only for train_continue and evaluate): reuse the run_id from the run specified by --from_run_id <RUNID>. Since the run_id correct run_id is already loaded in the config nothing has to be assigned. This case will happen if --reuse_run_id is specified.
139+
133140
134141
Args:
135142
config: Base configuration loaded from previous run or default.
136143
run_id: Id assigned to this run. If None a new one will be generated.
137144
reuse_run_id: Reuse run_id from base configuration instead.
145+
146+
Returns:
147+
config object with the run_id attribute properly set.
138148
"""
139-
if not reuse_run_id:
149+
config = config.copy()
150+
if reuse_run_id:
151+
assert config.run_id is not None, "run_id loaded from previous run should not be None."
152+
_logger.info(f"reusing run_id from previous run: {config.run_id}")
153+
else:
140154
if run_id is None:
141-
run_id = get_run_id()
142-
143-
config.run_id = run_id
144-
145-
assert config.run_id is not None
155+
# generate new id if run_id is None
156+
config.run_id = run_id or get_run_id()
157+
_logger.info(f"using generated run_id: {config.run_id}")
158+
else:
159+
config.run_id = run_id
160+
_logger.info(f"using assigned run_id: {config.run_id}")
161+
162+
return config
146163

147164

148165
def from_cli_arglist(arg_list: list[str]) -> Config:

tests/test_config.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def overwrite_file(overwrite_config):
145145
@pytest.fixture
146146
def config_fresh(private_config_file):
147147
cf = config.load_config(private_config_file, None, None)
148-
config.set_run_id(cf, TEST_RUN_ID, False)
148+
cf = config.set_run_id(cf, TEST_RUN_ID, False)
149149
cf.data_loader_rng_seed = 42
150150

151151
return cf
@@ -205,14 +205,25 @@ def test_from_cli(options, cf):
205205
parsed_config = config.from_cli_arglist(options)
206206

207207
assert parsed_config == OmegaConf.create(cf)
208-
@pytest.mark.parametrize("run_id,reuse,expected", [(None, False, "generated"), ("new_id", False, "new_id"), (None, True, TEST_RUN_ID), ("new_id", True, TEST_RUN_ID)])
208+
209+
210+
@pytest.mark.parametrize(
211+
"run_id,reuse,expected",
212+
[
213+
(None, False, "generated"),
214+
("new_id", False, "new_id"),
215+
(None, True, TEST_RUN_ID),
216+
("new_id", True, TEST_RUN_ID),
217+
],
218+
)
209219
def test_set_run_id(config_fresh, run_id, reuse, expected, mocker):
210-
patch = mocker.patch("weathergen.utils.config.get_run_id", return_value="generated")
211-
212-
config.set_run_id(config_fresh, run_id, reuse)
213-
220+
mocker.patch("weathergen.utils.config.get_run_id", return_value="generated")
221+
222+
config_fresh = config.set_run_id(config_fresh, run_id, reuse)
223+
214224
assert config_fresh.run_id == expected
215225

226+
216227
def test_print_cf_no_secrets(config_fresh):
217228
output = config._format_cf(config_fresh)
218229

@@ -252,4 +263,4 @@ def test_save(epoch, config_fresh):
252263
config.save(config_fresh, epoch)
253264

254265
cf = config.load_model_config(config_fresh.run_id, epoch, config_fresh.model_path)
255-
assert is_equal(cf, config_fresh)
266+
assert is_equal(cf, config_fresh)

0 commit comments

Comments
 (0)