Skip to content

Commit 86df7dc

Browse files
authored
Add KFold Loop example (#9965)
1 parent a99b744 commit 86df7dc

File tree

6 files changed

+289
-14
lines changed

6 files changed

+289
-14
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
186186

187187
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))
188188

189+
189190
- Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895))
190191

192+
191193
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))
192194

193195

194196
- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))
195197

196198

199+
- Added `kfold` example for loop customization ([#9965](https://github.com/PyTorchLightning/pytorch-lightning/pull/9965))
200+
201+
197202
- LightningLite:
198203
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
199204

docs/source/extensions/loops.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,17 @@ To run the following demo, install Flash and `BaaL <https://github.com/ElementAI
395395
# 5. Save the model!
396396
trainer.save_checkpoint("image_classification_model.pt")
397397
398-
Here is the `runnable example <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/integrations/baal/image_classification_active_learning.py>`_ and the `code for the active learning loop <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/classification/integrations/baal/loop.py#L31>`_.
398+
Here is the `Active Learning Loop example <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/integrations/baal/image_classification_active_learning.py>`_ and the `code for the active learning loop <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/classification/integrations/baal/loop.py#L31>`_.
399+
400+
401+
`KFold / Cross Validation <https://en.wikipedia.org/wiki/Cross-validation_(statistics)>`__ is a machine learning practice in which the training dataset is being partitioned into `num_folds` complementary subsets.
402+
One cross validation round will perform fitting where one fold is left out for validation and the other folds are used for training.
403+
To reduce variability, once all rounds are performed using the different folds, the trained models are ensembled and their predictions are
404+
averaged when estimating the model's predictive performance on the test dataset.
405+
KFold can elegantly be implemented with `Lightning Loop Customization` as follows:
406+
407+
Here is the `KFold Loop example <https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/loops/kfold.py>`_.
408+
399409

400410
Advanced Topics and Examples
401411
----------------------------

pl_examples/basic_examples/mnist_datamodule.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,21 @@
2626
if _TORCHVISION_AVAILABLE:
2727
from torchvision import transforms as transform_lib
2828

29-
_TORCHVISION_MNIST_AVAILABLE = not bool(os.getenv("PL_USE_MOCKED_MNIST", False))
30-
if _TORCHVISION_MNIST_AVAILABLE:
31-
try:
32-
from torchvision.datasets import MNIST
33-
34-
MNIST(_DATASETS_PATH, download=True)
35-
except HTTPError as e:
36-
print(f"Error {e} downloading `torchvision.datasets.MNIST`")
37-
_TORCHVISION_MNIST_AVAILABLE = False
38-
if not _TORCHVISION_MNIST_AVAILABLE:
39-
print("`torchvision.datasets.MNIST` not available. Using our hosted version")
40-
from tests.helpers.datasets import MNIST
29+
30+
def MNIST(*args, **kwargs):
31+
torchvision_mnist_available = not bool(os.getenv("PL_USE_MOCKED_MNIST", False))
32+
if torchvision_mnist_available:
33+
try:
34+
from torchvision.datasets import MNIST
35+
36+
MNIST(_DATASETS_PATH, download=True)
37+
except HTTPError as e:
38+
print(f"Error {e} downloading `torchvision.datasets.MNIST`")
39+
torchvision_mnist_available = False
40+
if not torchvision_mnist_available:
41+
print("`torchvision.datasets.MNIST` not available. Using our hosted version")
42+
from tests.helpers.datasets import MNIST
43+
return MNIST(*args, **kwargs)
4144

4245

4346
class MNISTDataModule(LightningDataModule):

pl_examples/loop_examples/__init__.py

Whitespace-only changes.

pl_examples/loop_examples/kfold.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os.path as osp
15+
from abc import ABC, abstractmethod
16+
from copy import deepcopy
17+
from dataclasses import dataclass
18+
from typing import Any, Dict, List, Optional, Type
19+
20+
import torch
21+
import torchvision.transforms as T
22+
from sklearn.model_selection import KFold
23+
from torch.nn import functional as F
24+
from torch.utils.data import random_split
25+
from torch.utils.data.dataloader import DataLoader
26+
from torch.utils.data.dataset import Dataset, Subset
27+
28+
from pl_examples import _DATASETS_PATH
29+
from pl_examples.basic_examples.mnist_datamodule import MNIST
30+
from pl_examples.basic_examples.simple_image_classifier import LitClassifier
31+
from pytorch_lightning import LightningDataModule, seed_everything, Trainer
32+
from pytorch_lightning.core.lightning import LightningModule
33+
from pytorch_lightning.loops.base import Loop
34+
from pytorch_lightning.loops.fit_loop import FitLoop
35+
from pytorch_lightning.trainer.states import TrainerFn
36+
37+
#############################################################################################
38+
# KFold Loop / Cross Validation Example #
39+
# This example demonstrates how to leverage Lightning Loop Customization introduced in v1.5 #
40+
# Learn more about the loop structure from the documentation: #
41+
# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html #
42+
#############################################################################################
43+
44+
45+
seed_everything(42)
46+
47+
48+
#############################################################################################
49+
# Step 1 / 5: Define KFold DataModule API #
50+
# Our KFold DataModule requires to implement the `setup_folds` and `setup_fold_index` #
51+
# methods. #
52+
#############################################################################################
53+
54+
55+
class BaseKFoldDataModule(LightningDataModule, ABC):
56+
@abstractmethod
57+
def setup_folds(self, num_folds: int) -> None:
58+
pass
59+
60+
@abstractmethod
61+
def setup_fold_index(self, fold_index: int) -> None:
62+
pass
63+
64+
65+
#############################################################################################
66+
# Step 2 / 5: Implement the KFoldDataModule #
67+
# The `KFoldDataModule` will take a train and test dataset. #
68+
# On `setup_folds`, folds will be created depending on the provided argument `num_folds` #
69+
# Our `setup_fold_index`, the provided train dataset will be splitted accordingly to #
70+
# the current fold split. #
71+
#############################################################################################
72+
73+
74+
@dataclass
75+
class MNISTKFoldDataModule(BaseKFoldDataModule):
76+
77+
train_dataset: Optional[Dataset] = None
78+
test_dataset: Optional[Dataset] = None
79+
train_fold: Optional[Dataset] = None
80+
val_fold: Optional[Dataset] = None
81+
82+
def prepare_data(self) -> None:
83+
# download the data.
84+
MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))]))
85+
86+
def setup(self, stage: Optional[str] = None) -> None:
87+
# load the data
88+
dataset = MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))]))
89+
self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000])
90+
91+
def setup_folds(self, num_folds: int) -> None:
92+
self.num_folds = num_folds
93+
self.splits = [split for split in KFold(num_folds).split(range(len(self.train_dataset)))]
94+
95+
def setup_fold_index(self, fold_index: int) -> None:
96+
train_indices, val_indices = self.splits[fold_index]
97+
self.train_fold = Subset(self.train_dataset, train_indices)
98+
self.val_fold = Subset(self.train_dataset, val_indices)
99+
100+
def train_dataloader(self) -> DataLoader:
101+
return DataLoader(self.train_fold)
102+
103+
def val_dataloader(self) -> DataLoader:
104+
return DataLoader(self.val_fold)
105+
106+
def test_dataloader(self) -> DataLoader:
107+
return DataLoader(self.test_dataset)
108+
109+
110+
#############################################################################################
111+
# Step 3 / 5: Implement the EnsembleVotingModel module #
112+
# The `EnsembleVotingModel` will take our custom LightningModule and #
113+
# several checkpoint_paths. #
114+
# #
115+
#############################################################################################
116+
117+
118+
class EnsembleVotingModel(LightningModule):
119+
def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str]):
120+
super().__init__()
121+
# Create `num_folds` models with their associated fold weights
122+
self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths])
123+
124+
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
125+
# Compute the averaged predictions over the `num_folds` models.
126+
logits = torch.stack([m(batch[0]) for m in self.models]).mean(0)
127+
loss = F.cross_entropy(logits, batch[1])
128+
self.log("test_loss", loss)
129+
130+
131+
#############################################################################################
132+
# Step 4 / 5: Implement the KFoldLoop #
133+
# From Lightning v1.5, it is possible to implement your own loop. There is several steps #
134+
# to do so which are described in detail within the documentation #
135+
# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html. #
136+
# Here, we will implement an outer fit_loop. It means we will implement subclass the #
137+
# base Loop and wrap the current trainer `fit_loop`. #
138+
#############################################################################################
139+
140+
141+
#############################################################################################
142+
# Here is the `Pseudo Code` for the base Loop. #
143+
# class Loop: #
144+
# #
145+
# def run(self, ...): #
146+
# self.reset(...) #
147+
# self.on_run_start(...) #
148+
# #
149+
# while not self.done: #
150+
# self.on_advance_start(...) #
151+
# self.advance(...) #
152+
# self.on_advance_end(...) #
153+
# #
154+
# return self.on_run_end(...) #
155+
#############################################################################################
156+
157+
158+
class KFoldLoop(Loop):
159+
def __init__(self, num_folds: int, fit_loop: FitLoop, export_path: str):
160+
super().__init__()
161+
self.num_folds = num_folds
162+
self.fit_loop = fit_loop
163+
self.current_fold: int = 0
164+
self.export_path = export_path
165+
166+
@property
167+
def done(self) -> bool:
168+
return self.current_fold >= self.num_folds
169+
170+
def reset(self) -> None:
171+
"""Nothing to reset in this loop."""
172+
173+
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
174+
"""Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
175+
model."""
176+
assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
177+
self.trainer.datamodule.setup_folds(self.num_folds)
178+
self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict())
179+
180+
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
181+
"""Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
182+
print(f"STARTING FOLD {self.current_fold}")
183+
assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
184+
self.trainer.datamodule.setup_fold_index(self.current_fold)
185+
186+
def advance(self, *args: Any, **kwargs: Any) -> None:
187+
"""Used to the run a fitting and testing on the current hold."""
188+
self._reset_fitting() # requires to reset the tracking stage.
189+
self.fit_loop.run()
190+
191+
self._reset_testing() # requires to reset the tracking stage.
192+
self.trainer.test_loop.run()
193+
self.current_fold += 1 # increment fold tracking number.
194+
195+
def on_advance_end(self) -> None:
196+
"""Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
197+
self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt"))
198+
# restore the original weights + optimizers and schedulers.
199+
self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
200+
self.trainer.accelerator.setup_optimizers(self.trainer)
201+
202+
def on_run_end(self) -> None:
203+
"""Used to compute the performance of the ensemble model on the test set."""
204+
checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)]
205+
voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths)
206+
voting_model.trainer = self.trainer
207+
# This requires to connect the new model and move it the right device.
208+
self.trainer.accelerator.connect(voting_model)
209+
self.trainer.training_type_plugin.model_to_device()
210+
self.trainer.test_loop.run()
211+
212+
def on_save_checkpoint(self) -> Dict[str, int]:
213+
return {"current_fold": self.current_fold}
214+
215+
def on_load_checkpoint(self, state_dict: Dict) -> None:
216+
self.current_fold = state_dict["current_fold"]
217+
218+
def _reset_fitting(self) -> None:
219+
self.trainer.reset_train_dataloader()
220+
self.trainer.reset_val_dataloader()
221+
self.trainer.state.fn = TrainerFn.FITTING
222+
self.trainer.training = True
223+
224+
def _reset_testing(self) -> None:
225+
self.trainer.reset_test_dataloader()
226+
self.trainer.state.fn = TrainerFn.TESTING
227+
self.trainer.testing = True
228+
229+
def __getattr__(self, key) -> Any:
230+
# requires to be overridden as attributes of the wrapped loop are being accessed.
231+
if key not in self.__dict__:
232+
return getattr(self.fit_loop, key)
233+
return self.__dict__[key]
234+
235+
236+
#############################################################################################
237+
# Step 5 / 5: Connect the KFoldLoop to the Trainer #
238+
# After creating the `KFoldDataModule` and our model, the `KFoldLoop` is being connected to #
239+
# the Trainer. #
240+
# Finally, use `trainer.fit` to start the cross validation training. #
241+
#############################################################################################
242+
243+
model = LitClassifier()
244+
datamodule = MNISTKFoldDataModule()
245+
trainer = Trainer(
246+
max_epochs=10,
247+
limit_train_batches=2,
248+
limit_val_batches=2,
249+
limit_test_batches=2,
250+
num_sanity_val_steps=0,
251+
devices=1,
252+
accelerator="auto",
253+
strategy="ddp",
254+
)
255+
trainer.fit_loop = KFoldLoop(5, trainer.fit_loop, export_path="./")
256+
trainer.fit(model, datamodule)

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,8 @@ def _run_train(self) -> None:
12251225
# reload data when needed
12261226
model = self.lightning_module
12271227

1228-
self.reset_train_val_dataloaders(model)
1228+
if isinstance(self.fit_loop, FitLoop):
1229+
self.reset_train_val_dataloaders(model)
12291230

12301231
self.fit_loop.trainer = self
12311232
with torch.autograd.set_detect_anomaly(self._detect_anomaly):

0 commit comments

Comments
 (0)