diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7ab8002ce..84550c5ef 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,7 +28,7 @@ jobs: python -m pip install ".[dev,all_extras,github-actions]" - name: Show dependencies - run: python -m pip list + run: python -m pip list - name: Run example notebooks run: build_tools/run_examples.sh @@ -103,7 +103,8 @@ jobs: - name: Install dependencies shell: bash run: | - pip install ".[dev,all_extras,github-actions]" + python -m pip install --upgrade pip + python -m pip install ".[dev,all_extras,github-actions]" - name: Show dependencies run: python -m pip list diff --git a/README.md b/README.md index b3b162c7c..6fa8bcdb5 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,12 @@ _PyTorch Forecasting_ is a PyTorch-based package for forecasting with state-of-the-art deep learning architectures. It provides a high-level API and uses [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/) to scale training on GPU or CPU, with automatic logging. - -| | **[Documentation](https://pytorch-forecasting.readthedocs.io)** · **[Tutorials](https://pytorch-forecasting.readthedocs.io/en/latest/tutorials.html)** · **[Release Notes](https://pytorch-forecasting.readthedocs.io/en/latest/CHANGELOG.html)** | -|---|---| -| **Open Source** | [![MIT](https://img.shields.io/github/license/sktime/pytorch-forecasting)](https://github.com/sktime/pytorch-forecasting/blob/master/LICENSE) | -| **Community** | [![!discord](https://img.shields.io/static/v1?logo=discord&label=discord&message=chat&color=lightgreen)](https://discord.com/invite/54ACzaFsn7) [![!slack](https://img.shields.io/static/v1?logo=linkedin&label=LinkedIn&message=news&color=lightblue)](https://www.linkedin.com/company/scikit-time/) | -| **CI/CD** | [![github-actions](https://img.shields.io/github/actions/workflow/status/sktime/pytorch-forecasting/pypi_release.yml?logo=github)](https://github.com/sktime/pytorch-forecasting/actions/workflows/pypi_release.yml) [![readthedocs](https://img.shields.io/readthedocs/pytorch-forecasting?logo=readthedocs)](https://pytorch-forecasting.readthedocs.io) [![platform](https://img.shields.io/conda/pn/conda-forge/pytorch-forecasting)](https://github.com/sktime/pytorch-forecasting) [![Code Coverage][coverage-image]][coverage-url] | -| **Code** | [![!pypi](https://img.shields.io/pypi/v/pytorch-forecasting?color=orange)](https://pypi.org/project/pytorch-forecasting/) [![!conda](https://img.shields.io/conda/vn/conda-forge/pytorch-forecasting)](https://anaconda.org/conda-forge/pytorch-forecasting) [![!python-versions](https://img.shields.io/pypi/pyversions/pytorch-forecasting)](https://www.python.org/) [![!black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) | +| | **[Documentation](https://pytorch-forecasting.readthedocs.io)** · **[Tutorials](https://pytorch-forecasting.readthedocs.io/en/latest/tutorials.html)** · **[Release Notes](https://pytorch-forecasting.readthedocs.io/en/latest/CHANGELOG.html)** | +| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| **Open Source** | [![MIT](https://img.shields.io/github/license/sktime/pytorch-forecasting)](https://github.com/sktime/pytorch-forecasting/blob/master/LICENSE) | +| **Community** | [![!discord](https://img.shields.io/static/v1?logo=discord&label=discord&message=chat&color=lightgreen)](https://discord.com/invite/54ACzaFsn7) [![!slack](https://img.shields.io/static/v1?logo=linkedin&label=LinkedIn&message=news&color=lightblue)](https://www.linkedin.com/company/scikit-time/) | +| **CI/CD** | [![github-actions](https://img.shields.io/github/actions/workflow/status/sktime/pytorch-forecasting/pypi_release.yml?logo=github)](https://github.com/sktime/pytorch-forecasting/actions/workflows/pypi_release.yml) [![readthedocs](https://img.shields.io/readthedocs/pytorch-forecasting?logo=readthedocs)](https://pytorch-forecasting.readthedocs.io) [![platform](https://img.shields.io/conda/pn/conda-forge/pytorch-forecasting)](https://github.com/sktime/pytorch-forecasting) [![Code Coverage][coverage-image]][coverage-url] | +| **Code** | [![!pypi](https://img.shields.io/pypi/v/pytorch-forecasting?color=orange)](https://pypi.org/project/pytorch-forecasting/) [![!conda](https://img.shields.io/conda/vn/conda-forge/pytorch-forecasting)](https://anaconda.org/conda-forge/pytorch-forecasting) [![!python-versions](https://img.shields.io/pypi/pyversions/pytorch-forecasting)](https://www.python.org/) [![!black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) | [coverage-image]: https://codecov.io/gh/sktime/pytorch-forecasting/branch/main/graph/badge.svg [coverage-url]: https://codecov.io/github/sktime/pytorch-forecasting?branch=main diff --git a/poetry.lock b/poetry.lock index 0b1922170..189a55ebd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -679,7 +679,6 @@ files = [ {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18a64814ae7bce73925131381603fff0116e2df25230dfc80d6d690aa6e20b37"}, {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90c81f22b4f572f8a2110b0b741bb64e5a6427e0a198b2cdc1fbaf85f352a3aa"}, {file = "contourpy-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53cc3a40635abedbec7f1bde60f8c189c49e84ac180c665f2cd7c162cc454baa"}, - {file = "contourpy-1.1.0-cp310-cp310-win32.whl", hash = "sha256:9b2dd2ca3ac561aceef4c7c13ba654aaa404cf885b187427760d7f7d4c57cff8"}, {file = "contourpy-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:1f795597073b09d631782e7245016a4323cf1cf0b4e06eef7ea6627e06a37ff2"}, {file = "contourpy-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0b7b04ed0961647691cfe5d82115dd072af7ce8846d31a5fac6c142dcce8b882"}, {file = "contourpy-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27bc79200c742f9746d7dd51a734ee326a292d77e7d94c8af6e08d1e6c15d545"}, @@ -688,7 +687,6 @@ files = [ {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5cec36c5090e75a9ac9dbd0ff4a8cf7cecd60f1b6dc23a374c7d980a1cd710e"}, {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f0cbd657e9bde94cd0e33aa7df94fb73c1ab7799378d3b3f902eb8eb2e04a3a"}, {file = "contourpy-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:181cbace49874f4358e2929aaf7ba84006acb76694102e88dd15af861996c16e"}, - {file = "contourpy-1.1.0-cp311-cp311-win32.whl", hash = "sha256:edb989d31065b1acef3828a3688f88b2abb799a7db891c9e282df5ec7e46221b"}, {file = "contourpy-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fb3b7d9e6243bfa1efb93ccfe64ec610d85cfe5aec2c25f97fbbd2e58b531256"}, {file = "contourpy-1.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bcb41692aa09aeb19c7c213411854402f29f6613845ad2453d30bf421fe68fed"}, {file = "contourpy-1.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5d123a5bc63cd34c27ff9c7ac1cd978909e9c71da12e05be0231c608048bb2ae"}, @@ -697,7 +695,6 @@ files = [ {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:317267d915490d1e84577924bd61ba71bf8681a30e0d6c545f577363157e5e94"}, {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d551f3a442655f3dcc1285723f9acd646ca5858834efeab4598d706206b09c9f"}, {file = "contourpy-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e7a117ce7df5a938fe035cad481b0189049e8d92433b4b33aa7fc609344aafa1"}, - {file = "contourpy-1.1.0-cp38-cp38-win32.whl", hash = "sha256:108dfb5b3e731046a96c60bdc46a1a0ebee0760418951abecbe0fc07b5b93b27"}, {file = "contourpy-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:d4f26b25b4f86087e7d75e63212756c38546e70f2a92d2be44f80114826e1cd4"}, {file = "contourpy-1.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc00bb4225d57bff7ebb634646c0ee2a1298402ec10a5fe7af79df9a51c1bfd9"}, {file = "contourpy-1.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:189ceb1525eb0655ab8487a9a9c41f42a73ba52d6789754788d1883fb06b2d8a"}, @@ -706,7 +703,6 @@ files = [ {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:143dde50520a9f90e4a2703f367cf8ec96a73042b72e68fcd184e1279962eb6f"}, {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e94bef2580e25b5fdb183bf98a2faa2adc5b638736b2c0a4da98691da641316a"}, {file = "contourpy-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ed614aea8462735e7d70141374bd7650afd1c3f3cb0c2dbbcbe44e14331bf002"}, - {file = "contourpy-1.1.0-cp39-cp39-win32.whl", hash = "sha256:71551f9520f008b2950bef5f16b0e3587506ef4f23c734b71ffb7b89f8721999"}, {file = "contourpy-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:438ba416d02f82b692e371858143970ed2eb6337d9cdbbede0d8ad9f3d7dd17d"}, {file = "contourpy-1.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a698c6a7a432789e587168573a864a7ea374c6be8d4f31f9d87c001d5a843493"}, {file = "contourpy-1.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:397b0ac8a12880412da3551a8cb5a187d3298a72802b45a3bd1805e204ad8439"}, @@ -1273,7 +1269,6 @@ files = [ {file = "greenlet-2.0.2-cp27-cp27m-win32.whl", hash = "sha256:6c3acb79b0bfd4fe733dff8bc62695283b57949ebcca05ae5c129eb606ff2d74"}, {file = "greenlet-2.0.2-cp27-cp27m-win_amd64.whl", hash = "sha256:283737e0da3f08bd637b5ad058507e578dd462db259f7f6e4c5c365ba4ee9343"}, {file = "greenlet-2.0.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:d27ec7509b9c18b6d73f2f5ede2622441de812e7b1a80bbd446cb0633bd3d5ae"}, - {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d967650d3f56af314b72df7089d96cda1083a7fc2da05b375d2bc48c82ab3f3c"}, {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:30bcf80dda7f15ac77ba5af2b961bdd9dbc77fd4ac6105cee85b0d0a5fcf74df"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26fbfce90728d82bc9e6c38ea4d038cba20b7faf8a0ca53a9c07b67318d46088"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9190f09060ea4debddd24665d6804b995a9c122ef5917ab26e1566dcc712ceeb"}, @@ -1282,7 +1277,6 @@ files = [ {file = "greenlet-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:76ae285c8104046b3a7f06b42f29c7b73f77683df18c49ab5af7983994c2dd91"}, {file = "greenlet-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:2d4686f195e32d36b4d7cf2d166857dbd0ee9f3d20ae349b6bf8afc8485b3645"}, {file = "greenlet-2.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c4302695ad8027363e96311df24ee28978162cdcdd2006476c43970b384a244c"}, - {file = "greenlet-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d4606a527e30548153be1a9f155f4e283d109ffba663a15856089fb55f933e47"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c48f54ef8e05f04d6eff74b8233f6063cb1ed960243eacc474ee73a2ea8573ca"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a1846f1b999e78e13837c93c778dcfc3365902cfb8d1bdb7dd73ead37059f0d0"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a06ad5312349fec0ab944664b01d26f8d1f05009566339ac6f63f56589bc1a2"}, @@ -1312,7 +1306,6 @@ files = [ {file = "greenlet-2.0.2-cp37-cp37m-win32.whl", hash = "sha256:3f6ea9bd35eb450837a3d80e77b517ea5bc56b4647f5502cd28de13675ee12f7"}, {file = "greenlet-2.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:7492e2b7bd7c9b9916388d9df23fa49d9b88ac0640db0a5b4ecc2b653bf451e3"}, {file = "greenlet-2.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b864ba53912b6c3ab6bcb2beb19f19edd01a6bfcbdfe1f37ddd1778abfe75a30"}, - {file = "greenlet-2.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1087300cf9700bbf455b1b97e24db18f2f77b55302a68272c56209d5587c12d1"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:ba2956617f1c42598a308a84c6cf021a90ff3862eddafd20c3333d50f0edb45b"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc3a569657468b6f3fb60587e48356fe512c1754ca05a564f11366ac9e306526"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8eab883b3b2a38cc1e050819ef06a7e6344d4a990d24d45bc6f2cf959045a45b"}, @@ -1321,7 +1314,6 @@ files = [ {file = "greenlet-2.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0ef99cdbe2b682b9ccbb964743a6aca37905fda5e0452e5ee239b1654d37f2a"}, {file = "greenlet-2.0.2-cp38-cp38-win32.whl", hash = "sha256:b80f600eddddce72320dbbc8e3784d16bd3fb7b517e82476d8da921f27d4b249"}, {file = "greenlet-2.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:4d2e11331fc0c02b6e84b0d28ece3a36e0548ee1a1ce9ddde03752d9b79bba40"}, - {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8512a0c38cfd4e66a858ddd1b17705587900dd760c6003998e9472b77b56d417"}, {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:88d9ab96491d38a5ab7c56dd7a3cc37d83336ecc564e4e8816dbed12e5aaefc8"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:561091a7be172ab497a3527602d467e2b3fbe75f9e783d8b8ce403fa414f71a6"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:971ce5e14dc5e73715755d0ca2975ac88cfdaefcaab078a284fea6cfabf866df"}, @@ -1968,6 +1960,24 @@ cli = ["fire"] docs = ["requests (>=2.0.0)"] typing = ["mypy (>=1.0.0)", "types-setuptools"] +[[package]] +name = "loguru" +version = "0.7.2" +description = "Python logging made (stupidly) simple" +optional = false +python-versions = ">=3.5" +files = [ + {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, + {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"}, +] + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] + [[package]] name = "mako" version = "1.2.4" @@ -2032,16 +2042,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -3726,7 +3726,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -3734,7 +3733,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, @@ -3760,7 +3758,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -3768,7 +3765,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -4993,6 +4989,20 @@ files = [ {file = "widgetsnbextension-4.0.10.tar.gz", hash = "sha256:64196c5ff3b9a9183a8e699a4227fb0b7002f252c814098e66c4d1cd0644688f"}, ] +[[package]] +name = "win32-setctime" +version = "1.1.0" +description = "A small Python utility to set file creation time on Windows" +optional = false +python-versions = ">=3.5" +files = [ + {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, + {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, +] + +[package.extras] +dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] + [[package]] name = "yarl" version = "1.9.2" diff --git a/pyproject.toml b/pyproject.toml index fbc8bb8fa..76332243b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,9 +31,7 @@ name = "pytorch-forecasting" readme = "README.md" # Markdown files are supported version = "1.1.1" # is being replaced automatically -authors = [ - {name = "Jan Beitner"}, -] +authors = [{ name = "Jan Beitner" }] requires-python = ">=3.8,<3.13" classifiers = [ "Intended Audience :: Developers", @@ -89,11 +87,7 @@ all_extras = [ "statsmodels", ] -tuning = [ - "optuna >=3.1.0,<4.0.0", - "optuna-integration", - "statsmodels", -] +tuning = ["optuna >=3.1.0,<4.0.0", "optuna-integration", "statsmodels"] mqf2 = ["cpflows"] @@ -134,15 +128,7 @@ dev = [ ] # docs - dependencies for building the documentation -docs = [ - "sphinx>3.2", - "pydata-sphinx-theme", - "nbsphinx", - "pandoc", - "nbconvert", - "recommonmark", - "docutils", -] +docs = ["sphinx>3.2", "pydata-sphinx-theme", "nbsphinx", "pandoc", "nbconvert", "recommonmark", "docutils"] github-actions = ["pytest-github-actions-annotate-failures"] @@ -151,6 +137,4 @@ exclude = ["build_tools"] [build-system] build-backend = "setuptools.build_meta" -requires = [ - "setuptools>=70.0.0", -] +requires = ["setuptools>=70.0.0"] diff --git a/pytorch_forecasting/__init__.py b/pytorch_forecasting/__init__.py index 47f927614..9a06e028d 100644 --- a/pytorch_forecasting/__init__.py +++ b/pytorch_forecasting/__init__.py @@ -41,6 +41,7 @@ BaseModelWithCovariates, DecoderMLP, DeepAR, + LSTMModel, MultiEmbedding, NBeats, NHiTS, @@ -72,6 +73,7 @@ "TemporalFusionTransformer", "NBeats", "NHiTS", + "LSTMModel", "Baseline", "DeepAR", "BaseModel", diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index bc955159f..4490f5a31 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -475,8 +475,8 @@ def __init__( # preprocess data data = self._preprocess_data(data) - for target in self.target_names: - assert target not in self.scalers, "Target normalizer is separate and not in scalers." + # for target in self.target_names: + # assert target not in self.scalers, "Target normalizer is separate and not in scalers." # create index self.index = self._construct_index(data, predict_mode=self.predict_mode) @@ -1569,9 +1569,9 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: # switch some variables to nan if encode length is 0 if encoder_length == 0 and len(self.dropout_categoricals) > 0: - data_cat[:, [self.flat_categoricals.index(c) for c in self.dropout_categoricals]] = ( - 0 # zero is encoded nan - ) + fc = self.flat_categoricals + dc = self.dropout_categoricals + data_cat[:, [fc.index(c) for c in dc]] = 0 # zero is encoded nan assert decoder_length > 0, "Decoder length should be greater than 0" assert encoder_length >= 0, "Encoder length should be at least 0" diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index ec143f2f8..1ce95f40f 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -17,6 +17,8 @@ from pytorch_forecasting.models.rnn import RecurrentNetwork from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer +from .lstm import LSTMModel + __all__ = [ "NBeats", "NHiTS", @@ -33,4 +35,5 @@ "GRU", "MultiEmbedding", "DecoderMLP", + "LSTMModel", ] diff --git a/pytorch_forecasting/models/_base_autoregressive.py b/pytorch_forecasting/models/_base_autoregressive.py new file mode 100644 index 000000000..86fd58504 --- /dev/null +++ b/pytorch_forecasting/models/_base_autoregressive.py @@ -0,0 +1,141 @@ +__all__ = ["AutoRegressiveBaseModel"] + +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union + +import torch +from torch import Tensor + +from pytorch_forecasting.metrics import DistributionLoss, MultiLoss +from pytorch_forecasting.models.base_model import AutoRegressiveBaseModel as AutoRegressiveBaseModel_ +from pytorch_forecasting.utils import apply_to_list, to_list + + +class AutoRegressiveBaseModel(AutoRegressiveBaseModel_): # pylint: disable=abstract-method + """Basically AutoRegressiveBaseModel from `pytorch_forecasting` but fixed for multi-target. Worked for `LSTM`.""" + + def output_to_prediction( + self, + normalized_prediction_parameters: torch.Tensor, + target_scale: Union[List[torch.Tensor], torch.Tensor], + n_samples: int = 1, + **kwargs: Any, + ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], torch.Tensor]: + """ + Convert network output to rescaled and normalized prediction. + Function is typically not called directly but via :py:meth:`~decode_autoregressive`. + Args: + normalized_prediction_parameters (torch.Tensor): network prediction output + target_scale (Union[List[torch.Tensor], torch.Tensor]): target scale to rescale network output + n_samples (int, optional): Number of samples to draw independently. Defaults to 1. + **kwargs: extra arguments for dictionary passed to :py:meth:`~transform_output` method. + Returns: + Tuple[Union[List[torch.Tensor], torch.Tensor], torch.Tensor]: tuple of rescaled prediction and + normalized prediction (e.g. for input into next auto-regressive step) + """ + B = normalized_prediction_parameters.size(0) + D = normalized_prediction_parameters.size(-1) + single_prediction = to_list(normalized_prediction_parameters)[0].ndim == 2 + if single_prediction: # add time dimension as it is expected + normalized_prediction_parameters = apply_to_list(normalized_prediction_parameters, lambda x: x.unsqueeze(1)) + # transform into real space + prediction_parameters = self.transform_output( + prediction=normalized_prediction_parameters, target_scale=target_scale, **kwargs + ) + + # sample value(s) from distribution and select first sample + if isinstance(self.loss, DistributionLoss) or ( + isinstance(self.loss, MultiLoss) and isinstance(self.loss[0], DistributionLoss) + ): + if n_samples > 1: + prediction_parameters = apply_to_list( + prediction_parameters, lambda x: x.reshape(int(x.size(0) / n_samples), n_samples, -1) + ) + prediction = self.loss.sample(prediction_parameters, 1) + prediction = apply_to_list(prediction, lambda x: x.reshape(x.size(0) * n_samples, 1, -1)) + else: + prediction = self.loss.sample(normalized_prediction_parameters, 1) + else: + prediction = prediction_parameters + # normalize prediction prediction + normalized_prediction = self.output_transformer.transform(prediction, target_scale=target_scale) + if isinstance(normalized_prediction, list): + input_target = normalized_prediction[-1] # torch.cat(normalized_prediction, dim=-1) # dim=-1 + else: + input_target = normalized_prediction # set next input target to normalized prediction + assert input_target.size(0) == B + assert input_target.size(-1) == D, f"{input_target.size()} but D={D}" + # remove time dimension + if single_prediction: + prediction = apply_to_list(prediction, lambda x: x.squeeze(1)) + input_target = input_target.squeeze(1) + return prediction, input_target + + def decode_autoregressive( + self, + decode_one: Callable, + first_target: Union[List[torch.Tensor], torch.Tensor], + first_hidden_state: Any, + target_scale: Union[List[torch.Tensor], torch.Tensor], + n_decoder_steps: int, + n_samples: int = 1, + **kwargs: Any, + ) -> Union[List[torch.Tensor], torch.Tensor]: + """ + Make predictions in auto-regressive manner. Supports only continuous targets. + Args: + decode_one (Callable): + function that takes at least the following arguments: + * ``idx`` (int): index of decoding step (from 0 to n_decoder_steps-1) + * ``lagged_targets`` (List[torch.Tensor]): list of normalized targets. + List is ``idx + 1`` elements long with the most recent entry at the end, i.e. + ``previous_target = lagged_targets[-1]`` and in general ``lagged_targets[-lag]``. + * ``hidden_state`` (Any): Current hidden state required for prediction. Keys are variable + names. Only lags that are greater than ``idx`` are included. + * additional arguments are not dynamic but can be passed via the ``**kwargs`` argument And + returns tuple of (not rescaled) network prediction output and hidden state for next + auto-regressive step. + first_target (Union[List[torch.Tensor], torch.Tensor]): first target value to use for decoding + first_hidden_state (Any): first hidden state used for decoding + target_scale (Union[List[torch.Tensor], torch.Tensor]): target scale as in ``x`` + n_decoder_steps (int): number of decoding/prediction steps + n_samples (int): number of independent samples to draw from the distribution - + only relevant for multivariate models. Defaults to 1. + **kwargs: additional arguments that are passed to the decode_one function. + Returns: + Union[List[torch.Tensor], torch.Tensor]: re-scaled prediction + """ + # make predictions which are fed into next step + output: List[Union[List[Tensor], Tensor]] = [] + current_hidden_state = first_hidden_state + normalized_output = [first_target] + for idx in range(n_decoder_steps): + # get lagged targets + current_target, current_hidden_state = decode_one( + idx, lagged_targets=normalized_output, hidden_state=current_hidden_state, **kwargs + ) + assert isinstance(current_target, Tensor) + # get prediction and its normalized version for the next step + prediction, current_target = self.output_to_prediction( + current_target, target_scale=target_scale, n_samples=n_samples + ) + + # save normalized output for lagged targets + normalized_output.append(current_target) + # set output to unnormalized samples, append each target as n_batch_samples x n_random_samples + output.append(prediction) + + if isinstance(self.hparams.target, str): + # Here, output is List[Tensor] + final_output = torch.stack(output, dim=1) # type: ignore + return final_output + # For multi-targets: output is List[List[Tensor]] + # final_output_multitarget = [ + # torch.stack([out[idx] for out in output], dim=1) for idx in range(len(self.target_positions)) + # ] + # self.target_positions is always Tensor([0]), so len() of that is always 1... + final_output_multitarget = torch.stack([out[0] for out in output], dim=1) + if final_output_multitarget.dim() > 3: + final_output_multitarget = final_output_multitarget.squeeze(2) + + r = [final_output_multitarget[..., i] for i in range(final_output_multitarget.size(-1))] + return r diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index a82196cf4..baa98cadd 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -564,7 +564,7 @@ def transform_output( prediction: Union[torch.Tensor, List[torch.Tensor]], target_scale: Union[torch.Tensor, List[torch.Tensor]], loss: Optional[Metric] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, List[torch.Tensor]]: """ Extract prediction from network output and rescale it to real space / de-normalize it. diff --git a/pytorch_forecasting/models/lstm.py b/pytorch_forecasting/models/lstm.py new file mode 100644 index 000000000..10d4591bf --- /dev/null +++ b/pytorch_forecasting/models/lstm.py @@ -0,0 +1,240 @@ +__all__ = ["LSTMModel"] + +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +from torch import Tensor, nn + +from pytorch_forecasting.metrics import MAE, Metric, MultiLoss +from pytorch_forecasting.models.nn import LSTM + +from ._base_autoregressive import AutoRegressiveBaseModel + + +class LSTMModel(AutoRegressiveBaseModel): + """Simple LSTM model. + + Args: + target (Union[str, Sequence[str]]): + Name (or list of names) of target variable(s). + + target_lags (Dict[str, Dict[str, int]]): _description_ + + n_layers (int): + Number of LSTM layers. + + hidden_size (int): + Hidden size for LSTM model. + + dropout (float, optional): + Droput probability (<1). Defaults to 0.1. + + input_size (int, optional): + Input size. Defaults to: inferred from `target`. + + loss (Metric): + Loss criterion. Can be different for each target in multi-target setting thanks to + `MultiLoss`. Defaults to `MAE`. + + **kwargs: + See :class:`pytorch_forecasting.models.base_model.AutoRegressiveBaseModel`. + """ + + def __init__( + self, + target: Union[str, Sequence[str]], + target_lags: Dict[str, Dict[str, int]], # pylint: disable=unused-argument + n_layers: int, + hidden_size: int, + dropout: float = 0.1, + input_size: Optional[int] = None, + loss: Optional[Metric] = None, + **kwargs: Any, + ): + """Prefer using the `LSTMModel.from_dataset()` method rather than this constructor. + + Args: + target (Union[str, Sequence[str]]): + Name (or list of names) of target variable(s). + target_lags (Dict[str, Dict[str, int]]): _description_ + + n_layers (int): + Number of LSTM layers. + + hidden_size (int): + Hidden size for LSTM model. + + dropout (float, optional): + Droput probability (<1). Defaults to 0.1. + + input_size (int, optional): + Input size. Defaults to: inferred from `target`. + + loss (Metric): + Loss criterion. Can be different for each target in multi-target setting thanks to + `MultiLoss`. Defaults to `MAE`. + + **kwargs: + See :class:`pytorch_forecasting.models.base_model.AutoRegressiveBaseModel`. + """ + n_targets = len(target) if isinstance(target, (list, tuple)) else 1 + if input_size is None: + input_size = n_targets + # arguments target and target_lags are required for autoregressive models + # even though target_lags cannot be used without covariates + # saves arguments in signature to `.hparams` attribute, mandatory call - do not skip this + self.save_hyperparameters() + # loss + if loss is None: + loss = MultiLoss([MAE() for _ in range(n_targets)]) if n_targets > 1 else MAE() # type: ignore + # pass additional arguments to BaseModel.__init__, mandatory call - do not skip this + super().__init__(loss=loss, **kwargs) # type: ignore + # use version of LSTM that can handle zero-length sequences + self.lstm = LSTM( + hidden_size=hidden_size, + input_size=input_size, + num_layers=n_layers, + dropout=dropout, + batch_first=True, + ) + # output layer + self.output_layer = nn.Linear(hidden_size, n_targets) + # others + self._input_vector: Tensor + + def encode(self, x: Dict[str, torch.Tensor]) -> Tuple[Tensor, Tensor]: + """Encode method. + Args: + x (Dict[str, torch.Tensor]): + First item returned by a `DataLoader` obtained from `TimeSeriesDataset.to_dataloader()`. + Returns: + Tuple[Tensor, Tensor]: + Tuple of hidden and cell state. + """ + # we need at least one encoding step as because the target needs to be lagged by one time step + # because we use the custom LSTM, we do not have to require encoder lengths of > 1 + # but can handle lengths of >= 1 + encoder_lengths = x["encoder_lengths"] + assert encoder_lengths.min() >= 1, f"encoder_lengths = {encoder_lengths.min()}" + input_vector = x["encoder_cont"].clone() + # lag target by one + input_vector[..., self.target_positions] = torch.roll( + input_vector[..., self.target_positions], + shifts=1, + dims=1, + ) + input_vector = input_vector[:, 1:] # first time step cannot be used because of lagging + # determine effective encoder_length length + effective_encoder_lengths = x["encoder_lengths"] - 1 + # run through LSTM network + hidden_state: Tuple[Tensor, Tensor] + _, hidden_state = self.lstm( + input_vector, + lengths=effective_encoder_lengths, + enforce_sorted=False, # passing the lengths directly + ) # second ouput is not needed (hidden state) + return hidden_state + + def decode( + self, + x: Dict[str, torch.Tensor], + hidden_state: Tuple[Tensor, Tensor], + ) -> Union[List[Tensor], Tensor]: + """ + Args: + x (Dict[str, torch.Tensor]): + First item returned by a `DataLoader` obtained from `TimeSeriesDataset.to_dataloader()`. + hidden_state (Tuple[Tensor, Tensor]): + Tuple of hidden and cell state. + Returns: + (Union[List[Tensor], Tensor]): + Tensor if one target, list of Tensors if multi-target. + """ + # again lag target by one + input_vector = x["decoder_cont"].clone() # (B,L,D) + B, L, D = input_vector.size() + input_vector[..., self.target_positions] = torch.roll( + input_vector[..., self.target_positions], shifts=1, dims=1 + ) + # but this time fill in missing target from encoder_cont at the first time step instead of throwing it away + last_encoder_target = x["encoder_cont"][ + torch.arange(x["encoder_cont"].size(0), device=x["encoder_cont"].device), + x["encoder_lengths"] - 1, + self.target_positions.unsqueeze(-1), + ].T + input_vector[:, 0, self.target_positions] = last_encoder_target + # Training mode + if self.training: # training mode + lstm_output, _ = self.lstm(input_vector, hidden_state, lengths=x["decoder_lengths"], enforce_sorted=False) + # transform into right shape + out: Tensor = self.output_layer(lstm_output) + if self.n_targets > 1: + out = [out[:, :, i].view(B, L, -1) for i in range(D)] # type: ignore + prediction: List[Tensor] = self.transform_output(out, target_scale=x["target_scale"]) + # predictions are not yet rescaled + return prediction + # Prediction mode + self._input_vector = input_vector + n_decoder_steps = input_vector.size(1) + first_target = input_vector[:, 0, :] # self.target_positions? + first_target = first_target.view(B, 1, D) + target_scale = x["target_scale"] + output: Union[List[Tensor], Tensor] = self.decode_autoregressive( + self.decode_one, # make predictions which are fed into next step + first_target=first_target, + first_hidden_state=hidden_state, + target_scale=target_scale, + n_decoder_steps=n_decoder_steps, + ) + # predictions are already rescaled + return output + + def forward(self, x: Dict[str, Tensor]) -> Dict[str, Union[Tensor, List[Tensor]]]: + """ + Args: + x (Dict[str, torch.Tensor]): Input dict. + + Returns: + Dict[str, torch.Tensor]: Output dict. + """ + hidden_state = self.encode(x) # encode to hidden state + output = self.decode(x, hidden_state) # decode leveraging hidden state + out = self.to_network_output(prediction=output) + return out + + def decode_one( + self, + idx: int, + lagged_targets: List[Tensor], + hidden_state: Tuple[Tensor, Tensor], + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + """_summary_ + Args: + idx (int): + (???). + lagged_targets (List[Tensor]): + (???). + hidden_state (Tuple[Tensor, Tensor]): + `(h,c)` (hidden state, cell state). + Returns: + Tuple[Tensor, Tuple[Tensor, Tensor]]: + One-step-ahead prediction and tuple of `(h,c)` (hidden state, cell state). + """ + B, _, D = self._input_vector.size() + + # input has shape (B,L,D) + x = self._input_vector[:, [idx]] + # take most recent target (i.e. lag=1) + lag = lagged_targets[-1] + assert lag.size(0) == B + assert lag.size(-1) == D + # make sure it has shape (B,D) + lag = lag.view(B, D) + # overwrite at target positions + x[:, 0, :] = lag + lstm_output, hidden_state = self.lstm(x, hidden_state) + # transform into right shape + prediction: Tensor = self.output_layer(lstm_output)[:, 0] # take first timestep + if self.n_targets > 1: + prediction = prediction.view(B, 1, D) + return prediction, hidden_state diff --git a/pytorch_forecasting/models/tuning.py b/pytorch_forecasting/models/tuning.py new file mode 100644 index 000000000..c896b0857 --- /dev/null +++ b/pytorch_forecasting/models/tuning.py @@ -0,0 +1,312 @@ +""" +Module for hyperparameter optimization. + +Hyperparameters can be efficiently tuned with `optuna `. +""" + +__all__ = ["optimize_hyperparameters"] + +import copy +import logging +import os +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.tuner import Tuner +from lightning.pytorch.tuner.lr_finder import _LRFinder +import numpy as np +import optuna +from optuna import Trial +from optuna.integration import PyTorchLightningPruningCallback +import optuna.logging +import statsmodels.api as sm +from torch import Tensor +from torch.utils.data import DataLoader + +from pytorch_forecasting import BaseModel, TemporalFusionTransformer +from pytorch_forecasting.data import TimeSeriesDataSet + +optuna_logger = logging.getLogger("optuna") + + +NUMBER = Union[float, int] + + +class PyTorchLightningPruningCallbackAdjusted(pl.Callback, PyTorchLightningPruningCallback): # type: ignore + """Need to inherit from callback for this to work.""" + + +def input_params_generator_tft( + trial: Trial, + hidden_size_range: Tuple[int, int] = (16, 265), + hidden_continuous_size_range: Tuple[int, int] = (8, 64), + attention_head_size_range: Tuple[int, int] = (1, 4), + dropout_range: Tuple[float, float] = (0.1, 0.3), +) -> dict: + """Generates parameters for `TemporalFusionTransformer`.""" + hidden_size = trial.suggest_int("hidden_size", *hidden_size_range, log=True) + dropout = trial.suggest_uniform("dropout", *dropout_range) + hidden_continuous_size = trial.suggest_int( + "hidden_continuous_size", + hidden_continuous_size_range[0], + min(hidden_continuous_size_range[1], hidden_size), + log=True, + ) + attention_head_size = trial.suggest_int("attention_head_size", *attention_head_size_range) + params = dict( + hidden_size=hidden_size, + dropout=dropout, + hidden_continuous_size=hidden_continuous_size, + attention_head_size=attention_head_size, + ) + return params + + +def optimize_hyperparameters( + train_dataloaders: DataLoader, + val_dataloaders: DataLoader, + model_path: str = "hpo", + monitor: str = "val_loss", + direction: str = "minimize", + model_class: Type[BaseModel] = TemporalFusionTransformer, + max_epochs: int = 20, + n_trials: int = 100, + timeout: float = 3600 * 8.0, # 8 hours + gradient_clip_val_range: Tuple[float, float] = (0.01, 100.0), + input_params: Dict[str, Dict[str, Any]] = None, + input_params_generator: Callable = None, + generator_params: dict = None, + learning_rate_range: Tuple[float, float] = (1e-5, 1.0), + use_learning_rate_finder: bool = True, + trainer_kwargs: Dict[str, Any] = None, + log_dir: str = "lightning_logs", + study: optuna.Study = None, + verbose: Union[int, bool] = None, + pruner: optuna.pruners.BasePruner = optuna.pruners.SuccessiveHalvingPruner(), + **kwargs: Any, +) -> optuna.Study: + """ + Optimize hyperparameters. Run hyperparameter optimization. + + Learning rate for is determined with the PyTorch Lightning learning rate finder. + + Args: + train_dataloaders (DataLoader): + Dataloader for training model. + val_dataloaders (DataLoader): + Dataloader for validating model. + model_path (str): + Folder to which model checkpoints are saved. + monitor (str): + Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config, + and reads this metric to score configuration. By default, the lower the better. + direction (str): + By default, direction is "minimize", meaning that lower values of the specified + ``monitor`` are better. You can change this, e.g. to "maximize". + max_epochs (int, optional): + Maximum number of epochs to run training. Defaults to 20. + n_trials (int, optional): + Number of hyperparameter trials to run. Defaults to 100. + timeout (float, optional): + Time in seconds after which training is stopped regardless of number of epochs or + validation metric. Defaults to 3600*8.0. + input_params (dict, optional): + A dictionary, where each ``key`` contains another dictionary with two keys: ``"method"`` + and ``"ranges"``. Example: + >>> {"hidden_size": { + "method": "suggest_int", + "ranges": (16, 265), + }} + The method key has to be a method of the ``optuna.Trial`` object. + The ranges key are the input ranges for the specified method. + input_params_generator (Callable, optional): + A function with the following signature: + `fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]`, + returning the parameter values to set up your model for the current trial/run. + Example: + >>> def fn(trial, param_ranges = (16, 265)) -> Dict[str, Any]: + param = trial.suggest_int("param", *param_ranges, log=True) + model_params = {"param": param} + return model_params + Then, when your model is created (before training it and report the metrics for + the current combination of hyperparameters), these dictionary is used as follows: + >>> model = YourModelClass.from_dataset( + train_dataloaders.dataset, + log_interval=-1, + **model_params, + ) + generator_params (dict, optional): + The additional parameters to be passed to the ``input_params_generator`` function, + if required. + learning_rate_range (Tuple[float, float], optional): + Learning rate range. Defaults to (1e-5, 1.0). + use_learning_rate_finder (bool): + If to use learning rate finder or optimize as part of hyperparameters. + Defaults to True. + trainer_kwargs (Dict[str, Any], optional): + Additional arguments to the + PyTorch Lightning trainer such as ``limit_train_batches``. + Defaults to {}. + log_dir (str, optional): + Folder into which to log results for tensorboard. Defaults to "lightning_logs". + study (optuna.Study, optional): + Study to resume. Will create new study by default. + verbose (Union[int, bool]): + Level of verbosity. + * None: no change in verbosity level (equivalent to verbose=1). + * 0 or False: log only warnings. + * 1 or True: log pruning events. + * 2: optuna logging level at debug level. + Defaults to None. + pruner (optuna.pruners.BasePruner, optional): + The optuna pruner to use. Defaults to ``optuna.pruners.SuccessiveHalvingPruner()``. + **kwargs: + Additional arguments for your model's class. + + Returns: + optuna.Study: optuna study results + """ + if generator_params is None: + generator_params = {} + + assert isinstance(train_dataloaders.dataset, TimeSeriesDataSet) and isinstance( + val_dataloaders.dataset, TimeSeriesDataSet + ), "Dataloaders must be built from TimeSeriesDataSet." + + logging_level = { + None: optuna.logging.get_verbosity(), + 0: optuna.logging.WARNING, + 1: optuna.logging.INFO, + 2: optuna.logging.DEBUG, + } + optuna_verbose = logging_level[verbose] + optuna.logging.set_verbosity(optuna_verbose) + + # need a deepcopy of loss as it will otherwise propagate from one trial to the next + loss = kwargs.get("loss", None) + + # create objective function + def objective(trial: optuna.Trial) -> float: + # Filenames for each trial must be made unique in order to access each checkpoint + checkpoint_callback = ModelCheckpoint( + dirpath=os.path.join(model_path, f"trial_{trial.number}"), + filename="{epoch}", + monitor=monitor, + ) + # Create Trainer + learning_rate_callback = LearningRateMonitor() + gradient_clip_val = trial.suggest_loguniform("gradient_clip_val", *gradient_clip_val_range) + default_trainer_kwargs = dict( + accelerator="auto", + max_epochs=max_epochs, + gradient_clip_val=gradient_clip_val, + callbacks=[ + learning_rate_callback, + checkpoint_callback, + PyTorchLightningPruningCallbackAdjusted(trial, monitor=monitor), + ], + logger=TensorBoardLogger(log_dir, name="optuna", version=trial.number), + enable_progress_bar=optuna_verbose < optuna.logging.INFO, + enable_model_summary=[False, True][optuna_verbose < optuna.logging.INFO], + ) + if trainer_kwargs is not None: + default_trainer_kwargs.update(trainer_kwargs) + trainer = pl.Trainer(**default_trainer_kwargs) # type: ignore + # Create model: set up kwargs + if input_params_generator is None: + assert ( + input_params is not None + ), "Please provide `input_params` when not providing a `input_params_generator` function." + params = dict() + for key, cfg in input_params.items(): + fn, low, high, more_kwargs = extract_params(trial, cfg) + try: + params[key] = fn(key, low, high, **more_kwargs) + except ValueError as ex: + raise ValueError(f"Error while calling {fn} for {key}.") from ex + else: + params = input_params_generator(trial, **generator_params) + kwargs.update(params) + kwargs["loss"] = copy.deepcopy(loss) + # Create model + model = model_class.from_dataset( + train_dataloaders.dataset, + log_interval=-1, + **kwargs, + ) + # find a good learning rate + if use_learning_rate_finder: + lr_trainer = pl.Trainer( + gradient_clip_val=gradient_clip_val, + accelerator=trainer_kwargs.get("accelerator", "auto"), + logger=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + tuner = Tuner(lr_trainer) + res: Optional[_LRFinder] = tuner.lr_find( + model, + train_dataloaders=train_dataloaders, + val_dataloaders=val_dataloaders, + early_stop_threshold=10000, + min_lr=learning_rate_range[0], + num_training=100, + max_lr=learning_rate_range[1], + ) + assert res is not None, "`tuner.lr_find()` return no results." + loss_finite = np.isfinite(res.results["loss"]) + if loss_finite.sum() > 3: # at least 3 valid values required for learning rate finder + lr_smoothed, loss_smoothed = sm.nonparametric.lowess( + np.asarray(res.results["loss"])[loss_finite], + np.asarray(res.results["lr"])[loss_finite], + frac=1.0 / 10.0, + )[min(loss_finite.sum() - 3, 10) : -1].T + optimal_idx = np.gradient(loss_smoothed).argmin() + optimal_lr = lr_smoothed[optimal_idx] + else: + optimal_idx = np.asarray(res.results["loss"]).argmin() + optimal_lr = res.results["lr"][optimal_idx] + optuna_logger.info(f"Using learning rate of {optimal_lr:.3g}") + # add learning rate artificially + model.hparams.learning_rate = trial.suggest_uniform("learning_rate", optimal_lr, optimal_lr) + else: + model.hparams.learning_rate = trial.suggest_loguniform("learning_rate", *learning_rate_range) + # fit + trainer.fit(model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders) + # report result: choose logged metric + metrics: dict = trainer.callback_metrics + try: + metric_value: Tensor = metrics[monitor] + except KeyError as ex: + raise KeyError(f"Available metrics: {metrics.keys()}") from ex + return metric_value.item() + + # setup optuna and run + if study is None: + study = optuna.create_study(direction=direction, pruner=pruner) + study.optimize(objective, n_trials=n_trials, timeout=timeout) + return study + + +def extract_params( + trial: Trial, + cfg: Dict[str, Any], +) -> Tuple[Callable, NUMBER, NUMBER, Dict[str, Any]]: + """Helper to extract config for one hp.""" + more_kwargs: Dict[str, Any] = {} + for k, v in cfg.items(): + if k.lower() in ["method"]: + method: str = cfg["method"] + assert isinstance(method, str), f"You must provide a {str} as method." + fn: Callable = getattr(trial, method) + elif k.lower() in ["ranges"]: + ranges: Sequence[NUMBER] = cfg["ranges"] + assert isinstance(ranges, (list, tuple)), f"You must provide a {list} or {tuple} as ranges." + assert len(ranges) == 2, f"Why did you provide {len(ranges)} values? Only provide 2." + low = ranges[0] + high = ranges[1] + else: + more_kwargs[k.lower()] = v + return fn, low, high, more_kwargs diff --git a/tests/conftest.py b/tests/conftest.py index 58dd0189e..2587b8610 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import sys import numpy as np +import pandas as pd import pytest sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../.."))) # isort:skip @@ -75,3 +76,29 @@ def test_dataset(test_data): def disable_mps(monkeypatch): """Disable MPS for all tests""" monkeypatch.setattr("torch._C._mps_is_available", lambda: False) + + +@pytest.fixture(scope="session") +def timeseriesdataset_multitarget() -> TimeSeriesDataSet: + """Dummy multi-target `TimeSeriesDataSet`.""" + multi_target_test_data = pd.DataFrame( + dict( + target1=np.random.rand(30), + target2=np.random.rand(30), + group=np.repeat(np.arange(3), 10), + time_idx=np.tile(np.arange(10), 3), + ) + ) + dataset = TimeSeriesDataSet( + multi_target_test_data, + group_ids=["group"], + target=["target1", "target2"], + time_idx="time_idx", + min_encoder_length=5, + max_encoder_length=5, + min_prediction_length=1, + max_prediction_length=1, + time_varying_unknown_reals=["target1", "target2"], + target_normalizer="auto", + ) + return dataset diff --git a/tests/test_models/test_tuning.py b/tests/test_models/test_tuning.py new file mode 100644 index 000000000..2caab19ca --- /dev/null +++ b/tests/test_models/test_tuning.py @@ -0,0 +1,55 @@ +import os +import sys +import typing as ty + +import optuna +import pytest + +from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting.models import LSTMModel +from pytorch_forecasting.models.tuning import optimize_hyperparameters + + +def test_tuning_lst(timeseriesdataset_multitarget: TimeSeriesDataSet) -> optuna.Study: + """Test we can tune a `LSTMModel` model.""" + # create dataloaders for model + batch_size = 32 + train_dataloader = timeseriesdataset_multitarget.to_dataloader(train=True, batch_size=batch_size, num_workers=0) + val_dataloader = timeseriesdataset_multitarget.to_dataloader(train=False, batch_size=batch_size, num_workers=0) + # Create HP to explore + hp: ty.Dict[str, ty.Dict[str, ty.Any]] = { + "n_layers": { + "method": "suggest_int", + "ranges": (1, 8), + }, + "hidden_size": { + "method": "suggest_int", + "ranges": (4, 16), + }, + "dropout": { + "method": "suggest_uniform", + "ranges": (0.1, 0.3), + }, + } + # create study + model_class = LSTMModel + study = optimize_hyperparameters( + train_dataloader, + val_dataloader, + monitor="val_loss", + model_path=os.path.join("pytest_artifacts", f"hpo_{model_class.__name__}"), + model_class=model_class, + input_params=hp, + n_trials=2, + max_epochs=10, + gradient_clip_val_range=(0.01, 1.0), + learning_rate_range=(0.001, 0.1), + trainer_kwargs=dict(limit_train_batches=30, accelerator="cpu"), + reduce_on_plateau_patience=4, + use_learning_rate_finder=True, # use Optuna to find ideal learning rate or use in-built learning rate finder + ) + return study + + +if __name__ == "__main__": + pytest.main([__file__, "-x", "-s"])