From a7687b872c6d64110052711a42e065aa7d7ea990 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 16 Oct 2021 18:48:41 +0100 Subject: [PATCH 01/22] update --- .../basic_examples/mnist_datamodule.py | 27 ++-- pl_examples/loops/__init__.py | 0 pl_examples/loops/cross_validation.py | 140 ++++++++++++++++++ pytorch_lightning/trainer/trainer.py | 3 +- 4 files changed, 157 insertions(+), 13 deletions(-) create mode 100644 pl_examples/loops/__init__.py create mode 100644 pl_examples/loops/cross_validation.py diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index 68823eeac7bba..5a025be77f934 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -26,18 +26,21 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib -_TORCHVISION_MNIST_AVAILABLE = not bool(os.getenv("PL_USE_MOCKED_MNIST", False)) -if _TORCHVISION_MNIST_AVAILABLE: - try: - from torchvision.datasets import MNIST - - MNIST(_DATASETS_PATH, download=True) - except HTTPError as e: - print(f"Error {e} downloading `torchvision.datasets.MNIST`") - _TORCHVISION_MNIST_AVAILABLE = False -if not _TORCHVISION_MNIST_AVAILABLE: - print("`torchvision.datasets.MNIST` not available. Using our hosted version") - from tests.helpers.datasets import MNIST + +def MNIST(*args, **kwargs): + _TORCHVISION_MNIST_AVAILABLE = not bool(os.getenv("PL_USE_MOCKED_MNIST", False)) + if _TORCHVISION_MNIST_AVAILABLE: + try: + from torchvision.datasets import MNIST + + MNIST(_DATASETS_PATH, download=True) + except HTTPError as e: + print(f"Error {e} downloading `torchvision.datasets.MNIST`") + _TORCHVISION_MNIST_AVAILABLE = False + if not _TORCHVISION_MNIST_AVAILABLE: + print("`torchvision.datasets.MNIST` not available. Using our hosted version") + from tests.helpers.datasets import MNIST + return MNIST(*args, **kwargs) class MNISTDataModule(LightningDataModule): diff --git a/pl_examples/loops/__init__.py b/pl_examples/loops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pl_examples/loops/cross_validation.py b/pl_examples/loops/cross_validation.py new file mode 100644 index 0000000000000..7a9edf95c2e5b --- /dev/null +++ b/pl_examples/loops/cross_validation.py @@ -0,0 +1,140 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Any, Dict, Optional + +import torchvision.transforms as T +from sklearn.model_selection import KFold +from torch.utils.data import random_split +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset, Subset + +from pl_examples import _DATASETS_PATH +from pl_examples.basic_examples.mnist_datamodule import MNIST +from pl_examples.basic_examples.simple_image_classifier import LitClassifier +from pytorch_lightning import LightningDataModule, seed_everything, Trainer +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.trainer.states import TrainerFn + +seed_everything(42) + + +class BaseKFoldDataModule(LightningDataModule, ABC): + @abstractmethod + def setup_folds(self, num_folds: int): + pass + + @abstractmethod + def setup_fold_index(self, fold_index: int) -> LightningDataModule: + pass + + +class KFoldDataModule(BaseKFoldDataModule): + def __init__(self, train_dataset: Dataset, test_dataset: Dataset): + super().__init__() + self._train_dataset = train_dataset + self._test_dataset = test_dataset + self._train_fold: Optional[Dataset] = None + self._val_fold: Optional[Dataset] = None + + def setup_folds(self, num_folds: int) -> None: + self.num_folds = num_folds + self.splits = [split for split in KFold(num_folds).split(range(len(self._train_dataset)))] + + def setup_fold_index(self, fold_index: int) -> None: + train_indices, val_indices = self.splits[fold_index] + self._train_fold = Subset(self._train_dataset, train_indices) + self._val_fold = Subset(self._train_dataset, val_indices) + + def train_dataloader(self): + return DataLoader(self._train_fold) + + def val_dataloader(self): + return DataLoader(self._val_fold) + + def test_dataloader(self): + return DataLoader(self._test_dataset) + + +class KFoldLoop(Loop): + def __init__(self, num_folds: int, fit_loop: FitLoop, export_path: str): + super().__init__() + self.num_folds = num_folds + self.fit_loop = fit_loop + self.current_fold = 0 + self.export_path = export_path + + def __getattr__(self, key): + if key not in self.__dict__: + return getattr(self.fit_loop, key) + return self.__dict__[key] + + @property + def done(self) -> bool: + return self.current_fold >= self.num_folds + + def on_run_start(self, *args: Any, **kwargs: Any) -> None: + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) + self.trainer.datamodule.setup_folds(self.num_folds) + self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict()) + + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + print(f"STARTING FOLD {self.current_fold}") + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) + self.trainer.datamodule.setup_fold_index(self.current_fold) + + def advance(self, *args: Any, **kwargs: Any) -> None: + self._reset_fitting() + self.fit_loop.run() + + self._reset_testing() + self.trainer.test_loop.run() + self.current_fold += 1 + + def on_advance_end(self) -> None: + self.trainer.save_checkpoint(os.path.join(self.export_path, f"model.{self.current_fold}.pt")) + # restore the original weights + optimizers and schedulers. + self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) + self.trainer.accelerator.setup_optimizers(self.trainer) + print() + + def reset(self) -> None: + pass + + def on_save_checkpoint(self): + return {"current_fold": self.current_fold} + + def on_load_checkpoint(self, state_dict: Dict) -> None: + self.current_fold = state_dict["current_fold"] + + def _reset_fitting(self): + self.trainer.reset_train_val_dataloaders() + self.trainer.state.fn = TrainerFn.FITTING + self.trainer.training = True + + def _reset_testing(self): + self.trainer.reset_test_dataloader() + self.trainer.state.fn = TrainerFn.TESTING + self.trainer.testing = True + + +dataset = MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) +dm = KFoldDataModule(*random_split(dataset, [50000, 10000])) +model = LitClassifier() +trainer = Trainer(max_epochs=10, limit_train_batches=2, limit_val_batches=2, num_sanity_val_steps=0) +trainer.fit_loop = KFoldLoop(5, trainer.fit_loop, export_path=".") +trainer.fit(model, dm) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e6d8ccde91d71..3aff8a7af85bf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1225,7 +1225,8 @@ def _run_train(self) -> None: # reload data when needed model = self.lightning_module - self.reset_train_val_dataloaders(model) + if isinstance(self.fit_loop, FitLoop): + self.reset_train_val_dataloaders(model) self.fit_loop.trainer = self with torch.autograd.set_detect_anomaly(self._detect_anomaly): From 8fb22615fe866bb88f1b84da0c9cd572e2d656b2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 16 Oct 2021 18:49:12 +0100 Subject: [PATCH 02/22] update --- pl_examples/loops/{cross_validation.py => k_fold_loop.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename pl_examples/loops/{cross_validation.py => k_fold_loop.py} (100%) diff --git a/pl_examples/loops/cross_validation.py b/pl_examples/loops/k_fold_loop.py similarity index 100% rename from pl_examples/loops/cross_validation.py rename to pl_examples/loops/k_fold_loop.py From 9710ae086dc7f7291187857f2b5c92dc5397dd3c Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 16 Oct 2021 18:49:35 +0100 Subject: [PATCH 03/22] update --- pl_examples/{loops => loop_customization}/__init__.py | 0 pl_examples/{loops => loop_customization}/k_fold_loop.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename pl_examples/{loops => loop_customization}/__init__.py (100%) rename pl_examples/{loops => loop_customization}/k_fold_loop.py (100%) diff --git a/pl_examples/loops/__init__.py b/pl_examples/loop_customization/__init__.py similarity index 100% rename from pl_examples/loops/__init__.py rename to pl_examples/loop_customization/__init__.py diff --git a/pl_examples/loops/k_fold_loop.py b/pl_examples/loop_customization/k_fold_loop.py similarity index 100% rename from pl_examples/loops/k_fold_loop.py rename to pl_examples/loop_customization/k_fold_loop.py From 0e108b3f9854a963d4a0abd281e309410768d2ef Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 16 Oct 2021 18:53:46 +0100 Subject: [PATCH 04/22] update --- pl_examples/loop_customization/k_fold_loop.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pl_examples/loop_customization/k_fold_loop.py b/pl_examples/loop_customization/k_fold_loop.py index 7a9edf95c2e5b..d6c79884954a8 100644 --- a/pl_examples/loop_customization/k_fold_loop.py +++ b/pl_examples/loop_customization/k_fold_loop.py @@ -87,6 +87,9 @@ def __getattr__(self, key): def done(self) -> bool: return self.current_fold >= self.num_folds + def reset(self) -> None: + """Nothing to reset in this loop.""" + def on_run_start(self, *args: Any, **kwargs: Any) -> None: assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) self.trainer.datamodule.setup_folds(self.num_folds) @@ -98,10 +101,10 @@ def on_advance_start(self, *args: Any, **kwargs: Any) -> None: self.trainer.datamodule.setup_fold_index(self.current_fold) def advance(self, *args: Any, **kwargs: Any) -> None: - self._reset_fitting() + self._reset_fitting() # requires to reset the tracking stage self.fit_loop.run() - self._reset_testing() + self._reset_testing() # requires to reset the tracking stage self.trainer.test_loop.run() self.current_fold += 1 @@ -112,9 +115,6 @@ def on_advance_end(self) -> None: self.trainer.accelerator.setup_optimizers(self.trainer) print() - def reset(self) -> None: - pass - def on_save_checkpoint(self): return {"current_fold": self.current_fold} @@ -135,6 +135,8 @@ def _reset_testing(self): dataset = MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) dm = KFoldDataModule(*random_split(dataset, [50000, 10000])) model = LitClassifier() -trainer = Trainer(max_epochs=10, limit_train_batches=2, limit_val_batches=2, num_sanity_val_steps=0) +trainer = Trainer( + max_epochs=10, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, num_sanity_val_steps=0 +) trainer.fit_loop = KFoldLoop(5, trainer.fit_loop, export_path=".") trainer.fit(model, dm) From 55456b80c612680ff48efa0bdd8543fa46ea8b87 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 16 Oct 2021 18:56:03 +0100 Subject: [PATCH 05/22] update --- pl_examples/loop_customization/k_fold_loop.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pl_examples/loop_customization/k_fold_loop.py b/pl_examples/loop_customization/k_fold_loop.py index d6c79884954a8..b69f773c32acc 100644 --- a/pl_examples/loop_customization/k_fold_loop.py +++ b/pl_examples/loop_customization/k_fold_loop.py @@ -75,14 +75,9 @@ def __init__(self, num_folds: int, fit_loop: FitLoop, export_path: str): super().__init__() self.num_folds = num_folds self.fit_loop = fit_loop - self.current_fold = 0 + self.current_fold: int = 0 self.export_path = export_path - def __getattr__(self, key): - if key not in self.__dict__: - return getattr(self.fit_loop, key) - return self.__dict__[key] - @property def done(self) -> bool: return self.current_fold >= self.num_folds @@ -131,6 +126,12 @@ def _reset_testing(self): self.trainer.state.fn = TrainerFn.TESTING self.trainer.testing = True + def __getattr__(self, key): + # requires to be overridden as attributes of the wrapped loop as being accessed. + if key not in self.__dict__: + return getattr(self.fit_loop, key) + return self.__dict__[key] + dataset = MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) dm = KFoldDataModule(*random_split(dataset, [50000, 10000])) From 75a94540fd29a42e08a52f5cac67cf0ed1ae465b Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 16 Oct 2021 19:13:28 +0100 Subject: [PATCH 06/22] add comments --- pl_examples/loop_customization/k_fold_loop.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/pl_examples/loop_customization/k_fold_loop.py b/pl_examples/loop_customization/k_fold_loop.py index b69f773c32acc..3f278e996fbdb 100644 --- a/pl_examples/loop_customization/k_fold_loop.py +++ b/pl_examples/loop_customization/k_fold_loop.py @@ -30,9 +30,21 @@ from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.trainer.states import TrainerFn +############################################################################################# +# KFold Loop / Cross Validation Example # +# This example demonstrates how to leverage Lightning Loop Customization introduced in v1.5 # +############################################################################################# + + seed_everything(42) +############################################################################################# +# Step 1 / 4: Define your DataModule API # +# Our KFold DataModule should implement a `setup_folds` and `setup_fold_index` function # +############################################################################################# + + class BaseKFoldDataModule(LightningDataModule, ABC): @abstractmethod def setup_folds(self, num_folds: int): @@ -43,6 +55,15 @@ def setup_fold_index(self, fold_index: int) -> LightningDataModule: pass +############################################################################################# +# Step 2 / 4: Implement the KFoldDataModule # +# The `KFoldDataModule` will take a train and test dataset. # +# On `setup_folds`, folds will be created depending on the provided argument num_folds # +# Our `setup_fold_index`, the provided train dataset will be splitted accordingly to # +# the current fold split. # +############################################################################################# + + class KFoldDataModule(BaseKFoldDataModule): def __init__(self, train_dataset: Dataset, test_dataset: Dataset): super().__init__() @@ -70,6 +91,34 @@ def test_dataloader(self): return DataLoader(self._test_dataset) +############################################################################################# +# Step 3 / 4: Implement the KFoldLoop # +# From Lightning v1.5, it is possible to implement your own loop. There is several to do # +# so and refer to the documentation to learn more. # +# Here, we will implement an outter fit_loop. It means we will implement subclass the # +# base Loop and wrap the current trainer `fit_loop`. # +# Here is the base Loop structure. # +# # +# reset() # +# on_run_start() # +# # +# while not done: # +# on_advance_start() # +# advance() # +# on_advance_end() # +# # +# on_run_end() # +# # +# On `on_run_start`, the `KFoldLoop` will call the `KFoldDataModule` `setup_folds` function # +# and store the original weights of the model. # +# On `on_advance_start`, the `KFoldLoop` will call the `KFoldDataModule` `setup_fold_index` # +# function. # +# On `advance`, the `KFoldLoop` will run the original trainer `fit_loop` and # +# the trainer `test_loop`. # +# On `advance_end`, the `KFoldLoop` will reset the model weight and optimizers / schedulers # +############################################################################################# + + class KFoldLoop(Loop): def __init__(self, num_folds: int, fit_loop: FitLoop, export_path: str): super().__init__() @@ -133,6 +182,13 @@ def __getattr__(self, key): return self.__dict__[key] +############################################################################################# +# Step 4 / 4: Connect the KFoldLoop to the Trainer # +# After creating the `KFoldDataModule` and our model, the `KFoldLoop` is being connected to # +# the Trainer. # +# Finally, use `trainer.fit` to start the cross validation training. # +############################################################################################# + dataset = MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) dm = KFoldDataModule(*random_split(dataset, [50000, 10000])) model = LitClassifier() From 7c52c4e29ee0971befa4f37cb6dcef2b8555a9f8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 10:56:19 +0100 Subject: [PATCH 07/22] update --- .../basic_examples/mnist_datamodule.py | 8 +- .../{loop_customization => loops}/__init__.py | 0 .../k_fold_loop.py => loops/kfold.py} | 95 ++++++++++++++----- 3 files changed, 73 insertions(+), 30 deletions(-) rename pl_examples/{loop_customization => loops}/__init__.py (100%) rename pl_examples/{loop_customization/k_fold_loop.py => loops/kfold.py} (70%) diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index 5a025be77f934..1d2371c702ce0 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -28,16 +28,16 @@ def MNIST(*args, **kwargs): - _TORCHVISION_MNIST_AVAILABLE = not bool(os.getenv("PL_USE_MOCKED_MNIST", False)) - if _TORCHVISION_MNIST_AVAILABLE: + torchvision_mnist_available = not bool(os.getenv("PL_USE_MOCKED_MNIST", False)) + if torchvision_mnist_available: try: from torchvision.datasets import MNIST MNIST(_DATASETS_PATH, download=True) except HTTPError as e: print(f"Error {e} downloading `torchvision.datasets.MNIST`") - _TORCHVISION_MNIST_AVAILABLE = False - if not _TORCHVISION_MNIST_AVAILABLE: + torchvision_mnist_available = False + if not torchvision_mnist_available: print("`torchvision.datasets.MNIST` not available. Using our hosted version") from tests.helpers.datasets import MNIST return MNIST(*args, **kwargs) diff --git a/pl_examples/loop_customization/__init__.py b/pl_examples/loops/__init__.py similarity index 100% rename from pl_examples/loop_customization/__init__.py rename to pl_examples/loops/__init__.py diff --git a/pl_examples/loop_customization/k_fold_loop.py b/pl_examples/loops/kfold.py similarity index 70% rename from pl_examples/loop_customization/k_fold_loop.py rename to pl_examples/loops/kfold.py index 3f278e996fbdb..078204bc3b00f 100644 --- a/pl_examples/loop_customization/k_fold_loop.py +++ b/pl_examples/loops/kfold.py @@ -14,10 +14,13 @@ import os from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Dict, Optional +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type +import torch import torchvision.transforms as T from sklearn.model_selection import KFold +from torch.nn import functional as F from torch.utils.data import random_split from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset, Subset @@ -26,6 +29,7 @@ from pl_examples.basic_examples.mnist_datamodule import MNIST from pl_examples.basic_examples.simple_image_classifier import LitClassifier from pytorch_lightning import LightningDataModule, seed_everything, Trainer +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.trainer.states import TrainerFn @@ -40,8 +44,9 @@ ############################################################################################# -# Step 1 / 4: Define your DataModule API # -# Our KFold DataModule should implement a `setup_folds` and `setup_fold_index` function # +# Step 1 / 5: Define KFold DataModule API # +# Our KFold DataModule should requires to implement `setup_folds` and `setup_fold_index` # +# function. # ############################################################################################# @@ -56,7 +61,7 @@ def setup_fold_index(self, fold_index: int) -> LightningDataModule: ############################################################################################# -# Step 2 / 4: Implement the KFoldDataModule # +# Step 2 / 5: Implement the KFoldDataModule # # The `KFoldDataModule` will take a train and test dataset. # # On `setup_folds`, folds will be created depending on the provided argument num_folds # # Our `setup_fold_index`, the provided train dataset will be splitted accordingly to # @@ -64,35 +69,64 @@ def setup_fold_index(self, fold_index: int) -> LightningDataModule: ############################################################################################# -class KFoldDataModule(BaseKFoldDataModule): - def __init__(self, train_dataset: Dataset, test_dataset: Dataset): - super().__init__() - self._train_dataset = train_dataset - self._test_dataset = test_dataset - self._train_fold: Optional[Dataset] = None - self._val_fold: Optional[Dataset] = None +@dataclass +class MnistKFoldDataModule(BaseKFoldDataModule): + + train_dataset: Optional[Dataset] = None + test_dataset: Optional[Dataset] = None + train_fold: Optional[Dataset] = None + val_fold: Optional[Dataset] = None + + def prepare_data(self) -> None: + MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) + + def setup(self, stage: Optional[str] = None) -> None: + dataset = MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) + self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000]) def setup_folds(self, num_folds: int) -> None: self.num_folds = num_folds - self.splits = [split for split in KFold(num_folds).split(range(len(self._train_dataset)))] + self.splits = [split for split in KFold(num_folds).split(range(len(self.train_dataset)))] def setup_fold_index(self, fold_index: int) -> None: train_indices, val_indices = self.splits[fold_index] - self._train_fold = Subset(self._train_dataset, train_indices) - self._val_fold = Subset(self._train_dataset, val_indices) + self.train_fold = Subset(self.train_dataset, train_indices) + self.val_fold = Subset(self.train_dataset, val_indices) def train_dataloader(self): - return DataLoader(self._train_fold) + return DataLoader(self.train_fold) def val_dataloader(self): - return DataLoader(self._val_fold) + return DataLoader(self.val_fold) def test_dataloader(self): - return DataLoader(self._test_dataset) + return DataLoader(self.test_dataset) + + +############################################################################################# +# Step 3 / 5: Implement the EnsembleVotingModel module # +# The `EnsembleVotingModel` will take our custom LightningModule and # +# several checkpoint_paths. # +# On `__init__`, it would create multiple models by reloading the fold weights # +# On `test_step`, the model will perform a forward through all the models and take # +# the average logits produced by the `num_folds` models, and loss the enssembling loss # +# # +############################################################################################# + + +class EnsembleVotingModel(LightningModule): + def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str]): + super().__init__() + self.models = [model_cls.load_from_checkpoint(p) for p in checkpoint_paths] + + def test_step(self, batch, batch_idx, dataloader_idx: int = 0): + logits = torch.stack([m(batch[0]) for m in self.models]).mean(0) + loss = F.cross_entropy(logits, batch[1]) + self.log("test_loss", loss) ############################################################################################# -# Step 3 / 4: Implement the KFoldLoop # +# Step 4 / 5: Implement the KFoldLoop # # From Lightning v1.5, it is possible to implement your own loop. There is several to do # # so and refer to the documentation to learn more. # # Here, we will implement an outter fit_loop. It means we will implement subclass the # @@ -157,7 +191,15 @@ def on_advance_end(self) -> None: # restore the original weights + optimizers and schedulers. self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) self.trainer.accelerator.setup_optimizers(self.trainer) - print() + + def on_run_end(self) -> None: + checkpoint_paths = [ + os.path.join(self.export_path, f"model.{fold_index}.pt") for fold_index in range(1, self.num_folds) + ] + voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) + voting_model.trainer = self.trainer + self.trainer.accelerator.connect(voting_model) + self.trainer.test_loop.run() def on_save_checkpoint(self): return {"current_fold": self.current_fold} @@ -166,7 +208,8 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: self.current_fold = state_dict["current_fold"] def _reset_fitting(self): - self.trainer.reset_train_val_dataloaders() + self.trainer.reset_train_dataloader() + self.trainer.reset_val_dataloader() self.trainer.state.fn = TrainerFn.FITTING self.trainer.training = True @@ -183,17 +226,17 @@ def __getattr__(self, key): ############################################################################################# -# Step 4 / 4: Connect the KFoldLoop to the Trainer # +# Step 5 / : Connect the KFoldLoop to the Trainer # # After creating the `KFoldDataModule` and our model, the `KFoldLoop` is being connected to # # the Trainer. # # Finally, use `trainer.fit` to start the cross validation training. # ############################################################################################# -dataset = MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) -dm = KFoldDataModule(*random_split(dataset, [50000, 10000])) model = LitClassifier() -trainer = Trainer( +trainer_kwargs = dict( max_epochs=10, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, num_sanity_val_steps=0 ) -trainer.fit_loop = KFoldLoop(5, trainer.fit_loop, export_path=".") -trainer.fit(model, dm) +trainer = Trainer(**trainer_kwargs) +# replace the current trainer `fit_loop` +trainer.fit_loop = KFoldLoop(5, trainer.fit_loop, export_path="./") +trainer.fit(model, MnistKFoldDataModule()) From 5e31c460d4713067f5704b551f9d0dd480ba3792 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 11:01:10 +0100 Subject: [PATCH 08/22] update --- pl_examples/loops/kfold.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/pl_examples/loops/kfold.py b/pl_examples/loops/kfold.py index 078204bc3b00f..9dcaab1c8197a 100644 --- a/pl_examples/loops/kfold.py +++ b/pl_examples/loops/kfold.py @@ -52,11 +52,11 @@ class BaseKFoldDataModule(LightningDataModule, ABC): @abstractmethod - def setup_folds(self, num_folds: int): + def setup_folds(self, num_folds: int) -> None: pass @abstractmethod - def setup_fold_index(self, fold_index: int) -> LightningDataModule: + def setup_fold_index(self, fold_index: int) -> None: pass @@ -70,7 +70,7 @@ def setup_fold_index(self, fold_index: int) -> LightningDataModule: @dataclass -class MnistKFoldDataModule(BaseKFoldDataModule): +class MNISTKFoldDataModule(BaseKFoldDataModule): train_dataset: Optional[Dataset] = None test_dataset: Optional[Dataset] = None @@ -78,9 +78,11 @@ class MnistKFoldDataModule(BaseKFoldDataModule): val_fold: Optional[Dataset] = None def prepare_data(self) -> None: + # download the data. MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) def setup(self, stage: Optional[str] = None) -> None: + # load the data dataset = MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000]) @@ -93,13 +95,13 @@ def setup_fold_index(self, fold_index: int) -> None: self.train_fold = Subset(self.train_dataset, train_indices) self.val_fold = Subset(self.train_dataset, val_indices) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: return DataLoader(self.train_fold) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: return DataLoader(self.val_fold) - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: return DataLoader(self.test_dataset) @@ -119,7 +121,7 @@ def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str] super().__init__() self.models = [model_cls.load_from_checkpoint(p) for p in checkpoint_paths] - def test_step(self, batch, batch_idx, dataloader_idx: int = 0): + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: logits = torch.stack([m(batch[0]) for m in self.models]).mean(0) loss = F.cross_entropy(logits, batch[1]) self.log("test_loss", loss) @@ -201,24 +203,24 @@ def on_run_end(self) -> None: self.trainer.accelerator.connect(voting_model) self.trainer.test_loop.run() - def on_save_checkpoint(self): + def on_save_checkpoint(self) -> Dict[str, int]: return {"current_fold": self.current_fold} def on_load_checkpoint(self, state_dict: Dict) -> None: self.current_fold = state_dict["current_fold"] - def _reset_fitting(self): + def _reset_fitting(self) -> None: self.trainer.reset_train_dataloader() self.trainer.reset_val_dataloader() self.trainer.state.fn = TrainerFn.FITTING self.trainer.training = True - def _reset_testing(self): + def _reset_testing(self) -> None: self.trainer.reset_test_dataloader() self.trainer.state.fn = TrainerFn.TESTING self.trainer.testing = True - def __getattr__(self, key): + def __getattr__(self, key) -> Any: # requires to be overridden as attributes of the wrapped loop as being accessed. if key not in self.__dict__: return getattr(self.fit_loop, key) @@ -233,10 +235,9 @@ def __getattr__(self, key): ############################################################################################# model = LitClassifier() -trainer_kwargs = dict( +datamodule = MNISTKFoldDataModule() +trainer = Trainer( max_epochs=10, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, num_sanity_val_steps=0 ) -trainer = Trainer(**trainer_kwargs) -# replace the current trainer `fit_loop` trainer.fit_loop = KFoldLoop(5, trainer.fit_loop, export_path="./") -trainer.fit(model, MnistKFoldDataModule()) +trainer.fit(model, datamodule) From 6acc22283c0ff691785626f0574d0876715e3fd6 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 18 Oct 2021 06:33:23 -0400 Subject: [PATCH 09/22] update --- pl_examples/loops/kfold.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/pl_examples/loops/kfold.py b/pl_examples/loops/kfold.py index 9dcaab1c8197a..8cf1f3cc5ab04 100644 --- a/pl_examples/loops/kfold.py +++ b/pl_examples/loops/kfold.py @@ -37,6 +37,8 @@ ############################################################################################# # KFold Loop / Cross Validation Example # # This example demonstrates how to leverage Lightning Loop Customization introduced in v1.5 # +# Learn more about the loop structure from the documentation # +# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html # ############################################################################################# @@ -119,7 +121,7 @@ def test_dataloader(self) -> DataLoader: class EnsembleVotingModel(LightningModule): def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str]): super().__init__() - self.models = [model_cls.load_from_checkpoint(p) for p in checkpoint_paths] + self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths]) def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: logits = torch.stack([m(batch[0]) for m in self.models]).mean(0) @@ -133,18 +135,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None # so and refer to the documentation to learn more. # # Here, we will implement an outter fit_loop. It means we will implement subclass the # # base Loop and wrap the current trainer `fit_loop`. # -# Here is the base Loop structure. # -# # -# reset() # -# on_run_start() # -# # -# while not done: # -# on_advance_start() # -# advance() # -# on_advance_end() # -# # -# on_run_end() # -# # # On `on_run_start`, the `KFoldLoop` will call the `KFoldDataModule` `setup_folds` function # # and store the original weights of the model. # # On `on_advance_start`, the `KFoldLoop` will call the `KFoldDataModule` `setup_fold_index` # @@ -201,6 +191,7 @@ def on_run_end(self) -> None: voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) voting_model.trainer = self.trainer self.trainer.accelerator.connect(voting_model) + self.trainer.training_type_plugin.model_to_device() self.trainer.test_loop.run() def on_save_checkpoint(self) -> Dict[str, int]: @@ -228,7 +219,7 @@ def __getattr__(self, key) -> Any: ############################################################################################# -# Step 5 / : Connect the KFoldLoop to the Trainer # +# Step 5 / 5: Connect the KFoldLoop to the Trainer # # After creating the `KFoldDataModule` and our model, the `KFoldLoop` is being connected to # # the Trainer. # # Finally, use `trainer.fit` to start the cross validation training. # @@ -237,7 +228,14 @@ def __getattr__(self, key) -> Any: model = LitClassifier() datamodule = MNISTKFoldDataModule() trainer = Trainer( - max_epochs=10, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, num_sanity_val_steps=0 + max_epochs=10, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + num_sanity_val_steps=0, + devices=1, + accelerator="auto", + strategy="ddp", ) trainer.fit_loop = KFoldLoop(5, trainer.fit_loop, export_path="./") trainer.fit(model, datamodule) From 4ee74e6e9b8dd12ac1a4601ad0b9d1ea70914d85 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 18 Oct 2021 06:40:04 -0400 Subject: [PATCH 10/22] update on comments --- pl_examples/loops/kfold.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pl_examples/loops/kfold.py b/pl_examples/loops/kfold.py index 8cf1f3cc5ab04..f386fc6d6049d 100644 --- a/pl_examples/loops/kfold.py +++ b/pl_examples/loops/kfold.py @@ -47,7 +47,7 @@ ############################################################################################# # Step 1 / 5: Define KFold DataModule API # -# Our KFold DataModule should requires to implement `setup_folds` and `setup_fold_index` # +# Our KFold DataModule should require to implement `setup_folds` and `setup_fold_index` # # function. # ############################################################################################# @@ -65,7 +65,7 @@ def setup_fold_index(self, fold_index: int) -> None: ############################################################################################# # Step 2 / 5: Implement the KFoldDataModule # # The `KFoldDataModule` will take a train and test dataset. # -# On `setup_folds`, folds will be created depending on the provided argument num_folds # +# On `setup_folds`, folds will be created depending on the provided argument `num_folds` # # Our `setup_fold_index`, the provided train dataset will be splitted accordingly to # # the current fold split. # ############################################################################################# @@ -186,7 +186,7 @@ def on_advance_end(self) -> None: def on_run_end(self) -> None: checkpoint_paths = [ - os.path.join(self.export_path, f"model.{fold_index}.pt") for fold_index in range(1, self.num_folds) + os.path.join(self.export_path, f"model.{fold_index + 1}.pt") for fold_index in range(self.num_folds) ] voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) voting_model.trainer = self.trainer From 206d614bea11e84b389154b3d0683ca011ef6c75 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 18 Oct 2021 06:54:56 -0400 Subject: [PATCH 11/22] add doc --- docs/source/extensions/loops.rst | 413 +++++++++++++++++++++++++++++++ pl_examples/loops/kfold.py | 2 +- 2 files changed, 414 insertions(+), 1 deletion(-) create mode 100644 docs/source/extensions/loops.rst diff --git a/docs/source/extensions/loops.rst b/docs/source/extensions/loops.rst new file mode 100644 index 0000000000000..6fc642d2fe867 --- /dev/null +++ b/docs/source/extensions/loops.rst @@ -0,0 +1,413 @@ +.. _loop_customization: + +Loops +===== + +Loops let advanced users swap out the default gradient descent optimization loop at the core of Lightning with a different optimization paradigm. + +The Lightning Trainer is built on top of the standard gradient descent optimization loop which works for 90%+ of machine learning use cases: + +.. code-block:: python + + for i, batch in enumerate(dataloader): + x, y = batch + y_hat = model(x) + loss = loss_function(y_hat, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + +However, some new research use cases such as meta-learning, active learning, recommendation systems, etc., require a different loop structure. +For example here is a simple loop that guides the weight updates with a loss from a special validation split: + +.. code-block:: python + + for i, batch in enumerate(train_dataloader): + x, y = batch + y_hat = model(x) + loss = loss_function(y_hat, y) + optimizer.zero_grad() + loss.backward() + + val_loss = 0 + for i, val_batch in enumerate(val_dataloader): + x, y = val_batch + y_hat = model(x) + val_loss += loss_function(y_hat, y) + + scale_gradients(model, 1 / val_loss) + optimizer.step() + + +With Lightning Loops, you can customize to non-standard gradient descent optimizations to get the same loop above: + +.. code-block:: python + + trainer = Trainer() + trainer.fit_loop.epoch_loop = MyGradientDescentLoop() + +Think of this as swapping out the engine in a car! + +Understanding the default Trainer loop +-------------------------------------- + +The Lightning :class:`~pytorch_lightning.trainer.trainer.Trainer` automates the standard optimization loop which every PyTorch user is familiar with: + +.. code-block:: python + + for i, batch in enumerate(dataloader): + x, y = batch + y_hat = model(x) + loss = loss_function(y_hat, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + +The core research logic is simply shifted to the :class:`~pytorch_lightning.core.lightning.LightningModule`: + +.. code-block:: python + + for i, batch in enumerate(dataloader): + # x, y = batch moved to training_step + # y_hat = model(x) moved to training_step + # loss = loss_function(y_hat, y) moved to training_step + loss = lightning_module.training_step(batch, i) + + # Lighting handles automatically: + optimizer.zero_grad() + loss.backward() + optimizer.step() + +Under the hood, the above loop is implemented using the :class:`~pytorch_lightning.loops.base.Loop` API like so: + +.. code-block:: python + + class DefaultLoop(Loop): + def advance(self, batch, i): + loss = lightning_module.training_step(batch, i) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def run(self, dataloader): + for i, batch in enumerate(dataloader): + self.advance(batch, i) + +Defining a loop within a class interface instead of hard-coding a raw Python for/while loop has several benefits: + +1. You can have full control over the data flow through loops. +2. You can add new loops and nest as many of them as you want. +3. If needed, the state of a loop can be :ref:`saved and resumed `. +4. New hooks can be injected at any point. + +.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/epoch-loop-steps.gif + :alt: Animation showing how to convert a standard training loop to a Lightning loop + + +.. _override default loops: + +Overriding the default loops +---------------------------- + +The fastest way to get started with loops, is to override functionality of an existing loop. +Lightning has 4 main loops it uses: :class:`~pytorch_lightning.loops.fit_loop.FitLoop` for training and validating, +:class:`~pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop` for testing, +:class:`~pytorch_lightning.loops.dataloader.prediction_loop.PredictionLoop` for predicting. + +For simple changes that don't require a custom loop, you can modify each of these loops. + +Each loop has a series of methods that can be modified. +For example with the :class:`~pytorch_lightning.loops.fit_loop.FitLoop`: + +.. code-block:: + + from pytorch_lightning.loops import FitLoop + + class MyLoop(FitLoop): + + def advance(): + ... + + def on_advance_end(self) + ... + + def on_run_end(self): + ... + +A full list with all built-in loops and subloops can be found :ref:`here `. + +To add your own modifications to a loop, simply subclass an existing loop class and override what you need. +Here is a simple example how to add a new hook: + +.. code-block:: python + + from pytorch_lightning.loops import FitLoop + + + class CustomFitLoop(FitLoop): + def advance(self): + # ... whatever code before + + # pass anything you want to the hook + self.trainer.call_hook("my_new_hook", *args, **kwargs) + + # ... whatever code after + +Now simply attach the correct loop in the trainer directly: + +.. code-block:: python + + trainer = Trainer(...) + trainer.fit_loop = CustomFitLoop() + + # fit() now uses the new FitLoop! + trainer.fit(...) + + # the equivalent for validate(), test(), predict() + val_loop = CustomValLoop() + trainer = Trainer() + trainer.validate_loop = val_loop + trainer.validate(model) + +Now your code is FULLY flexible and you can still leverage ALL the best parts of Lightning! + +.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/replace-fit-loop.gif + :alt: Animation showing how to replace a loop on the Trainer + +Creating a new loop from scratch +-------------------------------- + +You can also go wild and implement a full loop from scratch by sub-classing the :class:`~pytorch_lightning.loops.base.Loop` base class. +You will need to override a minimum of two things: + +.. code-block:: + + from pytorch_lightning.loop import Loop + + class MyFancyLoop(Loop): + + @property + def done(self): + # provide condition to stop the loop + + def advance(self): + # access your dataloader/s in whatever way you want + # do your fancy optimization things + # call the lightning module methods at your leisure + +Finally, attach it into the :class:`~pytorch_lightning.trainer.trainer.Trainer`: + +.. code-block:: python + + trainer = Trainer(...) + trainer.fit_loop = MyFancyLoop() + + # fit() now uses your fancy loop! + trainer.fit(...) + +Now you have full control over the Trainer. +But beware: The power of loop customization comes with great responsibility. +We recommend that you familiarize yourself with :ref:`overriding the default loops ` first before you start building a new loop from the ground up. + +Loop API +-------- +Here is the full API of methods available in the Loop base class. + +The :class:`~pytorch_lightning.loops.base.Loop` class is the base for all loops in Lighting just like the :class:`~pytorch_lightning.core.lightning.LightningModule` is the base for all models. +It defines a public interface that each loop implementation must follow, the key ones are: + +Properties +^^^^^^^^^^ + +done +~~~~ + +.. autoattribute:: pytorch_lightning.loops.base.Loop.done + :noindex: + +skip (optional) +~~~~~~~~~~~~~~~ + +.. autoattribute:: pytorch_lightning.loops.base.Loop.skip + :noindex: + +Methods +^^^^^^^ + +reset (optional) +~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.loops.base.Loop.reset + :noindex: + +advance +~~~~~~~ + +.. automethod:: pytorch_lightning.loops.base.Loop.advance + :noindex: + +run (optional) +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.loops.base.Loop.run + :noindex: + + +Subloops +-------- + +When you want to customize nested loops within loops, use the :meth:`~pytorch_lightning.loops.base.Loop.connect` method: + +.. code-block:: python + + # Step 1: create your loop + my_epoch_loop = MyEpochLoop() + + # Step 2: use connect() + trainer.fit_loop.connect(epoch_loop=my_epoch_loop) + + # Trainer runs the fit loop with your new epoch loop! + trainer.fit(model) + +More about the built-in loops and how they are composed is explained in the next section. + +.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/connect-epoch-loop.gif + :alt: Animation showing how to connect a custom subloop + +.. _loop structure: + +Built-in Loops +-------------- + +The training loop in Lightning is called *fit loop* and is actually a combination of several loops. +Here is what the structure would look like in plain Python: + +.. code-block:: python + + # FitLoop + for epoch in range(max_epochs): + + # TrainingEpochLoop + for batch_idx, batch in enumerate(train_dataloader): + + # TrainingBatchLoop + for split_batch in tbptt_split(batch): + + # OptimizerLoop + for optimizer_idx, opt in enumerate(optimizers): + + loss = lightning_module.training_step(batch, batch_idx, optimizer_idx) + ... + + # ValidationEpochLoop + for batch_idx, batch in enumerate(val_dataloader): + lightning_module.validation_step(batch, batch_idx, optimizer_idx) + ... + + +Each of these :code:`for`-loops represents a class implementing the :class:`~pytorch_lightning.loops.base.Loop` interface. + + +.. list-table:: Trainer entry points and associated loops + :widths: 25 75 + :header-rows: 1 + + * - Built-in loop + - Description + * - :class:`~pytorch_lightning.loops.fit_loop.FitLoop` + - The :class:`~pytorch_lightning.loops.fit_loop.FitLoop` is the top-level loop where training starts. + It simply counts the epochs and iterates from one to the next by calling :code:`TrainingEpochLoop.run()` in its :code:`advance()` method. + * - :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` + - The :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` is the one that iterates over the dataloader that the user returns in their :meth:`~pytorch_lightning.core.lightning.LightningModule.train_dataloader` method. + Its main responsibilities are calling the :code:`*_epoch_start` and :code:`*_epoch_end` hooks, accumulating outputs if the user request them in one of these hooks, and running validation at the requested interval. + The validation is carried out by yet another loop, :class:`~pytorch_lightning.loops.epoch.validation_epoch_loop.ValidationEpochLoop`. + + In the :code:`run()` method, the training epoch loop could in theory simply call the :code:`LightningModule.training_step` already and perform the optimization. + However, Lightning has built-in support for automatic optimization with multiple optimizers and on top of that also supports :doc:`truncated back-propagation through time <../advanced/sequences>`. + For this reason there are actually two more loops nested under :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop`. + * - :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop` + - The responsibility of the :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop` is to split a batch given by the :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` along the time-dimension and iterate over the list of splits. + It also keeps track of the hidden state *hiddens* returned by the training step. + By default, when truncated back-propagation through time (TBPTT) is turned off, this loop does not do anything except redirect the call to the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop`. + Read more about :doc:`TBPTT <../advanced/sequences>`. + * - :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` + - The :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` iterates over one or multiple optimizers and for each one it calls the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method with the batch, the current batch index and the optimizer index if multiple optimizers are requested. + It is the leaf node in the tree of loops and performs the actual optimization (forward, zero grad, backward, optimizer step). + * - :class:`~pytorch_lightning.loops.optimization.manual_loop.ManualOptimization` + - Substitutes the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` in case of :ref:`manual_optimization` and implements the manual optimization step. + + +Available Loops in Lightning Flash +---------------------------------- + +`Active Learning `__ is a machine learning practice in which the user interacts with the learner in order to provide new labels when required. + +You can find a real use case in `Lightning Flash `_. + +Flash implements the :code:`ActiveLearningLoop` that you can use together with the :code:`ActiveLearningDataModule` to label new data on the fly. +To run the following demo, install Flash and `BaaL `__ first: + +.. code-block:: bash + + pip install lightning-flash baal + +.. code-block:: python + + import torch + + import flash + from flash.core.classification import Probabilities + from flash.core.data.utils import download_data + from flash.image import ImageClassificationData, ImageClassifier + from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop + + # 1. Create the DataModule + download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") + + # Implement the research use-case where we mask labels from labelled dataset. + datamodule = ActiveLearningDataModule( + ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2), + val_split=0.1, + ) + + # 2. Build the task + head = torch.nn.Sequential( + torch.nn.Dropout(p=0.1), + torch.nn.Linear(512, datamodule.num_classes), + ) + model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, serializer=Probabilities()) + + # 3.1 Create the trainer + trainer = flash.Trainer(max_epochs=3) + + # 3.2 Create the active learning loop and connect it to the trainer + active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1) + active_learning_loop.connect(trainer.fit_loop) + trainer.fit_loop = active_learning_loop + + # 3.3 Finetune + trainer.finetune(model, datamodule=datamodule, strategy="freeze") + + # 4. Predict what's on a few images! ants or bees? + predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") + print(predictions) + + # 5. Save the model! + trainer.save_checkpoint("image_classification_model.pt") + +Here is the `runnable example `_ and the `code for the active learning loop `_. + + +`KFold / Cross Validation `__ is a machine learning practice in which the training dataset is being partitioned into `num_folds` complementary subsets. +One cross validation round will perform fitting where one fold is left out for validation and the other folds are used for training. +To reduce variability, once all rounds are performed using the different folds, the trained models are essembled and their predictions are +averaged when estimating the model's predictive performance on the test dataset. +KFold can elegantly be implemented with `Lightning Loop Customization` as follows: + +Here is the `runnable example `_. + + +Advanced Topics and Examples +---------------------------- + +Next: :doc:`Advanced loop features and examples <../extensions/loops_advanced>` diff --git a/pl_examples/loops/kfold.py b/pl_examples/loops/kfold.py index f386fc6d6049d..edc803ab9e8cf 100644 --- a/pl_examples/loops/kfold.py +++ b/pl_examples/loops/kfold.py @@ -37,7 +37,7 @@ ############################################################################################# # KFold Loop / Cross Validation Example # # This example demonstrates how to leverage Lightning Loop Customization introduced in v1.5 # -# Learn more about the loop structure from the documentation # +# Learn more about the loop structure from the documentation: # # https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html # ############################################################################################# From 665cf684b416e0a7c90c6c2a9f1f168858716ebb Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 11:56:44 +0100 Subject: [PATCH 12/22] update --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f809e66c6b7ad..6f1c6666a15a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -186,14 +186,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848)) + - Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895)) + - Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699)) - Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597)) +- Added `kfold` example for loop customization ([#9965](https://github.com/PyTorchLightning/pytorch-lightning/pull/9965)) + ### Changed - Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)). From a71af10f4b6da29129b0a229146d94894ec3acea Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 11:58:54 +0100 Subject: [PATCH 13/22] update on comments --- pl_examples/loops/kfold.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pl_examples/loops/kfold.py b/pl_examples/loops/kfold.py index edc803ab9e8cf..30205ac713148 100644 --- a/pl_examples/loops/kfold.py +++ b/pl_examples/loops/kfold.py @@ -131,9 +131,10 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None ############################################################################################# # Step 4 / 5: Implement the KFoldLoop # -# From Lightning v1.5, it is possible to implement your own loop. There is several to do # -# so and refer to the documentation to learn more. # -# Here, we will implement an outter fit_loop. It means we will implement subclass the # +# From Lightning v1.5, it is possible to implement your own loop. There is several steps # +# to do so which are described in detail within the documentation # +# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html. # +# Here, we will implement an outer fit_loop. It means we will implement subclass the # # base Loop and wrap the current trainer `fit_loop`. # # On `on_run_start`, the `KFoldLoop` will call the `KFoldDataModule` `setup_folds` function # # and store the original weights of the model. # From 01def9f63b55faf10a31eed6d3ef56d2d9076bf3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 12:01:07 +0100 Subject: [PATCH 14/22] typo --- docs/source/extensions/loops.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/extensions/loops.rst b/docs/source/extensions/loops.rst index 6fc642d2fe867..2c2b17c32c3b5 100644 --- a/docs/source/extensions/loops.rst +++ b/docs/source/extensions/loops.rst @@ -400,7 +400,7 @@ Here is the `runnable example `__ is a machine learning practice in which the training dataset is being partitioned into `num_folds` complementary subsets. One cross validation round will perform fitting where one fold is left out for validation and the other folds are used for training. -To reduce variability, once all rounds are performed using the different folds, the trained models are essembled and their predictions are +To reduce variability, once all rounds are performed using the different folds, the trained models are ensembled and their predictions are averaged when estimating the model's predictive performance on the test dataset. KFold can elegantly be implemented with `Lightning Loop Customization` as follows: From c525a827eb14a2ff95a54af06865650746040e1d Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 12:06:19 +0100 Subject: [PATCH 15/22] update --- pl_examples/loops/kfold.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_examples/loops/kfold.py b/pl_examples/loops/kfold.py index 30205ac713148..2b2a8d55ba9fd 100644 --- a/pl_examples/loops/kfold.py +++ b/pl_examples/loops/kfold.py @@ -47,8 +47,8 @@ ############################################################################################# # Step 1 / 5: Define KFold DataModule API # -# Our KFold DataModule should require to implement `setup_folds` and `setup_fold_index` # -# function. # +# Our KFold DataModule requires to implement the `setup_folds` and `setup_fold_index` # +# methods. # ############################################################################################# From 70f90a5d5a551997ec6bdab962d62e8704085e58 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 12:08:16 +0100 Subject: [PATCH 16/22] update --- pl_examples/loops/kfold.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pl_examples/loops/kfold.py b/pl_examples/loops/kfold.py index 2b2a8d55ba9fd..857986b63d322 100644 --- a/pl_examples/loops/kfold.py +++ b/pl_examples/loops/kfold.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os +import os.path as osp from abc import ABC, abstractmethod from copy import deepcopy from dataclasses import dataclass @@ -180,15 +180,13 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.current_fold += 1 def on_advance_end(self) -> None: - self.trainer.save_checkpoint(os.path.join(self.export_path, f"model.{self.current_fold}.pt")) + self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt")) # restore the original weights + optimizers and schedulers. self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) self.trainer.accelerator.setup_optimizers(self.trainer) def on_run_end(self) -> None: - checkpoint_paths = [ - os.path.join(self.export_path, f"model.{fold_index + 1}.pt") for fold_index in range(self.num_folds) - ] + checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)] voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) voting_model.trainer = self.trainer self.trainer.accelerator.connect(voting_model) From deb076500e46d3ab53aa3bd726b9366c3c5b9064 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 12:08:35 +0100 Subject: [PATCH 17/22] update --- pl_examples/loops/kfold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_examples/loops/kfold.py b/pl_examples/loops/kfold.py index 857986b63d322..fb53fe3d4e27a 100644 --- a/pl_examples/loops/kfold.py +++ b/pl_examples/loops/kfold.py @@ -211,7 +211,7 @@ def _reset_testing(self) -> None: self.trainer.testing = True def __getattr__(self, key) -> Any: - # requires to be overridden as attributes of the wrapped loop as being accessed. + # requires to be overridden as attributes of the wrapped loop are being accessed. if key not in self.__dict__: return getattr(self.fit_loop, key) return self.__dict__[key] From adb4cd1ca574a766a08092052e1eeb4cca2dae2d Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 12:46:39 +0100 Subject: [PATCH 18/22] update --- _notebooks | 2 +- docs/source/extensions/loops.rst | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/_notebooks b/_notebooks index a2fb6468112b7..32a3ce70fc49c 160000 --- a/_notebooks +++ b/_notebooks @@ -1 +1 @@ -Subproject commit a2fb6468112b7e1dad501c3b6a17533a4adfeabc +Subproject commit 32a3ce70fc49c8ac3dc024736199758edcbf156d diff --git a/docs/source/extensions/loops.rst b/docs/source/extensions/loops.rst index 2c2b17c32c3b5..d779f16c4505b 100644 --- a/docs/source/extensions/loops.rst +++ b/docs/source/extensions/loops.rst @@ -395,7 +395,7 @@ To run the following demo, install Flash and `BaaL `_ and the `code for the active learning loop `_. +Here is the `Active Learning Loop example `_ and the `code for the active learning loop `_. `KFold / Cross Validation `__ is a machine learning practice in which the training dataset is being partitioned into `num_folds` complementary subsets. @@ -404,7 +404,7 @@ To reduce variability, once all rounds are performed using the different folds, averaged when estimating the model's predictive performance on the test dataset. KFold can elegantly be implemented with `Lightning Loop Customization` as follows: -Here is the `runnable example `_. +Here is the `KFold Loop example `_. Advanced Topics and Examples From b56804c9e1bba37fa5f87a39e234e82907755399 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 12:56:51 +0100 Subject: [PATCH 19/22] purge notebooks --- _notebooks | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 32a3ce70fc49c..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 32a3ce70fc49c8ac3dc024736199758edcbf156d From 8f84ea0896a1459110fb33f7319d575453fc7261 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 12:57:06 +0100 Subject: [PATCH 20/22] purged notebooks --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..a2fb6468112b7 --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit a2fb6468112b7e1dad501c3b6a17533a4adfeabc From 837fd92420ae220014183c6e72866fb7a293366f Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 15:35:42 +0100 Subject: [PATCH 21/22] update --- .../{loops => loop_examples}/__init__.py | 0 pl_examples/{loops => loop_examples}/kfold.py | 25 +++++++++---------- 2 files changed, 12 insertions(+), 13 deletions(-) rename pl_examples/{loops => loop_examples}/__init__.py (100%) rename pl_examples/{loops => loop_examples}/kfold.py (91%) diff --git a/pl_examples/loops/__init__.py b/pl_examples/loop_examples/__init__.py similarity index 100% rename from pl_examples/loops/__init__.py rename to pl_examples/loop_examples/__init__.py diff --git a/pl_examples/loops/kfold.py b/pl_examples/loop_examples/kfold.py similarity index 91% rename from pl_examples/loops/kfold.py rename to pl_examples/loop_examples/kfold.py index fb53fe3d4e27a..546a1edfd9969 100644 --- a/pl_examples/loops/kfold.py +++ b/pl_examples/loop_examples/kfold.py @@ -111,9 +111,6 @@ def test_dataloader(self) -> DataLoader: # Step 3 / 5: Implement the EnsembleVotingModel module # # The `EnsembleVotingModel` will take our custom LightningModule and # # several checkpoint_paths. # -# On `__init__`, it would create multiple models by reloading the fold weights # -# On `test_step`, the model will perform a forward through all the models and take # -# the average logits produced by the `num_folds` models, and loss the enssembling loss # # # ############################################################################################# @@ -121,9 +118,11 @@ def test_dataloader(self) -> DataLoader: class EnsembleVotingModel(LightningModule): def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str]): super().__init__() + # Create `num_folds` models with their associated fold weights self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths]) def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + # Compute the averaged predictions over the `num_folds` models. logits = torch.stack([m(batch[0]) for m in self.models]).mean(0) loss = F.cross_entropy(logits, batch[1]) self.log("test_loss", loss) @@ -136,13 +135,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None # https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html. # # Here, we will implement an outer fit_loop. It means we will implement subclass the # # base Loop and wrap the current trainer `fit_loop`. # -# On `on_run_start`, the `KFoldLoop` will call the `KFoldDataModule` `setup_folds` function # -# and store the original weights of the model. # -# On `on_advance_start`, the `KFoldLoop` will call the `KFoldDataModule` `setup_fold_index` # -# function. # -# On `advance`, the `KFoldLoop` will run the original trainer `fit_loop` and # -# the trainer `test_loop`. # -# On `advance_end`, the `KFoldLoop` will reset the model weight and optimizers / schedulers # ############################################################################################# @@ -162,33 +154,40 @@ def reset(self) -> None: """Nothing to reset in this loop.""" def on_run_start(self, *args: Any, **kwargs: Any) -> None: + """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the + model.""" assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) self.trainer.datamodule.setup_folds(self.num_folds) self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict()) def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance.""" print(f"STARTING FOLD {self.current_fold}") assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) self.trainer.datamodule.setup_fold_index(self.current_fold) def advance(self, *args: Any, **kwargs: Any) -> None: - self._reset_fitting() # requires to reset the tracking stage + """Used to the run a fitting and testing on the current hold.""" + self._reset_fitting() # requires to reset the tracking stage. self.fit_loop.run() - self._reset_testing() # requires to reset the tracking stage + self._reset_testing() # requires to reset the tracking stage. self.trainer.test_loop.run() - self.current_fold += 1 + self.current_fold += 1 # increment fold tracking number. def on_advance_end(self) -> None: + """Used to save the weights of the current fold and reset the LightningModule and its optimizers.""" self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt")) # restore the original weights + optimizers and schedulers. self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) self.trainer.accelerator.setup_optimizers(self.trainer) def on_run_end(self) -> None: + """Used to compute the performance of the ensemble model on the test set.""" checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)] voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) voting_model.trainer = self.trainer + # This requires to connect the new model and move it the right device. self.trainer.accelerator.connect(voting_model) self.trainer.training_type_plugin.model_to_device() self.trainer.test_loop.run() From 249ffec7c527f130abf748bf397672c500d8b55c Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 18 Oct 2021 15:39:53 +0100 Subject: [PATCH 22/22] update --- pl_examples/loop_examples/kfold.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pl_examples/loop_examples/kfold.py b/pl_examples/loop_examples/kfold.py index 546a1edfd9969..630b1f26f3b4a 100644 --- a/pl_examples/loop_examples/kfold.py +++ b/pl_examples/loop_examples/kfold.py @@ -138,6 +138,23 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None ############################################################################################# +############################################################################################# +# Here is the `Pseudo Code` for the base Loop. # +# class Loop: # +# # +# def run(self, ...): # +# self.reset(...) # +# self.on_run_start(...) # +# # +# while not self.done: # +# self.on_advance_start(...) # +# self.advance(...) # +# self.on_advance_end(...) # +# # +# return self.on_run_end(...) # +############################################################################################# + + class KFoldLoop(Loop): def __init__(self, num_folds: int, fit_loop: FitLoop, export_path: str): super().__init__()