You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using a checkpoint in LightningCLI, the model is first instantiated and then the checkpoint is loaded by supplying it to the Trainer's method's ckpt_path argument.
The problem is that hyperparameters in the checkpoint are not used when instantiating the model, and thus when allocating tensors, which can cause checkpoint loading to fail if tensor sizes do not match. Furthermore, if there is complicated instantiation logic in the model, this may lead to other silent bugs or failures.
Here is a minimal example, where the predict method is used. We modify the out_dim in fit, so that the last layer has a different size, causing loading in predict to fail.
Restoring states from the checkpoint path at lightning_logs/version_23/checkpoints/epoch=0-step=64.ckpt
Traceback (most recent call last):
File ".../lightning_cli_load_checkpoint/src/lightning_cli_load_checkpoint/cli.py", line 20, in <module>
cli_main()
File ".../lightning_cli_load_checkpoint/src/lightning_cli_load_checkpoint/cli.py", line 16, in cli_main
cli = MyLightningCLI(DemoModelWithHyperparameters, datamodule_class=BoringDataModule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 398, in __init__
self._run_subcommand(self.subcommand)
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 708, in _run_subcommand
fn(**fn_kwargs)
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 887, in predict
return call._call_and_handle_interrupt(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 928, in _predict_impl
results = self._run(model, ckpt_path=ckpt_path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 409, in _restore_modules_and_callbacks
self.restore_model()
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 286, in restore_model
self.trainer.strategy.load_model_state_dict(
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 372, in load_model_state_dict
self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=strict)
File ".../lightning_cli_load_checkpoint/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for DemoModelWithHyperparameters:
size mismatch for l1.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([10, 32]).
size mismatch for l1.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([10]).
Expected behavior
I'd expect the loading to respect the checkpoint's arguments. In other words, while the current implementation roughly follows this logic:
After thinking a bit I think that implementing this is not trivial. The main difficulty comes from the argument links. Regardless on how the code is changed, it should work well when argument links are added to the parser. I am not sure if I am aware of all the details that need to be considered. And could be that some are only noticed while implementing. For now I would propose to do the following:
Behavior only changes when LightningCLI is working in subcommands mode, i.e. run=True.
Before instantiation and running, check whether ckpt_path is passed to the subcommand. Note that in principle ckpt_path could be defined as a command line argument, in a config file or as an environment variable. So a simple solution could be to parse, check if ckpt_path set, and if set, a second parse would be needed.
If ckpt_path is set, use torch.load to read the checkpoint, and check if hyperparameters are included (i.e. save_hyperparameters was used).
If hyperparameters are included, remove the keys that correspond to link targets (both applied on parse and on instantiate). Unfortunately, right now there is no official way (jsonargparse public API) to know which keys are link targets.
After removing link targets, parse again, but modifying the args such that right after the subcommand (e.g. predict) and before all other arguments, there is a new --config option with value the modified hyperparameters from the checkpoint.
Continue the normal flow which would instantiate classes and then run the trainer method.
I need to figure out what to do about point 4. Most likely new feature in jsonargparse is needed.
Uh oh!
There was an error while loading. Please reload this page.
Bug description
When using a checkpoint in LightningCLI, the model is first instantiated and then the checkpoint is loaded by supplying it to the
Trainer
's method'sckpt_path
argument.The problem is that hyperparameters in the checkpoint are not used when instantiating the model, and thus when allocating tensors, which can cause checkpoint loading to fail if tensor sizes do not match. Furthermore, if there is complicated instantiation logic in the model, this may lead to other silent bugs or failures.
This was first raised as a discussion in #20715
What version are you seeing the problem on?
v2.5
How to reproduce the bug
Here is a minimal example, where the
predict
method is used. We modify theout_dim
in fit, so that the last layer has a different size, causing loading in predict to fail.and then run
Error messages and logs
Expected behavior
I'd expect the loading to respect the checkpoint's arguments. In other words, while the current implementation roughly follows this logic:
I'd expect it to be closer to
Environment
Current environment
The text was updated successfully, but these errors were encountered: