Skip to content

Commit b940e08

Browse files
committed
ruffed
1 parent 30dc776 commit b940e08

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

src/weathergen/train/trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from weathergen.model.model import Model, ModelParams
2727
from weathergen.train.lr_scheduler import LearningRateScheduler
2828
from weathergen.train.trainer_base import Trainer_Base
29-
from weathergen.train.utils import get_run_id
3029
from weathergen.utils.config import Config
3130
from weathergen.utils.distributed import is_root
3231
from weathergen.utils.train_logger import TrainLogger
@@ -44,9 +43,12 @@ def __init__(self, checkpoint_freq=250, print_freq=10):
4443
self.print_freq = print_freq
4544

4645
###########################################
47-
def init(self, cf: Config,):
46+
def init(
47+
self,
48+
cf: Config,
49+
):
4850
self.cf = cf
49-
51+
5052
assert cf.samples_per_epoch % cf.batch_size == 0
5153
assert cf.samples_per_validation % cf.batch_size_validation == 0
5254

src/weathergen/utils/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,5 +116,5 @@ def _add_model_loading_params(parser: argparse.ArgumentParser):
116116
parser.add_argument(
117117
"--reuse_run_id",
118118
action="store_true",
119-
help="Use the id given via --from_run_id also for the current run."
119+
help="Use the id given via --from_run_id also for the current run.",
120120
)

src/weathergen/utils/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,20 +118,20 @@ def load_config(
118118
"""
119119
private_config = _load_private_conf(private_home)
120120
overwrite_configs = [_load_overwrite_conf(overwrite) for overwrite in overwrites]
121-
121+
122122
if from_run_id is None:
123123
base_config = _load_default_conf()
124124
else:
125125
base_config = load_model_config(from_run_id, epoch, private_config["model_path"])
126-
127126

128127
# use OmegaConf.unsafe_merge if too slow
129128
return OmegaConf.merge(base_config, private_config, *overwrite_configs)
130129

130+
131131
def set_run_id(config: Config, run_id: str | None, reuse_run_id: bool):
132132
"""
133133
Determine run_id of current run.
134-
134+
135135
Args:
136136
config: Base configuration loaded from previous run or default.
137137
run_id: Id assigned to this run. If None a new one will be generated.
@@ -142,9 +142,9 @@ def set_run_id(config: Config, run_id: str | None, reuse_run_id: bool):
142142
run_id = get_run_id()
143143

144144
config.run_id = run_id
145-
145+
146146
assert config.run_id is not None
147-
147+
148148

149149
def from_cli_arglist(arg_list: list[str]) -> Config:
150150
"""

0 commit comments

Comments
 (0)