Skip to content

Callbacks are not saved to the config file #7540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
tshu-w opened this issue May 14, 2021 · 13 comments
Closed

Callbacks are not saved to the config file #7540

tshu-w opened this issue May 14, 2021 · 13 comments
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) help wanted Open to be worked on question Further information is requested

Comments

@tshu-w
Copy link
Contributor

tshu-w commented May 14, 2021

🐛 Bug

LightningCLI cannot save callbacks to config file.

Please reproduce using the BoringModel

python train.py --print_config
# or
python train.py
# or
python train.py --config config.yaml # with config like this https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html#trainer-callbacks-and-arguments-with-class-type

train.py:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.cli import LightningCLI

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
    def forward(self, x):
        return self.layer(x)
    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}
    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)
    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)
    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)
    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)
    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)
    def test_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

def run():
    early_stopping = EarlyStopping(monitor="valid_loss")
    checkpoint_callback = ModelCheckpoint(dirpath="logs", monitor="valid_loss")

    cli = LightningCLI(
        BoringModel,
        seed_everything_default=123,
        trainer_defaults={
            "max_epochs": 2,
            "callbacks": [
                checkpoint_callback,
                early_stopping,
            ],
            "logger": {
                "class_path": "pytorch_lightning.loggers.TestTubeLogger",
                "init_args": {
                    "save_dir": "logs",
                    "create_git_tag": True,
                },
            },
        },
    )
    cli.trainer.test(cli.model)

if __name__ == '__main__':
    run()

Expected behavior

Callbacks were saved to the config file. And should the logger and callbacks provided to train_defaults be of the same form?

Environment

* CUDA:
	- GPU:
		- TITAN RTX
		- TITAN RTX
		- TITAN RTX
		- TITAN RTX
		- TITAN RTX
		- TITAN RTX
		- TITAN RTX
		- TITAN RTX
		- TITAN RTX
		- TITAN RTX
	- available:         True
	- version:           10.2
* Packages:
	- numpy:             1.19.2
	- pyTorch_debug:     False
	- pyTorch_version:   1.8.1+cu102
	- pytorch-lightning: 1.3.1
	- tqdm:              4.50.2
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- ELF
	- processor:         x86_64
	- python:            3.8.5
	- version:           #146-Ubuntu SMP Tue Apr 13 01:11:19 UTC 2021
@tshu-w tshu-w added bug Something isn't working help wanted Open to be worked on labels May 14, 2021
@awaelchli
Copy link
Contributor

It cannot save the callbacks to the config file, because these are objects. You need to provide class path I believe. It is explained here:
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html#trainer-callbacks-and-arguments-with-class-type

Basically, since you instantiate them yourself LightningCLI is unaware of them and doesn't know how to instantiate them. This does not just apply to callbacks.

@awaelchli awaelchli added the argparse (removed) Related to argument parsing (argparse, Hydra, ...) label May 14, 2021
@tshu-w
Copy link
Contributor Author

tshu-w commented May 14, 2021

It cannot save the callbacks to the config file, because these are objects. You need to provide class path I believe. It is explained here:
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html#trainer-callbacks-and-arguments-with-class-type

As I mention, I also tried python train.py --config config.yaml --print_config with config.yaml correctly set as the above link described. Callbacks settings still not print.

Basically, since you instantiate them yourself LightningCLI is unaware of them and doesn't know how to instantiate them. This does not just apply to callbacks.

I have to instantiate them because I cannot pass them like logger like the following code. Should we handle callbacks the same way we handle loggers?

cli = LightningCLI(
    BoringModel,
    seed_everything_default=123,
    trainer_defaults={
        "max_epochs": 2,
        "callbacks": [
            {
                "class_path": "pytorch_lightning.callbacks.ModelCheckpoint"
                "init_args": {
                    ...
                }
            },
        ],
        "logger": {
            "class_path": "pytorch_lightning.loggers.TestTubeLogger",
            "init_args": {
                "save_dir": "logs",
                "create_git_tag": True,
            },
        },
    },
)

@awaelchli
Copy link
Contributor

awaelchli commented May 14, 2021

Try this:

seed_everything: 123
trainer:
  callbacks:
    - class_path: pytorch_lightning.callbacks.EarlyStopping
      init_args:
        patience: 5
        monitor: valid_loss
    - class_path: pytorch_lightning.callbacks.ModelCheckpoint
      init_args:
        monitor: valid_loss
        dirpath: here

And then when we print the config it outputs everything:

seed_everything: 123
trainer:
  logger: true
  checkpoint_callback: true
  callbacks:
  - class_path: pytorch_lightning.callbacks.EarlyStopping
    init_args:
      monitor: valid_loss
      min_delta: 0.0
      patience: 5
      verbose: false
      mode: min
      strict: true
      check_finite: true
      check_on_train_epoch_end: false
  - class_path: pytorch_lightning.callbacks.ModelCheckpoint
    init_args:
      dirpath: here
      monitor: valid_loss
      verbose: false
      save_weights_only: false
      mode: min
      auto_insert_metric_name: true
  default_root_dir: null
....

I have to instantiate them because I cannot pass them like logger like the following code. Should we handle callbacks the same way we handle loggers?

Yes, I think the idea of jsonargparse is that we can instantiate any object like this.

cc @carmocca @mauvilsa

@tshu-w
Copy link
Contributor Author

tshu-w commented May 14, 2021

Try this:

Thank you for your detailed reply, sorry I may have made a mistake somewhere before.

@mauvilsa
Copy link
Contributor

Some additional clarification. The description of trainer_defaults is Set to override Trainer defaults or add persistent callbacks. The callbacks added via trainer_defaults are not configurable. They will always be present for this particular CLI. Since they are not configurable, it is by intention that they should not be in the config file. If they were in the config file then these callbacks would be added twice because callbacks is a list and specifying callbacks in the config appends to the existing list.

In summary. If you want persistent callbacks add them in trainer_defaults and they shouldn't be in the config file. If you don't want persistent callbacks then they should be specified in the config file and not in trainer_defaults.

@awaelchli awaelchli added question Further information is requested and removed bug Something isn't working labels May 14, 2021
@awaelchli
Copy link
Contributor

@tshu-w does it answer your questions or are there still any open problems here?

@tshu-w
Copy link
Contributor Author

tshu-w commented May 15, 2021

@mauvilsa Thank you for your detailed replies! I understand that the current setting is to only provide persistent callbacks in the code. As a newbie, I'm trying to understand what the rationale behind specializing callbacks is, and why not provide callback configuration to trainer_defaults like logger does.

@mauvilsa
Copy link
Contributor

@tshu-w It is great that you are challenging the concept. This is precisely why LightningCLI is in beta, so that important cases are identified, discussed and improved for the stable version.

Regarding your comment. You are comparing logger and callbacks but they are conceptually different cases. The callbacks are a list whereas there is a single logger. With a list the behavior of a default is not necessarily obvious. If for example there are three default callbacks. When the callbacks parameter is overridden with a config file there could be several possibilities. For the example lets say that a single callback is included in the config. One possibility could be that all default callbacks are discarded and the single given callback is used. Another possibility could be that callbacks are appended to default list so in the example there would be four callbacks used (this is what is currently implemented). Another possibility could be that someone wants to change some settings of a default callback. In this case how do you specify that you want to modify say the second callback and not be required to include in your config of the settings of the first and third callback which shouldn't change?

Note that in general what LightningCLI does is to look at the type hints of a parameter and make it configurable. This is why it is able to work with user defined modules and datamodules. The behavior for lists is that the whole list is replaced when overriding. Since in lightning the callbacks are important, the persistent callbacks seemed like a useful feature. Maybe we should make the persistent callbacks an independent init parameter of LightningCLI.

@tshu-w
Copy link
Contributor Author

tshu-w commented May 17, 2021

@mauvilsa Thx your replies! This makes sense.

Another possibility could be that someone wants to change some settings of a default callback. In this case how do you specify that you want to modify say the second callback and not be required to include in your config of the settings of the first and third callback which shouldn't change?

Last Question, since it is now appended to the back of the list, can it override the previous callback of the same type?

@tshu-w tshu-w closed this as completed May 17, 2021
@tshu-w tshu-w reopened this May 17, 2021
@mauvilsa
Copy link
Contributor

Last Question, since it is now appended to the back of the list, can it override the previous callback of the same type?

The current implementation no. It just appends. But it could be a good idea to change it so that callbacks of the same type are overridden instead of appended.

@carmocca
Copy link
Contributor

But it could be a good idea to change it so that callbacks of the same type are overridden instead of appended.

We have plans to add better support for multiple callbacks of the same type, so overriding would be problematic as users would use multiple callbacks of the same type.

#6467

@tshu-w
Copy link
Contributor Author

tshu-w commented May 18, 2021

Thank you for your efforts and it's great to see pytorchlightning keep getting better!

@apple2373
Copy link

apple2373 commented Nov 24, 2023

I came to this issue when I want to have a default callback of ModelCheckpoint but sometimes want to change the monitor to other metrics by command line args. I think the following is the supported way to achieve what i want to do.

from lightning.pytorch.callbacks import ModelCheckpoint
# from pytorch_lightning.callbacks import ModelCheckpoint # for some reasons, this won't work.... 

class MyLightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.add_lightning_class_args(ModelCheckpoint, "checkpoint")
        parser.set_defaults({"checkpoint.monitor": "val_acc1", "checkpoint.mode": "max"})

c.f. https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_expert.html#configure-forced-callbacks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) help wanted Open to be worked on question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants