|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import inspect |
15 | 16 | import json
|
16 | 17 | import os
|
17 | 18 | import pickle
|
18 | 19 | import sys
|
19 | 20 | from argparse import Namespace
|
| 21 | +from contextlib import redirect_stdout |
| 22 | +from io import StringIO |
20 | 23 | from unittest import mock
|
21 | 24 |
|
22 | 25 | import pytest
|
23 | 26 | import yaml
|
24 | 27 |
|
25 |
| -from pytorch_lightning import LightningModule, Trainer |
| 28 | +from pytorch_lightning import LightningDataModule, LightningModule, Trainer |
26 | 29 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
27 | 30 | from pytorch_lightning.utilities import _TPU_AVAILABLE
|
28 | 31 | from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback
|
@@ -329,3 +332,98 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
|
329 | 332 | assert config['model'] == cli.config['model']
|
330 | 333 | assert config['data'] == cli.config['data']
|
331 | 334 | assert config['trainer'] == cli.config['trainer']
|
| 335 | + |
| 336 | + |
| 337 | +def any_model_any_data_cli(): |
| 338 | + LightningCLI( |
| 339 | + LightningModule, |
| 340 | + LightningDataModule, |
| 341 | + subclass_mode_model=True, |
| 342 | + subclass_mode_data=True, |
| 343 | + ) |
| 344 | + |
| 345 | + |
| 346 | +def test_lightning_cli_help(): |
| 347 | + |
| 348 | + cli_args = ['any.py', '--help'] |
| 349 | + out = StringIO() |
| 350 | + with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit): |
| 351 | + any_model_any_data_cli() |
| 352 | + |
| 353 | + assert '--print_config' in out.getvalue() |
| 354 | + assert '--config' in out.getvalue() |
| 355 | + assert '--seed_everything' in out.getvalue() |
| 356 | + assert '--model.help' in out.getvalue() |
| 357 | + assert '--data.help' in out.getvalue() |
| 358 | + |
| 359 | + skip_params = {'self'} |
| 360 | + for param in inspect.signature(Trainer.__init__).parameters.keys(): |
| 361 | + if param not in skip_params: |
| 362 | + assert f'--trainer.{param}' in out.getvalue() |
| 363 | + |
| 364 | + cli_args = ['any.py', '--data.help=tests.helpers.BoringDataModule'] |
| 365 | + out = StringIO() |
| 366 | + with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit): |
| 367 | + any_model_any_data_cli() |
| 368 | + |
| 369 | + assert '--data.init_args.data_dir' in out.getvalue() |
| 370 | + |
| 371 | + |
| 372 | +def test_lightning_cli_print_config(): |
| 373 | + |
| 374 | + cli_args = [ |
| 375 | + 'any.py', |
| 376 | + '--seed_everything=1234', |
| 377 | + '--model=tests.helpers.BoringModel', |
| 378 | + '--data=tests.helpers.BoringDataModule', |
| 379 | + '--print_config', |
| 380 | + ] |
| 381 | + |
| 382 | + out = StringIO() |
| 383 | + with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit): |
| 384 | + any_model_any_data_cli() |
| 385 | + |
| 386 | + outval = yaml.safe_load(out.getvalue()) |
| 387 | + assert outval['seed_everything'] == 1234 |
| 388 | + assert outval['model']['class_path'] == 'tests.helpers.BoringModel' |
| 389 | + assert outval['data']['class_path'] == 'tests.helpers.BoringDataModule' |
| 390 | + |
| 391 | + |
| 392 | +def test_lightning_cli_submodules(tmpdir): |
| 393 | + |
| 394 | + class MainModule(BoringModel): |
| 395 | + def __init__( |
| 396 | + self, |
| 397 | + submodule1: LightningModule, |
| 398 | + submodule2: LightningModule, |
| 399 | + main_param: int = 1, |
| 400 | + ): |
| 401 | + super().__init__() |
| 402 | + self.submodule1 = submodule1 |
| 403 | + self.submodule2 = submodule2 |
| 404 | + |
| 405 | + config = """model: |
| 406 | + main_param: 2 |
| 407 | + submodule1: |
| 408 | + class_path: tests.helpers.BoringModel |
| 409 | + submodule2: |
| 410 | + class_path: tests.helpers.BoringModel |
| 411 | + """ |
| 412 | + config_path = tmpdir / 'config.yaml' |
| 413 | + with open(config_path, 'w') as f: |
| 414 | + f.write(config) |
| 415 | + |
| 416 | + cli_args = [ |
| 417 | + f'--trainer.default_root_dir={tmpdir}', |
| 418 | + '--trainer.max_epochs=1', |
| 419 | + f'--config={str(config_path)}', |
| 420 | + ] |
| 421 | + |
| 422 | + with mock.patch('sys.argv', ['any.py'] + cli_args): |
| 423 | + cli = LightningCLI(MainModule) |
| 424 | + |
| 425 | + assert cli.config_init['model']['main_param'] == 2 |
| 426 | + assert cli.model.submodule1 == cli.config_init['model']['submodule1'] |
| 427 | + assert cli.model.submodule2 == cli.config_init['model']['submodule2'] |
| 428 | + assert isinstance(cli.config_init['model']['submodule1'], BoringModel) |
| 429 | + assert isinstance(cli.config_init['model']['submodule2'], BoringModel) |
0 commit comments