Skip to content

Commit e015a16

Browse files
mauvilsaBorda
authored andcommitted
Allow a custom parser class when using LightningCLI (#20596)
* Allow a custom parser class when using LightningCLI * Update changelog (cherry picked from commit 5073ac1)
1 parent 843c647 commit e015a16

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11-
-
11+
- Allow LightningCLI to use a customized argument parser class ([#20596](https://github.com/Lightning-AI/pytorch-lightning/pull/20596))
1212

1313

1414
### Changed
@@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919
- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored ([#20538](https://github.com/Lightning-AI/pytorch-lightning/pull/20538))
2020

2121

22+
2223
### Removed
2324

2425
-

src/lightning/pytorch/cli.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def __init__(
314314
trainer_defaults: Optional[dict[str, Any]] = None,
315315
seed_everything_default: Union[bool, int] = True,
316316
parser_kwargs: Optional[Union[dict[str, Any], dict[str, dict[str, Any]]]] = None,
317+
parser_class: type[LightningArgumentParser] = LightningArgumentParser,
317318
subclass_mode_model: bool = False,
318319
subclass_mode_data: bool = False,
319320
args: ArgsType = None,
@@ -367,6 +368,7 @@ def __init__(
367368
self.trainer_defaults = trainer_defaults or {}
368369
self.seed_everything_default = seed_everything_default
369370
self.parser_kwargs = parser_kwargs or {}
371+
self.parser_class = parser_class
370372
self.auto_configure_optimizers = auto_configure_optimizers
371373

372374
self.model_class = model_class
@@ -404,7 +406,7 @@ def _setup_parser_kwargs(self, parser_kwargs: dict[str, Any]) -> tuple[dict[str,
404406
def init_parser(self, **kwargs: Any) -> LightningArgumentParser:
405407
"""Method that instantiates the argument parser."""
406408
kwargs.setdefault("dump_header", [f"lightning.pytorch=={pl.__version__}"])
407-
parser = LightningArgumentParser(**kwargs)
409+
parser = self.parser_class(**kwargs)
408410
parser.add_argument(
409411
"-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
410412
)

0 commit comments

Comments
 (0)