Skip to content

Commit ce0a977

Browse files
kaushikb11carmocca
andauthored
Moved env_vars_connector._defaults_from_env_vars to utilities.argsparse._defaults_from_env_vars (#10501)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 8ea39d2 commit ce0a977

File tree

4 files changed

+24
-41
lines changed

4 files changed

+24
-41
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3737
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))
3838

3939

40+
- Moved `trainer.connectors.env_vars_connector._defaults_from_env_vars` to `utilities.argsparse._defaults_from_env_vars` ([#10501](https://github.com/PyTorchLightning/pytorch-lightning/pull/10501))
41+
42+
4043
- Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426))
4144

4245

pytorch_lightning/trainer/connectors/env_vars_connector.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
5454
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
5555
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
56-
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
5756
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
5857
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
5958
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
@@ -75,6 +74,7 @@
7574
rank_zero_warn,
7675
)
7776
from pytorch_lightning.utilities.argparse import (
77+
_defaults_from_env_vars,
7878
add_argparse_args,
7979
from_argparse_args,
8080
parse_argparser,

pytorch_lightning/utilities/argparse.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from abc import ABC
1717
from argparse import _ArgumentGroup, ArgumentParser, Namespace
1818
from contextlib import suppress
19+
from functools import wraps
1920
from typing import Any, Callable, Dict, List, Tuple, Type, Union
2021

2122
import pytorch_lightning as pl
@@ -312,3 +313,22 @@ def _precision_allowed_type(x: Union[int, str]) -> Union[int, str]:
312313
return int(x)
313314
except ValueError:
314315
return x
316+
317+
318+
def _defaults_from_env_vars(fn: Callable) -> Callable:
319+
@wraps(fn)
320+
def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any:
321+
cls = self.__class__ # get the class
322+
if args: # in case any args passed move them to kwargs
323+
# parse only the argument names
324+
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
325+
# convert args to kwargs
326+
kwargs.update(dict(zip(cls_arg_names, args)))
327+
env_variables = vars(parse_env_variables(cls))
328+
# update the kwargs by env variables
329+
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
330+
331+
# all args were already moved to kwargs
332+
return fn(self, **kwargs)
333+
334+
return insert_env_defaults

0 commit comments

Comments
 (0)