Skip to content

link_arguments does not work in lightning 2.3 #20147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
peacekurella opened this issue Aug 1, 2024 · 8 comments · May be fixed by #20777
Open

link_arguments does not work in lightning 2.3 #20147

peacekurella opened this issue Aug 1, 2024 · 8 comments · May be fixed by #20777
Labels
bug Something isn't working lightningcli pl.cli.LightningCLI ver: 2.2.x

Comments

@peacekurella
Copy link

peacekurella commented Aug 1, 2024

Bug description

When using parser.link_arguments to link fields a & b with apply_on="instantiate", it does not populate the field b when it is accessed later. This was not a problem in lightning 2.2.5 as we are using it currently. However upgrading it to 2.3.x causes field b to not be populated.

What version are you seeing the problem on?

2.3.3

How to reproduce the bug

#20147 (comment)

Error messages and logs

Environment

Current environment
  • CUDA:
    - GPU: None
    - available: False
    - version: 12.1
  • Lightning:
    - lightning: 2.2.5
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.3.0
    - torch: 2.3.1
    - torchmetrics: 1.4.0.post0
  • Packages:
    - aiobotocore: 2.7.0
    - aiohttp: 3.9.5
    - aioitertools: 0.7.1
    - aiosignal: 1.2.0
    - alabaster: 0.7.16
    - altair: 5.0.1
    - anyio: 4.2.0
    - appdirs: 1.4.4
    - argon2-cffi: 21.3.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.2.3
    - astroid: 2.14.2
    - astropy: 6.1.0
    - astropy-iers-data: 0.2024.6.3.0.31.14
    - asttokens: 2.0.5
    - async-lru: 2.0.4
    - async-timeout: 4.0.3
    - atomicwrites: 1.4.0
    - attrs: 23.1.0
    - automat: 20.2.0
    - autopep8: 2.0.4
    - babel: 2.11.0
    - bcrypt: 3.2.0
    - beautifulsoup4: 4.12.3
    - binaryornot: 0.4.4
    - black: 24.4.2
    - bleach: 4.1.0
    - blinker: 1.6.2
    - bokeh: 3.4.1
    - boto3: 1.34.131
    - botocore: 1.34.131
    - bottleneck: 1.3.7
    - brotli: 1.0.9
    - cachetools: 5.3.3
    - cattrs: 23.2.3
    - certifi: 2024.6.2
    - cffi: 1.16.0
    - chardet: 4.0.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - cloudpickle: 2.2.1
    - colorama: 0.4.6
    - colorcet: 3.1.0
    - comm: 0.2.1
    - constantly: 23.10.4
    - contourpy: 1.2.0
    - cookiecutter: 2.6.0
    - cryptography: 42.0.5
    - cssselect: 1.2.0
    - cycler: 0.11.0
    - cytoolz: 0.12.2
    - dask: 2024.5.0
    - dask-expr: 1.1.0
    - datasets: 2.14.6
    - datashader: 0.16.2
    - debugpy: 1.6.7
    - decorator: 5.1.1
    - defusedxml: 0.7.1
    - diff-match-patch: 20200713
    - dill: 0.3.7
    - distributed: 2024.5.0
    - docker: 7.1.0
    - docstring-parser: 0.16
    - docstring-to-markdown: 0.11
    - docutils: 0.18.1
    - entrypoints: 0.4
    - et-xmlfile: 1.1.0
    - exceptiongroup: 1.2.0
    - executing: 0.8.3
    - fastjsonschema: 2.16.2
    - filelock: 3.13.1
    - flake8: 7.0.0
    - flask: 3.0.3
    - fonttools: 4.51.0
    - frozenlist: 1.4.0
    - fsspec: 2023.10.0
    - gensim: 4.3.2
    - gitdb: 4.0.7
    - gitpython: 3.1.37
    - gmpy2: 2.1.2
    - google-pasta: 0.2.0
    - greenlet: 3.0.1
    - h5py: 3.11.0
    - heapdict: 1.0.1
    - holoviews: 1.19.0
    - huggingface-hub: 0.23.4
    - hvplot: 0.10.0
    - hyperlink: 21.0.0
    - idna: 3.7
    - imagecodecs: 2023.1.23
    - imageio: 2.33.1
    - imagesize: 1.4.1
    - imbalanced-learn: 0.12.3
    - importlib-metadata: 6.11.0
    - importlib-resources: 6.4.0
    - incremental: 22.10.0
    - inflection: 0.5.1
    - iniconfig: 1.1.1
    - intake: 0.7.0
    - intervaltree: 3.1.0
    - ipykernel: 6.28.0
    - ipython: 8.25.0
    - ipython-genutils: 0.2.0
    - ipywidgets: 7.6.5
    - isort: 5.13.2
    - itemadapter: 0.3.0
    - itemloaders: 1.1.0
    - itsdangerous: 2.2.0
    - jaraco.classes: 3.2.1
    - jedi: 0.18.1
    - jeepney: 0.7.1
    - jellyfish: 1.0.1
    - jinja2: 3.1.4
    - jmespath: 1.0.1
    - joblib: 1.4.2
    - json5: 0.9.6
    - jsonargparse: 4.30.0
    - jsonschema: 4.19.2
    - jsonschema-specifications: 2023.7.1
    - jupyter: 1.0.0
    - jupyter-client: 8.6.0
    - jupyter-console: 6.6.3
    - jupyter-core: 5.5.0
    - jupyter-events: 0.10.0
    - jupyter-lsp: 2.2.0
    - jupyter-server: 2.10.0
    - jupyter-server-terminals: 0.4.4
    - jupyterlab: 4.0.11
    - jupyterlab-pygments: 0.1.2
    - jupyterlab-server: 2.25.1
    - jupyterlab-widgets: 3.0.10
    - keyring: 24.3.1
    - kiwisolver: 1.4.4
    - klon: 2.3.0
    - lazy-loader: 0.4
    - lazy-object-proxy: 1.10.0
    - lckr-jupyterlab-variableinspector: 3.1.0
    - lightning: 2.2.5
    - lightning-utilities: 0.11.2
    - linkify-it-py: 2.0.0
    - llvmlite: 0.42.0
    - lmdb: 1.4.1
    - locket: 1.0.0
    - lsprotocol: 2023.0.1
    - lxml: 4.9.4
    - lxml-stubs: 0.1.1
    - lz4: 4.3.2
    - markdown: 3.4.1
    - markdown-it-py: 2.2.0
    - markupsafe: 2.1.3
    - matplotlib: 3.8.4
    - matplotlib-inline: 0.1.6
    - mccabe: 0.7.0
    - mdit-py-plugins: 0.3.0
    - mdurl: 0.1.0
    - mistune: 2.0.4
    - mkl-fft: 1.3.8
    - mkl-random: 1.2.4
    - mkl-service: 2.4.0
    - more-itertools: 10.1.0
    - mpmath: 1.3.0
    - msgpack: 1.0.3
    - multidict: 6.0.4
    - multipledispatch: 0.6.0
    - multiprocess: 0.70.15
    - mypy: 1.10.0
    - mypy-extensions: 1.0.0
    - nbclient: 0.8.0
    - nbconvert: 7.10.0
    - nbformat: 5.9.2
    - nest-asyncio: 1.6.0
    - networkx: 3.2.1
    - nltk: 3.8.1
    - notebook: 7.0.8
    - notebook-shim: 0.2.3
    - numba: 0.59.1
    - numexpr: 2.8.7
    - numpy: 1.26.4
    - numpydoc: 1.7.0
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu12: 8.9.2.26
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-nccl-cu12: 2.20.5
    - nvidia-nvjitlink-cu12: 12.5.40
    - nvidia-nvtx-cu12: 12.1.105
    - openpyxl: 3.1.2
    - overrides: 7.4.0
    - packaging: 23.2
    - pandas: 2.2.2
    - pandocfilters: 1.5.0
    - panel: 1.4.4
    - param: 2.1.0
    - parsel: 1.8.1
    - parso: 0.8.3
    - partd: 1.4.1
    - pathos: 0.3.1
    - pathspec: 0.10.3
    - patsy: 0.5.6
    - pexpect: 4.8.0
    - pickleshare: 0.7.5
    - pillow: 10.3.0
    - pip: 24.0
    - platformdirs: 3.10.0
    - plotly: 5.22.0
    - pluggy: 1.5.0
    - ply: 3.11
    - pox: 0.3.4
    - ppft: 1.7.6.8
    - prometheus-client: 0.14.1
    - prompt-toolkit: 3.0.43
    - protego: 0.1.16
    - protobuf: 3.20.3
    - psutil: 5.9.0
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - py-cpuinfo: 9.0.0
    - pyarrow: 14.0.2
    - pyasn1: 0.4.8
    - pyasn1-modules: 0.2.8
    - pycodestyle: 2.11.1
    - pycparser: 2.21
    - pyct: 0.5.0
    - pycurl: 7.45.2
    - pydeck: 0.8.0
    - pydispatcher: 2.0.5
    - pydocstyle: 6.3.0
    - pyerfa: 2.0.1.4
    - pyflakes: 3.2.0
    - pygls: 1.3.1
    - pygments: 2.15.1
    - pylint: 2.16.2
    - pylint-venv: 3.0.3
    - pyls-spyder: 0.4.0
    - pyodbc: 5.0.1
    - pyopenssl: 24.0.0
    - pyparsing: 3.0.9
    - pyproj: 3.6.1
    - pyqt5: 5.15.10
    - pyqt5-sip: 12.13.0
    - pyqtwebengine: 5.15.6
    - pysocks: 1.7.1
    - pytest: 8.2.2
    - python-dateutil: 2.9.0.post0
    - python-json-logger: 2.0.7
    - python-lsp-black: 2.0.0
    - python-lsp-jsonrpc: 1.1.2
    - python-lsp-server: 1.10.0
    - python-slugify: 5.0.2
    - python-snappy: 0.6.1
    - pytoolconfig: 1.2.6
    - pytorch-lightning: 2.3.0
    - pytz: 2024.1
    - pyviz-comms: 3.0.2
    - pywavelets: 1.5.0
    - pyxdg: 0.27
    - pyyaml: 6.0.1
    - pyzmq: 25.1.2
    - qdarkstyle: 3.2.3
    - qstylizer: 0.2.2
    - qtawesome: 1.2.2
    - qtconsole: 5.5.1
    - qtpy: 2.4.1
    - queuelib: 1.6.2
    - referencing: 0.30.2
    - regex: 2023.10.3
    - requests: 2.32.3
    - requests-file: 1.5.1
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rich: 13.3.5
    - rope: 1.12.0
    - rpds-py: 0.10.6
    - rtree: 1.0.1
    - ruff: 0.4.9
    - ruff-lsp: 0.0.53
    - s3fs: 2023.10.0
    - s3transfer: 0.10.1
    - sagemaker: 2.224.0
    - schema: 0.7.7
    - scikit-image: 0.23.2
    - scikit-learn: 1.4.2
    - scipy: 1.11.4
    - scrapy: 2.11.1
    - seaborn: 0.13.2
    - secretstorage: 3.3.1
    - send2trash: 1.8.2
    - service-identity: 18.1.0
    - setuptools: 69.5.1
    - sip: 6.7.12
    - six: 1.16.0
    - smart-open: 5.2.1
    - smdebug-rulesconfig: 1.0.1
    - smmap: 4.0.0
    - sniffio: 1.3.0
    - snowballstemmer: 2.2.0
    - sortedcontainers: 2.4.0
    - soupsieve: 2.5
    - sphinx: 7.3.7
    - sphinxcontrib-applehelp: 1.0.2
    - sphinxcontrib-devhelp: 1.0.2
    - sphinxcontrib-htmlhelp: 2.0.0
    - sphinxcontrib-jsmath: 1.0.1
    - sphinxcontrib-qthelp: 1.0.3
    - sphinxcontrib-serializinghtml: 1.1.10
    - spyder: 5.5.1
    - spyder-kernels: 2.5.0
    - sqlalchemy: 2.0.30
    - stack-data: 0.2.0
    - statsmodels: 0.14.2
    - streamlit: 1.32.0
    - sympy: 1.12
    - tables: 3.9.2
    - tabulate: 0.9.0
    - tblib: 1.7.0
    - tenacity: 8.2.2
    - tensorboardx: 2.6.2.2
    - terminado: 0.17.1
    - text-unidecode: 1.3
    - textdistance: 4.2.1
    - threadpoolctl: 2.2.0
    - three-merge: 0.1.1
    - tifffile: 2023.4.12
    - tinycss2: 1.2.1
    - tldextract: 3.2.0
    - toml: 0.10.2
    - tomli: 2.0.1
    - tomlkit: 0.11.1
    - toolz: 0.12.0
    - torch: 2.3.1
    - torchmetrics: 1.4.0.post0
    - tornado: 6.4.1
    - tqdm: 4.66.4
    - traitlets: 5.14.3
    - triton: 2.3.1
    - twisted: 23.10.0
    - typeshed-client: 2.5.1
    - typing-extensions: 4.11.0
    - tzdata: 2023.3
    - uc-micro-py: 1.0.1
    - ujson: 5.10.0
    - unicodedata2: 15.1.0
    - unidecode: 1.2.0
    - urllib3: 2.0.7
    - w3lib: 2.1.2
    - watchdog: 4.0.1
    - wcwidth: 0.2.5
    - webencodings: 0.5.1
    - websocket-client: 1.8.0
    - werkzeug: 3.0.3
    - whatthepatch: 1.0.2
    - wheel: 0.43.0
    - widgetsnbextension: 3.5.2
    - wrapt: 1.14.1
    - wurlitzer: 3.0.2
    - xarray: 2023.6.0
    - xxhash: 3.4.1
    - xyzservices: 2022.9.0
    - yapf: 0.40.2
    - yarl: 1.9.3
    - zict: 3.0.0
    - zipp: 3.17.0
    - zope.interface: 5.4.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.13
    - release: 5.10.220-188.869.amzn2int.x86_64
    - version: Proposal for help #1 SMP Wed Jul 17 14:39:49 UTC 2024

More info

No response

cc @carmocca @mauvilsa

@peacekurella peacekurella added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 1, 2024
@peacekurella
Copy link
Author

I noticed that the drop down menu does not contain 2.3.x as part of the version selection.

@awaelchli
Copy link
Contributor

awaelchli commented Aug 1, 2024

Hey @peacekurella can you please provide a code example based on https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/pytorch/bug_report/bug_report_model.py so we can verify it's not working?

@awaelchli awaelchli added repro needed The issue is missing a reproducible example and removed needs triage Waiting to be triaged by maintainers labels Aug 1, 2024
@peacekurella
Copy link
Author

I can do that.

@peacekurella
Copy link
Author

import torch
from typing import Type, TypeVar
from lightning.pytorch import LightningModule
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.cli import LightningCLI
from lightning import LightningDataModule
from lightning.pytorch.callbacks import ModelCheckpoint

class MyLightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        
        parser.add_argument("data.destinationaddressid_vocab_size", default=10)
        parser.add_argument("model.destinationaddressid_vocab_size")
        parser.add_argument("--ckpt_path_ex", type=str, default = None)

        parser.link_arguments(
            "data.destinationaddressid_vocab_size",
            "model.destinationaddressid_vocab_size",
            apply_on="instantiate",
        )
    
    def before_instantiate_classes(self) -> None:
        if self.config.ckpt_path_ex:
            print("restoring from checkpoint")
            # we are restoring from a checkpoint
            CheckpointModuleInstantiatiorCLI.before_instantiate_classes(self)

class MyDataModule(LightningDataModule):
    def __init__(self, destinationaddressid_vocab_size: int = None):
        super().__init__()
        self.destinationaddressid_vocab_size = destinationaddressid_vocab_size
        print(f"The value of destinationaddressid_vocab_size in data module is {destinationaddressid_vocab_size}")
    
    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)
    
    def predict_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

ModuleType = TypeVar("ModuleType")

class CheckpointModuleInstantiatiorCLI:
    def __init__(self, cli: LightningCLI):
        self.cli = cli

    def class_instantiator(self, class_type: Type[ModuleType], *args, **kwargs) -> ModuleType:
        if args:
            raise ValueError("Unexpected args")

        map_location = None if torch.cuda.is_available() else "cpu"
        defaults = self.cli.parser.get_defaults()
        if class_type == BoringModel:
            non_default_kwargs = {k: v for k, v in kwargs.items() if defaults.model.get(k) != v}
            return BoringModel.load_from_checkpoint(
                self.cli.config.ckpt_path_ex,
                map_location=map_location,
                **non_default_kwargs,
            )
        elif class_type == MyDataModule:
            non_default_kwargs = {k: v for k, v in kwargs.items() if defaults.data.get(k) != v}
            return MyDataModule.load_from_checkpoint(
                self.cli.config.ckpt_path_ex,
                map_location=map_location,
                **non_default_kwargs,
            )
        else:
            raise ValueError("Unexpected class")

    @staticmethod
    def before_instantiate_classes(cli: LightningCLI) -> None:
        instantiator = CheckpointModuleInstantiatiorCLI(cli)
        cli.parser.add_instantiator(instantiator.class_instantiator, BoringModel)
        cli.parser.add_instantiator(instantiator.class_instantiator, MyDataModule)



class BoringModel(LightningModule):
    def __init__(self, destinationaddressid_vocab_size):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.destinationaddressid_vocab_size = destinationaddressid_vocab_size
        self.save_hyperparameters()
        print(f"The value of destinationaddressid_vocab_size in model module is {self.destinationaddressid_vocab_size}")
        

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run(args):

    cli = MyLightningCLI(
        BoringModel,
        MyDataModule,
        args=args,
        trainer_defaults={"callbacks": [ModelCheckpoint(dirpath="ckpts")]},
        run=False,
    )

    cli.trainer.fit(
        model=cli.model,
        datamodule=cli.datamodule,
        ckpt_path=cli.config.ckpt_path_ex if cli.config.ckpt_path_ex else None,
    )

if __name__ == "__main__":
    run(args=None)

Running with lightning 2.2.5

  1. generate checkpoints python bug_report.py --data.destinationaddressid_vocab_size 15 --trainer.max_epoch=1 . This prints
The value of destinationaddressid_vocab_size in data module is 15
The value of destinationaddressid_vocab_size in model module is 15
  1. load the model from checkpoints python bug_report.py --trainer.max_epoch=2 --ckpt_path_ex ckpts/epoch=0-step=32.ckpt this prints
The value of destinationaddressid_vocab_size in data module is None
The value of destinationaddressid_vocab_size in model module is 15

Running with lightning 2.3.3

  1. generate checkpoints python bug_report.py --data.destinationaddressid_vocab_size 15 --trainer.max_epoch=1 . This prints
The value of destinationaddressid_vocab_size in data module is 15
The value of destinationaddressid_vocab_size in model module is 15
  1. load the model from checkpoints python bug_report.py --trainer.max_epoch=2 --ckpt_path_ex ckpts/epoch=0-step=32.ckpt this prints
The value of destinationaddressid_vocab_size in data module is 10
The value of destinationaddressid_vocab_size in model module is 10

@peacekurella
Copy link
Author

@awaelchli added the repro code and scenarios with outputs.

@awaelchli
Copy link
Contributor

Ok thanks @peacekurella. But the default value is 10, and in the second command you don't pass --data.destinationaddressid_vocab_size 15. When you resume training, you certainly would need to pass the same configuration. We can't expect that the output is 15 in the second example, if data.destinationaddressid_vocab_size is not passed.

@awaelchli awaelchli added lightningcli pl.cli.LightningCLI and removed repro needed The issue is missing a reproducible example labels Aug 3, 2024
@peacekurella
Copy link
Author

peacekurella commented Aug 5, 2024

@awaelchli The way I understand it, save_hyperparameters() is not storing the values for parameters that have been linked previously. This was not the case in lightning 2.2.5. This is a problem when restoring from ckpt files for inference. Typically we try to get all the required HP for inference from the ckpt file itself.

@mauvilsa
Copy link
Contributor

I noticed that this is a duplicate of #20311. Even though this issue is older, there is a temporal workaround in #20311 (comment).

Additional to the workaround, I created just now pull request #20777 with a potential fix for this. Would be nice if those of you affected review and test it out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working lightningcli pl.cli.LightningCLI ver: 2.2.x
Projects
None yet
3 participants