Skip to content

Inconcistency in loading from checkpoint in LightningCLI #20801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Northo opened this issue May 6, 2025 · 1 comment
Open

Inconcistency in loading from checkpoint in LightningCLI #20801

Northo opened this issue May 6, 2025 · 1 comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.5.x

Comments

@Northo
Copy link

Northo commented May 6, 2025

Bug description

When using a checkpoint in LightningCLI, the model is first instantiated and then the checkpoint is loaded by supplying it to the Trainer's method's ckpt_path argument.

The problem is that hyperparameters in the checkpoint are not used when instantiating the model, and thus when allocating tensors, which can cause checkpoint loading to fail if tensor sizes do not match. Furthermore, if there is complicated instantiation logic in the model, this may lead to other silent bugs or failures.

This was first raised as a discussion in #20715

What version are you seeing the problem on?

v2.5

How to reproduce the bug

Here is a minimal example, where the predict method is used. We modify the out_dim in fit, so that the last layer has a different size, causing loading in predict to fail.

# cli.py
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule

class DemoModelWithHyperparameters(DemoModel):
    def __init__(self, *args, **kwargs):
        self.save_hyperparameters()
        super().__init__(*args, **kwargs)

def cli_main():
    cli = LightningCLI(DemoModelWithHyperparameters, BoringDataModule)

if __name__ == "__main__":
    cli_main()

and then run

$ python src/lightning_cli_load_checkpoint/cli.py fit --trainer.max_epochs 1 --model.out_dim 2
$ python src/lightning_cli_load_checkpoint/cli.py predict --ckpt_path <path_to_checkpoint>

Error messages and logs

Restoring states from the checkpoint path at lightning_logs/version_23/checkpoints/epoch=0-step=64.ckpt
Traceback (most recent call last):
  File ".../lightning_cli_load_checkpoint/src/lightning_cli_load_checkpoint/cli.py", line 20, in <module>
    cli_main()
  File ".../lightning_cli_load_checkpoint/src/lightning_cli_load_checkpoint/cli.py", line 16, in cli_main
    cli = MyLightningCLI(DemoModelWithHyperparameters, datamodule_class=BoringDataModule)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 398, in __init__
    self._run_subcommand(self.subcommand)
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 708, in _run_subcommand
    fn(**fn_kwargs)
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 887, in predict
    return call._call_and_handle_interrupt(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 928, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 409, in _restore_modules_and_callbacks
    self.restore_model()
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 286, in restore_model
    self.trainer.strategy.load_model_state_dict(
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 372, in load_model_state_dict
    self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=strict)
  File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for DemoModelWithHyperparameters:
	size mismatch for l1.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([10, 32]).
	size mismatch for l1.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([10]).

Expected behavior

I'd expect the loading to respect the checkpoint's arguments. In other words, while the current implementation roughly follows this logic:

model = Model(**cli_args)
Trainer().predict(model, data, ckpt_path=ckpt_path)

I'd expect it to be closer to

model = Model.load_from_checkpoint(ckpt_path, **cli_args)
Trainer().predict(model, data, ckpt_path=ckpt_path)

Environment

Current environment
  • CUDA:
    • GPU: None
    • available: False
    • version: None
  • Lightning:
    • lightning: 2.5.1
    • lightning-cli-load-checkpoint: 0.1.0
    • lightning-utilities: 0.14.3
    • pytorch-lightning: 2.5.1
    • torch: 2.6.0
    • torchmetrics: 1.7.1
  • Packages:
    • aiohappyeyeballs: 2.6.1
    • aiohttp: 3.11.16
    • aiosignal: 1.3.2
    • antlr4-python3-runtime: 4.9.3
    • attrs: 25.3.0
    • autocommand: 2.2.2
    • backports.tarfile: 1.2.0
    • contourpy: 1.3.1
    • cycler: 0.12.1
    • docstring-parser: 0.16
    • filelock: 3.18.0
    • fonttools: 4.57.0
    • frozenlist: 1.5.0
    • fsspec: 2025.3.2
    • hydra-core: 1.3.2
    • idna: 3.10
    • importlib-metadata: 8.0.0
    • importlib-resources: 6.5.2
    • inflect: 7.3.1
    • jaraco.collections: 5.1.0
    • jaraco.context: 5.3.0
    • jaraco.functools: 4.0.1
    • jaraco.text: 3.12.1
    • jinja2: 3.1.6
    • jsonargparse: 4.38.0
    • kiwisolver: 1.4.8
    • lightning: 2.5.1
    • lightning-cli-load-checkpoint: 0.1.0
    • lightning-utilities: 0.14.3
    • markdown-it-py: 3.0.0
    • markupsafe: 3.0.2
    • matplotlib: 3.10.1
    • mdurl: 0.1.2
    • more-itertools: 10.3.0
    • mpmath: 1.3.0
    • multidict: 6.4.3
    • networkx: 3.4.2
    • numpy: 2.2.4
    • omegaconf: 2.3.0
    • packaging: 24.2
    • pillow: 11.2.1
    • platformdirs: 4.2.2
    • propcache: 0.3.1
    • protobuf: 6.30.2
    • pygments: 2.19.1
    • pyparsing: 3.2.3
    • python-dateutil: 2.9.0.post0
    • pytorch-lightning: 2.5.1
    • pyyaml: 6.0.2
    • rich: 13.9.4
    • setuptools: 78.1.0
    • six: 1.17.0
    • sympy: 1.13.1
    • tensorboardx: 2.6.2.2
    • tomli: 2.0.1
    • torch: 2.6.0
    • torchmetrics: 1.7.1
    • tqdm: 4.67.1
    • typeguard: 4.3.0
    • typeshed-client: 2.7.0
    • typing-extensions: 4.13.2
    • wheel: 0.45.1
    • yarl: 1.19.0
    • zipp: 3.19.2
  • System:
    • OS: Darwin
    • architecture:
      • 64bit
    • processor: arm
    • python: 3.12.7
    • release: 24.4.0
    • version: Darwin Kernel Version 24.4.0: Fri Apr 11 18:33:47 PDT 2025; root:xnu-11417.101.15~117/RELEASE_ARM64_T6000
@Northo Northo added bug Something isn't working needs triage Waiting to be triaged by maintainers labels May 6, 2025
@mauvilsa
Copy link
Contributor

mauvilsa commented May 8, 2025

After thinking a bit I think that implementing this is not trivial. The main difficulty comes from the argument links. Regardless on how the code is changed, it should work well when argument links are added to the parser. I am not sure if I am aware of all the details that need to be considered. And could be that some are only noticed while implementing. For now I would propose to do the following:

  1. Behavior only changes when LightningCLI is working in subcommands mode, i.e. run=True.
  2. Before instantiation and running, check whether ckpt_path is passed to the subcommand. Note that in principle ckpt_path could be defined as a command line argument, in a config file or as an environment variable. So a simple solution could be to parse, check if ckpt_path set, and if set, a second parse would be needed.
  3. If ckpt_path is set, use torch.load to read the checkpoint, and check if hyperparameters are included (i.e. save_hyperparameters was used).
  4. If hyperparameters are included, remove the keys that correspond to link targets (both applied on parse and on instantiate). Unfortunately, right now there is no official way (jsonargparse public API) to know which keys are link targets.
  5. After removing link targets, parse again, but modifying the args such that right after the subcommand (e.g. predict) and before all other arguments, there is a new --config option with value the modified hyperparameters from the checkpoint.
  6. Continue the normal flow which would instantiate classes and then run the trainer method.

I need to figure out what to do about point 4. Most likely new feature in jsonargparse is needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.5.x
Projects
None yet
Development

No branches or pull requests

2 participants