|
| 1 | +# Copyright The PyTorch Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import os.path as osp |
| 15 | +from abc import ABC, abstractmethod |
| 16 | +from copy import deepcopy |
| 17 | +from dataclasses import dataclass |
| 18 | +from typing import Any, Dict, List, Optional, Type |
| 19 | + |
| 20 | +import torch |
| 21 | +import torchvision.transforms as T |
| 22 | +from sklearn.model_selection import KFold |
| 23 | +from torch.nn import functional as F |
| 24 | +from torch.utils.data import random_split |
| 25 | +from torch.utils.data.dataloader import DataLoader |
| 26 | +from torch.utils.data.dataset import Dataset, Subset |
| 27 | + |
| 28 | +from pl_examples import _DATASETS_PATH |
| 29 | +from pl_examples.basic_examples.mnist_datamodule import MNIST |
| 30 | +from pl_examples.basic_examples.simple_image_classifier import LitClassifier |
| 31 | +from pytorch_lightning import LightningDataModule, seed_everything, Trainer |
| 32 | +from pytorch_lightning.core.lightning import LightningModule |
| 33 | +from pytorch_lightning.loops.base import Loop |
| 34 | +from pytorch_lightning.loops.fit_loop import FitLoop |
| 35 | +from pytorch_lightning.trainer.states import TrainerFn |
| 36 | + |
| 37 | +############################################################################################# |
| 38 | +# KFold Loop / Cross Validation Example # |
| 39 | +# This example demonstrates how to leverage Lightning Loop Customization introduced in v1.5 # |
| 40 | +# Learn more about the loop structure from the documentation: # |
| 41 | +# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html # |
| 42 | +############################################################################################# |
| 43 | + |
| 44 | + |
| 45 | +seed_everything(42) |
| 46 | + |
| 47 | + |
| 48 | +############################################################################################# |
| 49 | +# Step 1 / 5: Define KFold DataModule API # |
| 50 | +# Our KFold DataModule requires to implement the `setup_folds` and `setup_fold_index` # |
| 51 | +# methods. # |
| 52 | +############################################################################################# |
| 53 | + |
| 54 | + |
| 55 | +class BaseKFoldDataModule(LightningDataModule, ABC): |
| 56 | + @abstractmethod |
| 57 | + def setup_folds(self, num_folds: int) -> None: |
| 58 | + pass |
| 59 | + |
| 60 | + @abstractmethod |
| 61 | + def setup_fold_index(self, fold_index: int) -> None: |
| 62 | + pass |
| 63 | + |
| 64 | + |
| 65 | +############################################################################################# |
| 66 | +# Step 2 / 5: Implement the KFoldDataModule # |
| 67 | +# The `KFoldDataModule` will take a train and test dataset. # |
| 68 | +# On `setup_folds`, folds will be created depending on the provided argument `num_folds` # |
| 69 | +# Our `setup_fold_index`, the provided train dataset will be splitted accordingly to # |
| 70 | +# the current fold split. # |
| 71 | +############################################################################################# |
| 72 | + |
| 73 | + |
| 74 | +@dataclass |
| 75 | +class MNISTKFoldDataModule(BaseKFoldDataModule): |
| 76 | + |
| 77 | + train_dataset: Optional[Dataset] = None |
| 78 | + test_dataset: Optional[Dataset] = None |
| 79 | + train_fold: Optional[Dataset] = None |
| 80 | + val_fold: Optional[Dataset] = None |
| 81 | + |
| 82 | + def prepare_data(self) -> None: |
| 83 | + # download the data. |
| 84 | + MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) |
| 85 | + |
| 86 | + def setup(self, stage: Optional[str] = None) -> None: |
| 87 | + # load the data |
| 88 | + dataset = MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) |
| 89 | + self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000]) |
| 90 | + |
| 91 | + def setup_folds(self, num_folds: int) -> None: |
| 92 | + self.num_folds = num_folds |
| 93 | + self.splits = [split for split in KFold(num_folds).split(range(len(self.train_dataset)))] |
| 94 | + |
| 95 | + def setup_fold_index(self, fold_index: int) -> None: |
| 96 | + train_indices, val_indices = self.splits[fold_index] |
| 97 | + self.train_fold = Subset(self.train_dataset, train_indices) |
| 98 | + self.val_fold = Subset(self.train_dataset, val_indices) |
| 99 | + |
| 100 | + def train_dataloader(self) -> DataLoader: |
| 101 | + return DataLoader(self.train_fold) |
| 102 | + |
| 103 | + def val_dataloader(self) -> DataLoader: |
| 104 | + return DataLoader(self.val_fold) |
| 105 | + |
| 106 | + def test_dataloader(self) -> DataLoader: |
| 107 | + return DataLoader(self.test_dataset) |
| 108 | + |
| 109 | + |
| 110 | +############################################################################################# |
| 111 | +# Step 3 / 5: Implement the EnsembleVotingModel module # |
| 112 | +# The `EnsembleVotingModel` will take our custom LightningModule and # |
| 113 | +# several checkpoint_paths. # |
| 114 | +# # |
| 115 | +############################################################################################# |
| 116 | + |
| 117 | + |
| 118 | +class EnsembleVotingModel(LightningModule): |
| 119 | + def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str]): |
| 120 | + super().__init__() |
| 121 | + # Create `num_folds` models with their associated fold weights |
| 122 | + self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths]) |
| 123 | + |
| 124 | + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: |
| 125 | + # Compute the averaged predictions over the `num_folds` models. |
| 126 | + logits = torch.stack([m(batch[0]) for m in self.models]).mean(0) |
| 127 | + loss = F.cross_entropy(logits, batch[1]) |
| 128 | + self.log("test_loss", loss) |
| 129 | + |
| 130 | + |
| 131 | +############################################################################################# |
| 132 | +# Step 4 / 5: Implement the KFoldLoop # |
| 133 | +# From Lightning v1.5, it is possible to implement your own loop. There is several steps # |
| 134 | +# to do so which are described in detail within the documentation # |
| 135 | +# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html. # |
| 136 | +# Here, we will implement an outer fit_loop. It means we will implement subclass the # |
| 137 | +# base Loop and wrap the current trainer `fit_loop`. # |
| 138 | +############################################################################################# |
| 139 | + |
| 140 | + |
| 141 | +############################################################################################# |
| 142 | +# Here is the `Pseudo Code` for the base Loop. # |
| 143 | +# class Loop: # |
| 144 | +# # |
| 145 | +# def run(self, ...): # |
| 146 | +# self.reset(...) # |
| 147 | +# self.on_run_start(...) # |
| 148 | +# # |
| 149 | +# while not self.done: # |
| 150 | +# self.on_advance_start(...) # |
| 151 | +# self.advance(...) # |
| 152 | +# self.on_advance_end(...) # |
| 153 | +# # |
| 154 | +# return self.on_run_end(...) # |
| 155 | +############################################################################################# |
| 156 | + |
| 157 | + |
| 158 | +class KFoldLoop(Loop): |
| 159 | + def __init__(self, num_folds: int, fit_loop: FitLoop, export_path: str): |
| 160 | + super().__init__() |
| 161 | + self.num_folds = num_folds |
| 162 | + self.fit_loop = fit_loop |
| 163 | + self.current_fold: int = 0 |
| 164 | + self.export_path = export_path |
| 165 | + |
| 166 | + @property |
| 167 | + def done(self) -> bool: |
| 168 | + return self.current_fold >= self.num_folds |
| 169 | + |
| 170 | + def reset(self) -> None: |
| 171 | + """Nothing to reset in this loop.""" |
| 172 | + |
| 173 | + def on_run_start(self, *args: Any, **kwargs: Any) -> None: |
| 174 | + """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the |
| 175 | + model.""" |
| 176 | + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) |
| 177 | + self.trainer.datamodule.setup_folds(self.num_folds) |
| 178 | + self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict()) |
| 179 | + |
| 180 | + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: |
| 181 | + """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance.""" |
| 182 | + print(f"STARTING FOLD {self.current_fold}") |
| 183 | + assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) |
| 184 | + self.trainer.datamodule.setup_fold_index(self.current_fold) |
| 185 | + |
| 186 | + def advance(self, *args: Any, **kwargs: Any) -> None: |
| 187 | + """Used to the run a fitting and testing on the current hold.""" |
| 188 | + self._reset_fitting() # requires to reset the tracking stage. |
| 189 | + self.fit_loop.run() |
| 190 | + |
| 191 | + self._reset_testing() # requires to reset the tracking stage. |
| 192 | + self.trainer.test_loop.run() |
| 193 | + self.current_fold += 1 # increment fold tracking number. |
| 194 | + |
| 195 | + def on_advance_end(self) -> None: |
| 196 | + """Used to save the weights of the current fold and reset the LightningModule and its optimizers.""" |
| 197 | + self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt")) |
| 198 | + # restore the original weights + optimizers and schedulers. |
| 199 | + self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) |
| 200 | + self.trainer.accelerator.setup_optimizers(self.trainer) |
| 201 | + |
| 202 | + def on_run_end(self) -> None: |
| 203 | + """Used to compute the performance of the ensemble model on the test set.""" |
| 204 | + checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)] |
| 205 | + voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) |
| 206 | + voting_model.trainer = self.trainer |
| 207 | + # This requires to connect the new model and move it the right device. |
| 208 | + self.trainer.accelerator.connect(voting_model) |
| 209 | + self.trainer.training_type_plugin.model_to_device() |
| 210 | + self.trainer.test_loop.run() |
| 211 | + |
| 212 | + def on_save_checkpoint(self) -> Dict[str, int]: |
| 213 | + return {"current_fold": self.current_fold} |
| 214 | + |
| 215 | + def on_load_checkpoint(self, state_dict: Dict) -> None: |
| 216 | + self.current_fold = state_dict["current_fold"] |
| 217 | + |
| 218 | + def _reset_fitting(self) -> None: |
| 219 | + self.trainer.reset_train_dataloader() |
| 220 | + self.trainer.reset_val_dataloader() |
| 221 | + self.trainer.state.fn = TrainerFn.FITTING |
| 222 | + self.trainer.training = True |
| 223 | + |
| 224 | + def _reset_testing(self) -> None: |
| 225 | + self.trainer.reset_test_dataloader() |
| 226 | + self.trainer.state.fn = TrainerFn.TESTING |
| 227 | + self.trainer.testing = True |
| 228 | + |
| 229 | + def __getattr__(self, key) -> Any: |
| 230 | + # requires to be overridden as attributes of the wrapped loop are being accessed. |
| 231 | + if key not in self.__dict__: |
| 232 | + return getattr(self.fit_loop, key) |
| 233 | + return self.__dict__[key] |
| 234 | + |
| 235 | + |
| 236 | +############################################################################################# |
| 237 | +# Step 5 / 5: Connect the KFoldLoop to the Trainer # |
| 238 | +# After creating the `KFoldDataModule` and our model, the `KFoldLoop` is being connected to # |
| 239 | +# the Trainer. # |
| 240 | +# Finally, use `trainer.fit` to start the cross validation training. # |
| 241 | +############################################################################################# |
| 242 | + |
| 243 | +model = LitClassifier() |
| 244 | +datamodule = MNISTKFoldDataModule() |
| 245 | +trainer = Trainer( |
| 246 | + max_epochs=10, |
| 247 | + limit_train_batches=2, |
| 248 | + limit_val_batches=2, |
| 249 | + limit_test_batches=2, |
| 250 | + num_sanity_val_steps=0, |
| 251 | + devices=1, |
| 252 | + accelerator="auto", |
| 253 | + strategy="ddp", |
| 254 | +) |
| 255 | +trainer.fit_loop = KFoldLoop(5, trainer.fit_loop, export_path="./") |
| 256 | +trainer.fit(model, datamodule) |
0 commit comments