Skip to content

Commit 4e2be2f

Browse files
committed
Merge branch 'master' into mypy_utilities_auto-restart
2 parents d35be35 + f1cc6e3 commit 4e2be2f

File tree

9 files changed

+110
-94
lines changed

9 files changed

+110
-94
lines changed

CHANGELOG.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4242
- The `trainer.lightning_module` reference is now properly set at the very beginning of the run ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))
4343

4444

45-
- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352)))
45+
- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352))
4646

4747

4848
- The `Trainer` functions `reset_{train,val,test,predict}_dataloader`, `reset_train_val_dataloaders`, and `request_dataloader` `model` argument is now optional ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))
@@ -54,6 +54,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5454
- Improved string conversion for `ResultCollection` ([#8622](https://github.com/PyTorchLightning/pytorch-lightning/pull/8622))
5555

5656

57+
- `LightningCLI` changes:
58+
* `LightningCLI.init_parser` now returns the parser instance. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721))
59+
* `LightningCLI.add_core_arguments_to_parser`, `LightningCLI.parse_arguments` now take a `parser` argument. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721))
60+
* `LightningCLI.instantiate_trainer` now takes a config and a list of callbacks. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721))
61+
* Split `LightningCLI.add_core_arguments_to_parser` into `LightningCLI.add_default_arguments_to_parser` + `LightningCLI.add_core_arguments_to_parser`. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721))
62+
63+
5764
- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))
5865

5966
- Removed restrictions in the trainer that loggers can only log from rank 0. Existing logger behavior has not changed. ([#8608]
@@ -102,6 +109,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
102109
- Removed the deprecated `model` argument from `ModelCheckpoint.save_checkpoint` ([#8688](https://github.com/PyTorchLightning/pytorch-lightning/pull/8688))
103110

104111

112+
- Removed the deprecated `sync_step` argument from `WandbLogger` ([#8763](https://github.com/PyTorchLightning/pytorch-lightning/pull/8763))
113+
114+
105115
### Fixed
106116

107117
- Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601))

pl_examples/run_examples.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ set -ex
44
dir_path=$(dirname "${BASH_SOURCE[0]}")
55
args="--trainer.max_epochs=1 --data.batch_size=32 --trainer.limit_train_batches=2 --trainer.limit_val_batches=2"
66

7-
python "${dir_path}/basic_examples/simple_image_classifier.py" "$@" ${args}
8-
python "${dir_path}/basic_examples/backbone_image_classifier.py" "$@" ${args}
9-
python "${dir_path}/basic_examples/autoencoder.py" "$@" ${args}
7+
python "${dir_path}/basic_examples/simple_image_classifier.py" ${args} "$@"
8+
python "${dir_path}/basic_examples/backbone_image_classifier.py" ${args} "$@"
9+
python "${dir_path}/basic_examples/autoencoder.py" ${args} "$@"

pytorch_lightning/loggers/wandb.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def __init__(
113113
log_model: Optional[bool] = False,
114114
experiment=None,
115115
prefix: Optional[str] = "",
116-
sync_step: Optional[bool] = None,
117116
**kwargs,
118117
):
119118
if wandb is None:
@@ -136,12 +135,6 @@ def __init__(
136135
"Hint: Upgrade with `pip install --ugrade wandb`."
137136
)
138137

139-
if sync_step is not None:
140-
warning_cache.deprecation(
141-
"`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
142-
" Metrics are now logged separately and automatically synchronized."
143-
)
144-
145138
super().__init__()
146139
self._offline = offline
147140
self._log_model = log_model

pytorch_lightning/trainer/trainer.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
3333
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop
3434
from pytorch_lightning.loops.fit_loop import FitLoop
35-
from pytorch_lightning.plugins import Plugin
35+
from pytorch_lightning.plugins import DDPSpawnPlugin, Plugin
3636
from pytorch_lightning.plugins.environments import ClusterEnvironment
3737
from pytorch_lightning.profiler import (
3838
AdvancedProfiler,
@@ -76,7 +76,6 @@
7676
)
7777
from pytorch_lightning.utilities.debugging import InternalDebugger
7878
from pytorch_lightning.utilities.distributed import distributed_available
79-
from pytorch_lightning.utilities.enums import DistributedType
8079
from pytorch_lightning.utilities.exceptions import MisconfigurationException
8180
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
8281
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -633,8 +632,8 @@ def test(
633632
test_dataloaders=None, # TODO: remove with 1.6
634633
) -> _EVALUATE_OUTPUT:
635634
r"""
636-
Perform one evaluation epoch over the test set. It's separated from
637-
fit to make sure you never run on your test set until you want to.
635+
Perform one evaluation epoch over the test set.
636+
It's separated from fit to make sure you never run on your test set until you want to.
638637
639638
Args:
640639
model: The model to test.
@@ -711,9 +710,9 @@ def predict(
711710
ckpt_path: Optional[str] = None,
712711
) -> Optional[_PREDICT_OUTPUT]:
713712
r"""
714-
715-
Separates from fit to make sure you never run on your predictions set until you want to.
716-
This will call the model forward function to compute predictions.
713+
Run inference on your data.
714+
This will call the model forward function to compute predictions. Useful to perform distributed
715+
and batched predictions. Logging is disabled in the predict hooks.
717716
718717
Args:
719718
model: The model to predict with.
@@ -947,7 +946,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
947946

948947
# teardown if necessary (similar calls for spawn plugins are excluded as they have
949948
# been included at the end of `new_process` functions)
950-
if self._distrib_type not in DistributedType.interactive_compatible_types():
949+
if not isinstance(self.training_type_plugin, DDPSpawnPlugin):
951950
self._call_teardown_hook()
952951

953952
if self.state.status != TrainerStatus.INTERRUPTED:
@@ -1078,7 +1077,7 @@ def _run_train(self) -> None:
10781077
self.training_type_plugin.reconciliate_processes(traceback.format_exc())
10791078
# give accelerators a chance to finish
10801079
self.accelerator.on_train_end()
1081-
self._on_expection()
1080+
self._on_exception()
10821081
# reset bookkeeping
10831082
self.state.stage = None
10841083
raise
@@ -1334,7 +1333,7 @@ def _log_device_info(self) -> None:
13341333
" `Trainer(ipus=8)` or script `--ipus=8`."
13351334
)
13361335

1337-
def _on_expection(self):
1336+
def _on_exception(self):
13381337
if not _fault_tolerant_enabled():
13391338
return
13401339
# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.

pytorch_lightning/utilities/cli.py

Lines changed: 71 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,17 @@
1313
# limitations under the License.
1414
import inspect
1515
import os
16-
import warnings
1716
from argparse import Namespace
1817
from types import MethodType
1918
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
2019

2120
from torch.optim import Optimizer
2221

23-
from pytorch_lightning.callbacks import Callback
24-
from pytorch_lightning.core.datamodule import LightningDataModule
25-
from pytorch_lightning.core.lightning import LightningModule
26-
from pytorch_lightning.trainer.trainer import Trainer
22+
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
23+
from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, warnings
2724
from pytorch_lightning.utilities.cloud_io import get_filesystem
2825
from pytorch_lightning.utilities.exceptions import MisconfigurationException
29-
from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE
3026
from pytorch_lightning.utilities.model_helpers import is_overridden
31-
from pytorch_lightning.utilities.seed import seed_everything
3227
from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple
3328

3429
if _JSONARGPARSE_AVAILABLE:
@@ -79,6 +74,9 @@ def add_lightning_class_args(
7974
lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}.
8075
nested_key: Name of the nested namespace to store arguments.
8176
subclass_mode: Whether allow any subclass of the given class.
77+
78+
Returns:
79+
A list with the names of the class arguments added.
8280
"""
8381
if callable(lightning_class) and not inspect.isclass(lightning_class):
8482
lightning_class = class_from_function(lightning_class)
@@ -191,7 +189,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st
191189

192190
def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]:
193191
# `ArgumentParser` is un-pickleable. Drop it
194-
return (self.__class__, (None, self.config, self.config_filename), {})
192+
return self.__class__, (None, self.config, self.config_filename), {}
195193

196194

197195
class LightningCLI:
@@ -205,21 +203,22 @@ def __init__(
205203
save_config_filename: str = "config.yaml",
206204
save_config_overwrite: bool = False,
207205
trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer,
208-
trainer_defaults: Dict[str, Any] = None,
209-
seed_everything_default: int = None,
206+
trainer_defaults: Optional[Dict[str, Any]] = None,
207+
seed_everything_default: Optional[int] = None,
210208
description: str = "pytorch-lightning trainer command line tool",
211209
env_prefix: str = "PL",
212210
env_parse: bool = False,
213-
parser_kwargs: Dict[str, Any] = None,
211+
parser_kwargs: Optional[Dict[str, Any]] = None,
214212
subclass_mode_model: bool = False,
215213
subclass_mode_data: bool = False,
216214
) -> None:
217215
"""
218216
Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are
219-
called / instantiated using a parsed configuration file and / or command line args and then runs trainer.fit.
220-
Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``. A full
221-
configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed from variables
222-
named for example ``PL_TRAINER__MAX_EPOCHS``.
217+
called / instantiated using a parsed configuration file and / or command line args.
218+
219+
Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``.
220+
A full configuration yaml would be parsed from ``PL_CONFIG`` if set.
221+
Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``.
223222
224223
Example, first implement the ``trainer.py`` tool as::
225224
@@ -266,56 +265,73 @@ def __init__(
266265
self.save_config_filename = save_config_filename
267266
self.save_config_overwrite = save_config_overwrite
268267
self.trainer_class = trainer_class
269-
self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults
268+
self.trainer_defaults = trainer_defaults or {}
270269
self.seed_everything_default = seed_everything_default
271270
self.subclass_mode_model = subclass_mode_model
272271
self.subclass_mode_data = subclass_mode_data
273-
self.parser_kwargs = {} if parser_kwargs is None else parser_kwargs
274-
self.parser_kwargs.update({"description": description, "env_prefix": env_prefix, "default_env": env_parse})
275272

276-
self.init_parser()
277-
self.add_core_arguments_to_parser()
278-
self.add_arguments_to_parser(self.parser)
273+
parser_kwargs = parser_kwargs or {}
274+
parser_kwargs.update({"description": description, "env_prefix": env_prefix, "default_env": env_parse})
275+
self.setup_parser(**parser_kwargs)
279276
self.link_optimizers_and_lr_schedulers()
280-
self.parse_arguments()
281-
if self.config["seed_everything"] is not None:
282-
seed_everything(self.config["seed_everything"], workers=True)
277+
self.parse_arguments(self.parser)
278+
279+
seed = self.config.get("seed_everything")
280+
if seed is not None:
281+
seed_everything(seed, workers=True)
282+
283283
self.before_instantiate_classes()
284284
self.instantiate_classes()
285285
self.add_configure_optimizers_method_to_model()
286+
286287
self.prepare_fit_kwargs()
287288
self.before_fit()
288289
self.fit()
289290
self.after_fit()
290291

291-
def init_parser(self) -> None:
292-
"""Method that instantiates the argument parser"""
293-
self.parser = LightningArgumentParser(**self.parser_kwargs)
292+
def init_parser(self, **kwargs: Any) -> LightningArgumentParser:
293+
"""Method that instantiates the argument parser."""
294+
return LightningArgumentParser(**kwargs)
295+
296+
def setup_parser(self, **kwargs: Any) -> None:
297+
"""Initialize and setup the parser, and arguments."""
298+
self.parser = self.init_parser(**kwargs)
299+
self._add_arguments(self.parser)
294300

295-
def add_core_arguments_to_parser(self) -> None:
296-
"""Adds arguments from the core classes to the parser"""
297-
self.parser.add_argument(
301+
def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
302+
"""Adds default arguments to the parser."""
303+
parser.add_argument(
298304
"--seed_everything",
299305
type=Optional[int],
300306
default=self.seed_everything_default,
301307
help="Set to an int to run seed_everything with this value before classes instantiation",
302308
)
303-
self.parser.add_lightning_class_args(self.trainer_class, "trainer")
309+
310+
def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
311+
"""Adds arguments from the core classes to the parser."""
312+
parser.add_lightning_class_args(self.trainer_class, "trainer")
304313
trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"}
305-
self.parser.set_defaults(trainer_defaults)
306-
self.parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model)
314+
parser.set_defaults(trainer_defaults)
315+
parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model)
307316
if self.datamodule_class is not None:
308-
self.parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data)
317+
parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data)
318+
319+
def _add_arguments(self, parser: LightningArgumentParser) -> None:
320+
# default + core + custom arguments
321+
self.add_default_arguments_to_parser(parser)
322+
self.add_core_arguments_to_parser(parser)
323+
self.add_arguments_to_parser(parser)
309324

310325
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
311-
"""Implement to add extra arguments to parser or link arguments
326+
"""
327+
Implement to add extra arguments to the parser or link arguments.
312328
313329
Args:
314-
parser: The argument parser object to which arguments can be added
330+
parser: The parser object to which arguments can be added
315331
"""
316332

317333
def link_optimizers_and_lr_schedulers(self) -> None:
318-
"""Creates argument links for optimizers and lr_schedulers that specified a link_to"""
334+
"""Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``."""
319335
for key, (class_type, link_to) in self.parser.optimizers_and_lr_schedulers.items():
320336
if link_to == "AUTOMATIC":
321337
continue
@@ -325,41 +341,40 @@ def link_optimizers_and_lr_schedulers(self) -> None:
325341
add_class_path = _add_class_path_generator(class_type)
326342
self.parser.link_arguments(key, link_to, compute_fn=add_class_path)
327343

328-
def parse_arguments(self) -> None:
329-
"""Parses command line arguments and stores it in self.config"""
330-
self.config = self.parser.parse_args()
344+
def parse_arguments(self, parser: LightningArgumentParser) -> None:
345+
"""Parses command line arguments and stores it in ``self.config``."""
346+
self.config = parser.parse_args()
331347

332348
def before_instantiate_classes(self) -> None:
333-
"""Implement to run some code before instantiating the classes"""
349+
"""Implement to run some code before instantiating the classes."""
334350

335351
def instantiate_classes(self) -> None:
336-
"""Instantiates the classes using settings from self.config"""
352+
"""Instantiates the classes and sets their attributes."""
337353
self.config_init = self.parser.instantiate_classes(self.config)
338354
self.datamodule = self.config_init.get("data")
339355
self.model = self.config_init["model"]
340-
self.instantiate_trainer()
341-
342-
def instantiate_trainer(self) -> None:
343-
"""Instantiates the trainer using self.config_init['trainer']"""
344-
if self.config_init["trainer"].get("callbacks") is None:
345-
self.config_init["trainer"]["callbacks"] = []
346356
callbacks = [self.config_init[c] for c in self.parser.callback_keys]
347-
self.config_init["trainer"]["callbacks"].extend(callbacks)
357+
self.trainer = self.instantiate_trainer(self.config_init["trainer"], callbacks)
358+
359+
def instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer:
360+
"""Instantiates the trainer."""
361+
config["callbacks"] = config["callbacks"] or []
362+
config["callbacks"].extend(callbacks)
348363
if "callbacks" in self.trainer_defaults:
349364
if isinstance(self.trainer_defaults["callbacks"], list):
350-
self.config_init["trainer"]["callbacks"].extend(self.trainer_defaults["callbacks"])
365+
config["callbacks"].extend(self.trainer_defaults["callbacks"])
351366
else:
352-
self.config_init["trainer"]["callbacks"].append(self.trainer_defaults["callbacks"])
353-
if self.save_config_callback and not self.config_init["trainer"]["fast_dev_run"]:
367+
config["callbacks"].append(self.trainer_defaults["callbacks"])
368+
if self.save_config_callback and not config["fast_dev_run"]:
354369
config_callback = self.save_config_callback(
355370
self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite
356371
)
357-
self.config_init["trainer"]["callbacks"].append(config_callback)
358-
self.trainer = self.trainer_class(**self.config_init["trainer"])
372+
config["callbacks"].append(config_callback)
373+
return self.trainer_class(**config)
359374

360375
def add_configure_optimizers_method_to_model(self) -> None:
361376
"""
362-
Adds to the model an automatically generated configure_optimizers method
377+
Adds to the model an automatically generated ``configure_optimizers`` method.
363378
364379
If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC',
365380
then a `configure_optimizers` method is automatically implemented in the model class.
@@ -390,7 +405,7 @@ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]:
390405
)
391406

392407
if is_overridden("configure_optimizers", self.model):
393-
warnings.warn(
408+
warnings._warn(
394409
f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by "
395410
f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`."
396411
)

0 commit comments

Comments
 (0)