-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Add KFold Loop example #9965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add KFold Loop example #9965
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
a7687b8
update
tchaton 8fb2261
update
tchaton 9710ae0
update
tchaton 0e108b3
update
tchaton 55456b8
update
tchaton 75a9454
add comments
tchaton 7c52c4e
update
tchaton 5e31c46
update
tchaton 6acc222
update
tchaton 4ee74e6
update on comments
tchaton 206d614
add doc
tchaton 66f81d7
Merge branch 'master' into add_example_kfold_loop
tchaton 665cf68
update
tchaton a6a919e
update
tchaton a71af10
update on comments
tchaton 01def9f
typo
tchaton c525a82
update
tchaton 70f90a5
update
tchaton deb0765
update
tchaton adb4cd1
update
tchaton b56804c
purge notebooks
tchaton 8f84ea0
purged notebooks
tchaton 3f51783
update
tchaton 837fd92
update
tchaton 249ffec
update
tchaton File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os.path as osp | ||
from abc import ABC, abstractmethod | ||
from copy import deepcopy | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List, Optional, Type | ||
|
||
import torch | ||
import torchvision.transforms as T | ||
from sklearn.model_selection import KFold | ||
from torch.nn import functional as F | ||
from torch.utils.data import random_split | ||
from torch.utils.data.dataloader import DataLoader | ||
from torch.utils.data.dataset import Dataset, Subset | ||
|
||
from pl_examples import _DATASETS_PATH | ||
from pl_examples.basic_examples.mnist_datamodule import MNIST | ||
from pl_examples.basic_examples.simple_image_classifier import LitClassifier | ||
from pytorch_lightning import LightningDataModule, seed_everything, Trainer | ||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.loops.base import Loop | ||
from pytorch_lightning.loops.fit_loop import FitLoop | ||
from pytorch_lightning.trainer.states import TrainerFn | ||
|
||
############################################################################################# | ||
# KFold Loop / Cross Validation Example # | ||
# This example demonstrates how to leverage Lightning Loop Customization introduced in v1.5 # | ||
# Learn more about the loop structure from the documentation: # | ||
# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html # | ||
############################################################################################# | ||
|
||
|
||
seed_everything(42) | ||
|
||
|
||
############################################################################################# | ||
# Step 1 / 5: Define KFold DataModule API # | ||
# Our KFold DataModule requires to implement the `setup_folds` and `setup_fold_index` # | ||
# methods. # | ||
############################################################################################# | ||
|
||
|
||
class BaseKFoldDataModule(LightningDataModule, ABC): | ||
@abstractmethod | ||
def setup_folds(self, num_folds: int) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def setup_fold_index(self, fold_index: int) -> None: | ||
pass | ||
|
||
|
||
############################################################################################# | ||
# Step 2 / 5: Implement the KFoldDataModule # | ||
# The `KFoldDataModule` will take a train and test dataset. # | ||
# On `setup_folds`, folds will be created depending on the provided argument `num_folds` # | ||
# Our `setup_fold_index`, the provided train dataset will be splitted accordingly to # | ||
# the current fold split. # | ||
############################################################################################# | ||
|
||
|
||
@dataclass | ||
class MNISTKFoldDataModule(BaseKFoldDataModule): | ||
|
||
train_dataset: Optional[Dataset] = None | ||
test_dataset: Optional[Dataset] = None | ||
train_fold: Optional[Dataset] = None | ||
val_fold: Optional[Dataset] = None | ||
|
||
def prepare_data(self) -> None: | ||
# download the data. | ||
MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) | ||
|
||
def setup(self, stage: Optional[str] = None) -> None: | ||
# load the data | ||
dataset = MNIST(_DATASETS_PATH, transform=T.Compose([T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,))])) | ||
self.train_dataset, self.test_dataset = random_split(dataset, [50000, 10000]) | ||
|
||
def setup_folds(self, num_folds: int) -> None: | ||
self.num_folds = num_folds | ||
self.splits = [split for split in KFold(num_folds).split(range(len(self.train_dataset)))] | ||
|
||
def setup_fold_index(self, fold_index: int) -> None: | ||
train_indices, val_indices = self.splits[fold_index] | ||
self.train_fold = Subset(self.train_dataset, train_indices) | ||
self.val_fold = Subset(self.train_dataset, val_indices) | ||
|
||
def train_dataloader(self) -> DataLoader: | ||
return DataLoader(self.train_fold) | ||
|
||
def val_dataloader(self) -> DataLoader: | ||
return DataLoader(self.val_fold) | ||
|
||
def test_dataloader(self) -> DataLoader: | ||
return DataLoader(self.test_dataset) | ||
|
||
|
||
############################################################################################# | ||
# Step 3 / 5: Implement the EnsembleVotingModel module # | ||
# The `EnsembleVotingModel` will take our custom LightningModule and # | ||
# several checkpoint_paths. # | ||
# # | ||
############################################################################################# | ||
|
||
|
||
class EnsembleVotingModel(LightningModule): | ||
def __init__(self, model_cls: Type[LightningModule], checkpoint_paths: List[str]): | ||
super().__init__() | ||
# Create `num_folds` models with their associated fold weights | ||
self.models = torch.nn.ModuleList([model_cls.load_from_checkpoint(p) for p in checkpoint_paths]) | ||
|
||
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: | ||
# Compute the averaged predictions over the `num_folds` models. | ||
logits = torch.stack([m(batch[0]) for m in self.models]).mean(0) | ||
loss = F.cross_entropy(logits, batch[1]) | ||
self.log("test_loss", loss) | ||
|
||
|
||
############################################################################################# | ||
# Step 4 / 5: Implement the KFoldLoop # | ||
# From Lightning v1.5, it is possible to implement your own loop. There is several steps # | ||
# to do so which are described in detail within the documentation # | ||
# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html. # | ||
# Here, we will implement an outer fit_loop. It means we will implement subclass the # | ||
# base Loop and wrap the current trainer `fit_loop`. # | ||
############################################################################################# | ||
|
||
|
||
############################################################################################# | ||
# Here is the `Pseudo Code` for the base Loop. # | ||
# class Loop: # | ||
# # | ||
# def run(self, ...): # | ||
# self.reset(...) # | ||
# self.on_run_start(...) # | ||
# # | ||
# while not self.done: # | ||
# self.on_advance_start(...) # | ||
# self.advance(...) # | ||
# self.on_advance_end(...) # | ||
# # | ||
# return self.on_run_end(...) # | ||
############################################################################################# | ||
|
||
|
||
class KFoldLoop(Loop): | ||
def __init__(self, num_folds: int, fit_loop: FitLoop, export_path: str): | ||
super().__init__() | ||
self.num_folds = num_folds | ||
self.fit_loop = fit_loop | ||
self.current_fold: int = 0 | ||
self.export_path = export_path | ||
|
||
@property | ||
def done(self) -> bool: | ||
return self.current_fold >= self.num_folds | ||
|
||
def reset(self) -> None: | ||
"""Nothing to reset in this loop.""" | ||
|
||
def on_run_start(self, *args: Any, **kwargs: Any) -> None: | ||
"""Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the | ||
model.""" | ||
assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) | ||
self.trainer.datamodule.setup_folds(self.num_folds) | ||
self.lightning_module_state_dict = deepcopy(self.trainer.lightning_module.state_dict()) | ||
|
||
def on_advance_start(self, *args: Any, **kwargs: Any) -> None: | ||
"""Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance.""" | ||
print(f"STARTING FOLD {self.current_fold}") | ||
assert isinstance(self.trainer.datamodule, BaseKFoldDataModule) | ||
self.trainer.datamodule.setup_fold_index(self.current_fold) | ||
|
||
def advance(self, *args: Any, **kwargs: Any) -> None: | ||
"""Used to the run a fitting and testing on the current hold.""" | ||
self._reset_fitting() # requires to reset the tracking stage. | ||
self.fit_loop.run() | ||
|
||
self._reset_testing() # requires to reset the tracking stage. | ||
self.trainer.test_loop.run() | ||
self.current_fold += 1 # increment fold tracking number. | ||
|
||
def on_advance_end(self) -> None: | ||
"""Used to save the weights of the current fold and reset the LightningModule and its optimizers.""" | ||
self.trainer.save_checkpoint(osp.join(self.export_path, f"model.{self.current_fold}.pt")) | ||
# restore the original weights + optimizers and schedulers. | ||
self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict) | ||
self.trainer.accelerator.setup_optimizers(self.trainer) | ||
|
||
def on_run_end(self) -> None: | ||
"""Used to compute the performance of the ensemble model on the test set.""" | ||
checkpoint_paths = [osp.join(self.export_path, f"model.{f_idx + 1}.pt") for f_idx in range(self.num_folds)] | ||
voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) | ||
voting_model.trainer = self.trainer | ||
# This requires to connect the new model and move it the right device. | ||
self.trainer.accelerator.connect(voting_model) | ||
self.trainer.training_type_plugin.model_to_device() | ||
self.trainer.test_loop.run() | ||
|
||
def on_save_checkpoint(self) -> Dict[str, int]: | ||
return {"current_fold": self.current_fold} | ||
|
||
def on_load_checkpoint(self, state_dict: Dict) -> None: | ||
self.current_fold = state_dict["current_fold"] | ||
|
||
def _reset_fitting(self) -> None: | ||
self.trainer.reset_train_dataloader() | ||
self.trainer.reset_val_dataloader() | ||
self.trainer.state.fn = TrainerFn.FITTING | ||
self.trainer.training = True | ||
|
||
def _reset_testing(self) -> None: | ||
self.trainer.reset_test_dataloader() | ||
self.trainer.state.fn = TrainerFn.TESTING | ||
self.trainer.testing = True | ||
|
||
def __getattr__(self, key) -> Any: | ||
# requires to be overridden as attributes of the wrapped loop are being accessed. | ||
if key not in self.__dict__: | ||
return getattr(self.fit_loop, key) | ||
return self.__dict__[key] | ||
|
||
|
||
############################################################################################# | ||
# Step 5 / 5: Connect the KFoldLoop to the Trainer # | ||
# After creating the `KFoldDataModule` and our model, the `KFoldLoop` is being connected to # | ||
# the Trainer. # | ||
# Finally, use `trainer.fit` to start the cross validation training. # | ||
############################################################################################# | ||
|
||
model = LitClassifier() | ||
datamodule = MNISTKFoldDataModule() | ||
trainer = Trainer( | ||
max_epochs=10, | ||
limit_train_batches=2, | ||
limit_val_batches=2, | ||
limit_test_batches=2, | ||
num_sanity_val_steps=0, | ||
devices=1, | ||
accelerator="auto", | ||
strategy="ddp", | ||
) | ||
trainer.fit_loop = KFoldLoop(5, trainer.fit_loop, export_path="./") | ||
trainer.fit(model, datamodule) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.