diff --git a/.gitignore b/.gitignore index 7b1247433e7b4..997886d648614 100644 --- a/.gitignore +++ b/.gitignore @@ -137,6 +137,9 @@ ENV/ Datasets/ mnist/ legacy/checkpoints/ +*.gz +*ubyte + # pl tests ml-runs/ diff --git a/CHANGELOG.md b/CHANGELOG.md index b9b567bce0a87..209a8a4671028 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -119,6 +119,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Updated precision attributes in `DeepSpeedPlugin` ([#10164](https://github.com/PyTorchLightning/pytorch-lightning/pull/10164)) * Added the ability to return a result from rank 0 in `DDPSpawnPlugin.spawn` ([#10162](https://github.com/PyTorchLightning/pytorch-lightning/pull/10162)) * Added `pytorch_lightning.lite` package ([#10175](https://github.com/PyTorchLightning/pytorch-lightning/pull/10175)) + * Added `LightningLite` documentation ([#10043](https://github.com/PyTorchLightning/pytorch-lightning/pull/10043)) + * Added `LightningLite` examples ([#9987](https://github.com/PyTorchLightning/pytorch-lightning/pull/9987)) * Make the `_LiteDataLoader` an iterator and add supports for custom dataloader ([#10279](https://github.com/PyTorchLightning/pytorch-lightning/pull/10279)) - Added `use_omegaconf` argument to `save_hparams_to_yaml` plugin ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170)) - Added `ckpt_path` argument for `Trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061)) diff --git a/docs/source/conf.py b/docs/source/conf.py index 16b2ed7509ee3..845b3b946972a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,6 +16,7 @@ import os import shutil import sys +import warnings from importlib.util import module_from_spec, spec_from_file_location import pt_lightning_sphinx_theme @@ -26,10 +27,13 @@ sys.path.insert(0, os.path.abspath(PATH_ROOT)) sys.path.append(os.path.join(PATH_RAW_NB, ".actions")) +_SHOULD_COPY_NOTEBOOKS = True + try: from helpers import HelperCLI except Exception: - raise ModuleNotFoundError("To build the code, please run: `git submodule update --init --recursive`") + _SHOULD_COPY_NOTEBOOKS = False + warnings.warn("To build the code, please run: `git submodule update --init --recursive`", stacklevel=2) FOLDER_GENERATED = "generated" SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True)) @@ -41,8 +45,8 @@ spec.loader.exec_module(about) # -- Project documents ------------------------------------------------------- - -HelperCLI.copy_notebooks(PATH_RAW_NB, PATH_HERE, "notebooks") +if _SHOULD_COPY_NOTEBOOKS: + HelperCLI.copy_notebooks(PATH_RAW_NB, PATH_HERE, "notebooks") def _transform_changelog(path_in: str, path_out: str) -> None: diff --git a/pl_examples/README.md b/pl_examples/README.md index ff6ed1bf17ad9..58cc5c64d8c4b 100644 --- a/pl_examples/README.md +++ b/pl_examples/README.md @@ -5,20 +5,31 @@ can be found in our sister library [Lightning Bolts](https://pytorch-lightning.r ______________________________________________________________________ -## Basic examples +## MNIST Examples -In this folder we add several starter examples: +5 MNIST examples showing how to gradually convert from pure PyTorch to PyTorch Lightning. -- [MNIST Classifier](./basic_examples/simple_image_classifier.py): Shows how to define the model inside the `LightningModule`. -- [Image Classifier](./basic_examples/backbone_image_classifier.py): Trains arbitrary datasets with arbitrary backbones. -- [Autoencoder](./basic_examples/autoencoder.py): Shows how the `LightningModule` can be used as a system. -- [Profiler](./basic_examples/profiler_example.py): Shows the basic usage of the PyTorch profilers and how to inspect traces in Google Chrome. -- [Image Classifier with DALI](./basic_examples/dali_image_classifier.py): Shows how to use [NVIDIA DALI](https://developer.nvidia.com/DALI) with Lightning. -- [Mnist Datamodule](.basic_examples/mnist_datamodule.py): Shows how to define a simple `LightningDataModule` using the MNIST dataset. +The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.html) from pure PyTorch is optional but it might be helpful to learn about it. + +- [MNIST with vanilla PyTorch](./basic_examples/mnist_examples/image_classifier_1_pytorch.py) +- [MNIST with LightningLite](./basic_examples/mnist_examples/image_classifier_2_lite.py) +- [MNIST LightningLite to LightningModule](./basic_examples/mnist_examples/image_classifier_3_lite_to_lightning_module.py) +- [MNIST with LightningModule](./basic_examples/mnist_examples/image_classifier_4_lightning_module.py) +- [MNIST with LightningModule + LightningDataModule](./basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py) + +______________________________________________________________________ + +## Basic Examples + +In this folder, we have 2 simple examples: + +- [Image Classifier](./basic_examples/backbone_image_classifier.py) (trains arbitrary datasets with arbitrary backbones). +- [Image Classifier + DALI](./basic_examples/mnist_examples/image_classifier_4_dali.py) (defines the model inside the `LightningModule`). +- [Autoencoder](./basic_examples/autoencoder.py) (shows how the `LightningModule` can be used as a system) ______________________________________________________________________ -## Domain examples +## Domain Examples This folder contains older examples. You should instead use the examples in [Lightning Bolts](https://pytorch-lightning.readthedocs.io/en/latest/ecosystem/bolts.html) diff --git a/pl_examples/basic_examples/README.md b/pl_examples/basic_examples/README.md index 6e8d69cc7cbfb..b1b02f90ecb24 100644 --- a/pl_examples/basic_examples/README.md +++ b/pl_examples/basic_examples/README.md @@ -2,62 +2,70 @@ Use these examples to test how Lightning works. -#### MNIST +## MNIST Examples -Trains MNIST where the model is defined inside the `LightningModule`. +Here are 5 MNIST examples showing you how to gradually convert from pure PyTorch to PyTorch Lightning. -```bash -# cpu -python simple_image_classifier.py +The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/stable/starter/lightning_lite.rst) from pure PyTorch is optional but it might be helpful to learn about it. + +#### 1. Image Classifier with Vanilla PyTorch -# gpus (any number) -python simple_image_classifier.py --trainer.gpus 2 +Trains a simple CNN over MNIST using vanilla PyTorch. -# Distributed Data Parallel -python simple_image_classifier.py --trainer.gpus 2 --trainer.accelerator ddp +```bash +# CPU +python image_classifier_1_pytorch.py ``` ______________________________________________________________________ -#### MNIST with DALI +#### 2. Image Classifier with LightningLite -The MNIST example above using [NVIDIA DALI](https://developer.nvidia.com/DALI). -Requires NVIDIA DALI to be installed based on your CUDA version, see [here](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html). +This script shows you how to scale the previous script to enable GPU and multi-GPU training using [LightningLite](https://pytorch-lightning.readthedocs.io/en/stable/starter/lightning_lite.html). ```bash -python dali_image_classifier.py +# CPU / multiple GPUs if available +python image_classifier_2_lite.py ``` ______________________________________________________________________ -#### Image classifier +#### 3. Image Classifier - Conversion from Lite to Lightning -Generic image classifier with an arbitrary backbone (ie: a simple system) +This script shows you how to prepare your conversion from [LightningLite](https://pytorch-lightning.readthedocs.io/en/stable/starter/lightning_lite.html) to `LightningModule`. ```bash -# cpu -python backbone_image_classifier.py +# CPU / multiple GPUs if available +python image_classifier_3_lite_to_lightning_module.py +``` -# gpus (any number) -python backbone_image_classifier.py --trainer.gpus 2 +______________________________________________________________________ + +#### 4. Image Classifier with LightningModule + +This script shows you the result of the conversion to the `LightningModule` and finally all the benefits you get from the Lightning ecosystem. + +```bash +# CPU +python image_classifier_4_lightning_module.py -# Distributed Data Parallel -python backbone_image_classifier.py --trainer.gpus 2 --trainer.accelerator ddp +# GPUs (any number) +python image_classifier_4_lightning_module.py --trainer.gpus 2 ``` ______________________________________________________________________ -#### Autoencoder +#### 5. Image Classifier with LightningModule and LightningDataModule -Showing the power of a system... arbitrarily complex training loops +This script shows you how to extract the data related components into a `LightningDataModule`. ```bash -# cpu -python autoencoder.py +# CPU +python image_classifier_5_lightning_datamodule.py -# gpus (any number) -python autoencoder.py --trainer.gpus 2 +# GPUs (any number) +python image_classifier_5_lightning_datamodule.py --trainer.gpus 2 -# Distributed Data Parallel -python autoencoder.py --trainer.gpus 2 --trainer.accelerator ddp +# Distributed Data Parallel (DDP) +python image_classifier_5_lightning_datamodule.py --trainer.gpus 2 --trainer.strategy 'ddp' ``` diff --git a/pl_examples/basic_examples/mnist_examples/README.md b/pl_examples/basic_examples/mnist_examples/README.md new file mode 100644 index 0000000000000..c82960af1ff22 --- /dev/null +++ b/pl_examples/basic_examples/mnist_examples/README.md @@ -0,0 +1,67 @@ +## MNIST Examples + +Here are 5 MNIST examples showing you how to gradually convert from pure PyTorch to PyTorch Lightning. + +The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/stable/lightning_lite.rst) from pure PyTorch is optional but it might be helpful to learn about it. + +#### 1. Image Classifier with Vanilla PyTorch + +Trains a simple CNN over MNIST using vanilla PyTorch. + +```bash +# CPU +python image_classifier_1_pytorch.py +``` + +______________________________________________________________________ + +#### 2. Image Classifier with LightningLite + +This script shows you how to scale the previous script to enable GPU and multi-GPU training using [LightningLite](https://pytorch-lightning.readthedocs.io/en/stable/starter/lightning_lite.html). + +```bash +# CPU / multiple GPUs if available +python image_classifier_2_lite.py +``` + +______________________________________________________________________ + +#### 3. Image Classifier - Conversion from Lite to Lightning + +This script shows you how to prepare your conversion from [LightningLite](https://pytorch-lightning.readthedocs.io/en/stable/starter/lightning_lite.html) to `LightningModule`. + +```bash +# CPU / multiple GPUs if available +python image_classifier_3_lite_to_lightning_module.py +``` + +______________________________________________________________________ + +#### 4. Image Classifier with LightningModule + +This script shows you the result of the conversion to the `LightningModule` and finally all the benefits you get from Lightning. + +```bash +# CPU +python image_classifier_4_lightning_module.py + +# GPUs (any number) +python image_classifier_4_lightning_module.py --trainer.gpus 2 +``` + +______________________________________________________________________ + +#### 5. Image Classifier with LightningModule and LightningDataModule + +This script shows you how to extract the data related components into a `LightningDataModule`. + +```bash +# CPU +python image_classifier_5_lightning_datamodule.py + +# GPUs (any number) +python image_classifier_5_lightning_datamodule.py --trainer.gpus 2 + +# Distributed Data parallel +python image_classifier_5_lightning_datamodule.py --trainer.gpus 2 --trainer.strategy 'ddp' +``` diff --git a/pl_examples/basic_examples/mnist_examples/__init__.py b/pl_examples/basic_examples/mnist_examples/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pl_examples/basic_examples/mnist_examples/image_classifier_1_pytorch.py b/pl_examples/basic_examples/mnist_examples/image_classifier_1_pytorch.py new file mode 100644 index 0000000000000..4073c485e6017 --- /dev/null +++ b/pl_examples/basic_examples/mnist_examples/image_classifier_1_pytorch.py @@ -0,0 +1,154 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision.transforms as T +from torch.optim.lr_scheduler import StepLR + +from pl_examples.basic_examples.mnist_datamodule import MNIST + +# Credit to the PyTorch Team +# Taken from https://github.com/pytorch/examples/blob/master/mnist/main.py and slightly adapted. + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def run(hparams): + + torch.manual_seed(hparams.seed) + + use_cuda = torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + + transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) + train_dataset = MNIST("./data", train=True, download=True, transform=transform) + test_dataset = MNIST("./data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=hparams.batch_size, + ) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=hparams.batch_size) + + model = Net().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=hparams.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=hparams.gamma) + + # EPOCH LOOP + for epoch in range(1, hparams.epochs + 1): + + # TRAINING LOOP + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if (batch_idx == 0) or ((batch_idx + 1) % hparams.log_interval == 0): + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + if hparams.dry_run: + break + scheduler.step() + + # TESTING LOOP + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + if hparams.dry_run: + break + + test_loss /= len(test_loader.dataset) + + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) + ) + ) + + if hparams.dry_run: + break + + if hparams.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +def main(): + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" + ) + parser.add_argument("--epochs", type=int, default=14, metavar="N", help="number of epochs to train (default: 14)") + parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)") + parser.add_argument("--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)") + parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model") + hparams = parser.parse_args() + run(hparams) + + +if __name__ == "__main__": + main() diff --git a/pl_examples/basic_examples/mnist_examples/image_classifier_2_lite.py b/pl_examples/basic_examples/mnist_examples/image_classifier_2_lite.py new file mode 100644 index 0000000000000..4240a9b7c4e08 --- /dev/null +++ b/pl_examples/basic_examples/mnist_examples/image_classifier_2_lite.py @@ -0,0 +1,161 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Here are 5 required steps to convert to LightningLite. + +1. Subclass LightningLite and override its ``run`` method. + +2. Move the body of your existing ``run`` function into the ``run`` method. + +3. Remove all ``.to``, ``.cuda`` etc calls since LightningLite will take care of it. + +4. Apply ``setup`` over each model and optimizers pair, ``setup_dataloaders`` on all your dataloaders, +and replace ``loss.backward()`` with ``self.backward(loss)``. + +5. Instantiate your LightningLite and call its ``run`` method. + +Learn more from the documentation: https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.html. +""" + +import argparse + +import torch +import torch.nn.functional as F +import torch.optim as optim +import torchvision.transforms as T +from torch.optim.lr_scheduler import StepLR +from torchmetrics.classification import Accuracy + +from pl_examples.basic_examples.mnist_datamodule import MNIST +from pl_examples.basic_examples.mnist_examples.image_classifier_1_pytorch import Net +from pytorch_lightning import seed_everything +from pytorch_lightning.lite import LightningLite # import LightningLite + + +class Lite(LightningLite): + def run(self, hparams): + self.hparams = hparams + seed_everything(hparams.seed) # instead of torch.manual_seed(...) + + transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) + # This is meant to ensure the data are download only by 1 process. + if self.is_global_zero: + MNIST("./data", download=True) + self.barrier() + train_dataset = MNIST("./data", train=True, transform=transform) + test_dataset = MNIST("./data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=hparams.batch_size, + ) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=hparams.batch_size) + + # don't forget to call `setup_dataloaders` to prepare for dataloaders for distributed training. + train_loader, test_loader = self.setup_dataloaders(train_loader, test_loader) + + model = Net() # remove call to .to(device) + optimizer = optim.Adadelta(model.parameters(), lr=hparams.lr) + + # don't forget to call `setup` to prepare for model / optimizer for distributed training. + # the model is moved automatically to the right device. + model, optimizer = self.setup(model, optimizer) + + scheduler = StepLR(optimizer, step_size=1, gamma=hparams.gamma) + + # use torchmetrics instead of manually computing the accuracy + test_acc = Accuracy().to(self.device) + + # EPOCH LOOP + for epoch in range(1, hparams.epochs + 1): + + # TRAINING LOOP + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + # NOTE: no need to call `.to(device)` on the data, target + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + self.backward(loss) # instead of loss.backward() + + optimizer.step() + if (batch_idx == 0) or ((batch_idx + 1) % hparams.log_interval == 0): + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + if hparams.dry_run: + break + + scheduler.step() + + # TESTING LOOP + model.eval() + test_loss = 0 + with torch.no_grad(): + for data, target in test_loader: + # NOTE: no need to call `.to(device)` on the data, target + output = model(data) + test_loss += F.nll_loss(output, target, reduction="sum").item() + + # WITHOUT TorchMetrics + # pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + # correct += pred.eq(target.view_as(pred)).sum().item() + + # WITH TorchMetrics + test_acc(output, target) + + if hparams.dry_run: + break + + # all_gather is used to aggregated the value across processes + test_loss = self.all_gather(test_loss).sum() / len(test_loader.dataset) + + print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc.compute():.0f}%)\n") + test_acc.reset() + + if hparams.dry_run: + break + + # When using distributed training, use `self.save` + # to ensure the current process is allowed to save a checkpoint + if hparams.save_model: + self.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="LightningLite MNIST Example") + parser.add_argument( + "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" + ) + parser.add_argument("--epochs", type=int, default=14, metavar="N", help="number of epochs to train (default: 14)") + parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)") + parser.add_argument("--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)") + parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model") + hparams = parser.parse_args() + + Lite(accelerator="auto", devices="auto").run(hparams) diff --git a/pl_examples/basic_examples/mnist_examples/image_classifier_3_lite_to_lightning_module.py b/pl_examples/basic_examples/mnist_examples/image_classifier_3_lite_to_lightning_module.py new file mode 100644 index 0000000000000..0d6925fc68c1a --- /dev/null +++ b/pl_examples/basic_examples/mnist_examples/image_classifier_3_lite_to_lightning_module.py @@ -0,0 +1,169 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Here are the steps to convert from LightningLite to a LightningModule. + +1. Start implementing the ``training_step``, ``forward``, ``train_dataloader`` and ``configure_optimizers`` +methods on the LightningLite class. + +2. Utilize those methods within the ``run`` method. + +3. Finally, switch to LightningModule and validate that your results are still reproducible (next script). + +Learn more from the documentation: https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.html. +""" + +import argparse + +import torch +import torch.nn.functional as F +import torch.optim as optim +import torchvision.transforms as T +from torch.optim.lr_scheduler import StepLR +from torchmetrics import Accuracy + +from pl_examples.basic_examples.mnist_datamodule import MNIST +from pl_examples.basic_examples.mnist_examples.image_classifier_1_pytorch import Net +from pytorch_lightning import seed_everything +from pytorch_lightning.lite import LightningLite + + +class Lite(LightningLite): + """Lite is starting to look like a LightningModule.""" + + def run(self, hparams): + self.hparams = hparams + seed_everything(hparams.seed) # instead of torch.manual_seed(...) + + self.model = Net() + [optimizer], [scheduler] = self.configure_optimizers() + model, optimizer = self.setup(self.model, optimizer) + + if self.is_global_zero: + # In multi-device training, this code will only run on the first process / GPU + self.prepare_data() + + train_loader, test_loader = self.setup_dataloaders(self.train_dataloader(), self.train_dataloader()) + + self.test_acc = Accuracy().to(self.device) + + # EPOCH LOOP + for epoch in range(1, hparams.epochs + 1): + + # TRAINING LOOP + self.model.train() + for batch_idx, batch in enumerate(train_loader): + optimizer.zero_grad() + loss = self.training_step(batch, batch_idx) + self.backward(loss) + optimizer.step() + + if (batch_idx == 0) or ((batch_idx + 1) % hparams.log_interval == 0): + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + (batch_idx + 1) * self.hparams.batch_size, + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + if hparams.dry_run: + break + + scheduler.step() + + # TESTING LOOP + self.model.eval() + test_loss = 0 + with torch.no_grad(): + for batch_idx, batch in enumerate(test_loader): + test_loss += self.test_step(batch, batch_idx) + if hparams.dry_run: + break + + test_loss = self.all_gather(test_loss).sum() / len(test_loader.dataset) + + print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({self.test_acc.compute():.0f}%)\n") + self.test_acc.reset() + + if hparams.dry_run: + break + + if hparams.save_model: + self.save(model.state_dict(), "mnist_cnn.pt") + + # Methods for the `LightningModule` conversion + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + """Here you compute and return the training loss and compute extra training metrics.""" + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y.long()) + return loss + + def test_step(self, batch, batch_idx): + """Here you compute and return the testing loss and compute extra testing metrics.""" + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y.long()) + self.test_acc(logits, y.long()) + return loss + + def configure_optimizers(self): + optimizer = optim.Adadelta(self.model.parameters(), lr=self.hparams.lr) + return [optimizer], [StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)] + + # Methods for the `LightningDataModule` conversion + + @property + def transform(self): + return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) + + def prepare_data(self) -> None: + MNIST("./data", download=True) + + def train_dataloader(self): + train_dataset = MNIST("./data", train=True, download=False, transform=self.transform) + return torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.batch_size) + + def test_dataloader(self): + test_dataset = MNIST("./data", train=False, download=False, transform=self.transform) + return torch.utils.data.DataLoader(test_dataset, batch_size=self.hparams.batch_size) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="LightningLite to LightningModule MNIST Example") + parser.add_argument( + "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" + ) + parser.add_argument("--epochs", type=int, default=14, metavar="N", help="number of epochs to train (default: 14)") + parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)") + parser.add_argument("--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)") + parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model") + hparams = parser.parse_args() + + Lite(accelerator="auto", devices="auto").run(hparams) diff --git a/pl_examples/basic_examples/mnist_examples/image_classifier_4_lightning_module.py b/pl_examples/basic_examples/mnist_examples/image_classifier_4_lightning_module.py new file mode 100644 index 0000000000000..cb67d3446c51a --- /dev/null +++ b/pl_examples/basic_examples/mnist_examples/image_classifier_4_lightning_module.py @@ -0,0 +1,81 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Simple MNIST image classifier example with LightningModule. + +To run: python image_classifier_4_lightning_module.py --trainer.max_epochs=50 +""" +import torch +import torchvision.transforms as T +from torch.nn import functional as F + +from pl_examples import cli_lightning_logo +from pl_examples.basic_examples.mnist_datamodule import MNIST +from pl_examples.basic_examples.mnist_examples.image_classifier_1_pytorch import Net +from pytorch_lightning import LightningModule +from pytorch_lightning.utilities.cli import LightningCLI + + +class ImageClassifier(LightningModule): + def __init__(self, model=None, lr=1.0, gamma=0.7, batch_size=32): + super().__init__() + self.save_hyperparameters() + self.model = model or Net() + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y.long()) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y.long()) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adadelta(self.model.parameters(), lr=self.hparams.lr) + return [optimizer], [torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)] + + # Methods for the `LightningDataModule` conversion + + @property + def transform(self): + return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) + + def prepare_data(self) -> None: + MNIST("./data", download=True) + + def train_dataloader(self): + train_dataset = MNIST("./data", train=True, download=False, transform=self.transform) + return torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.batch_size) + + def test_dataloader(self): + test_dataset = MNIST("./data", train=False, download=False, transform=self.transform) + return torch.utils.data.DataLoader(test_dataset, batch_size=self.hparams.batch_size) + + +def cli_main(): + # The LightningCLI removes all the boilerplate associated with arguments parsing. This is purely optional. + cli = LightningCLI(ImageClassifier, seed_everything_default=42, save_config_overwrite=True, run=False) + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) + + +if __name__ == "__main__": + cli_lightning_logo() + cli_main() diff --git a/pl_examples/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py b/pl_examples/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py new file mode 100644 index 0000000000000..4020d101ccab6 --- /dev/null +++ b/pl_examples/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py @@ -0,0 +1,87 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Simple MNIST image classifier example with LightningModule and LightningDataModule. + +To run: python image_classifier_5_lightning_datamodule.py --trainer.max_epochs=50 +""" +import torch +import torchvision.transforms as T +from torch.nn import functional as F + +from pl_examples import cli_lightning_logo +from pl_examples.basic_examples.mnist_datamodule import MNIST +from pl_examples.basic_examples.mnist_examples.image_classifier_1_pytorch import Net +from pytorch_lightning import LightningDataModule, LightningModule +from pytorch_lightning.utilities.cli import LightningCLI + + +class ImageClassifier(LightningModule): + def __init__(self, model, lr=1.0, gamma=0.7, batch_size=32): + super().__init__() + self.save_hyperparameters() + self.model = model or Net() + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y.long()) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y.long()) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adadelta(self.model.parameters(), lr=self.hparams.lr) + return [optimizer], [torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)] + + +class MNISTDataModule(LightningDataModule): + def __init__(self, batch_size=32): + super().__init__() + self.save_hyperparameters() + + @property + def transform(self): + return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) + + def prepare_data(self) -> None: + MNIST("./data", download=True) + + def train_dataloader(self): + train_dataset = MNIST("./data", train=True, download=False, transform=self.transform) + return torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.batch_size) + + def test_dataloader(self): + test_dataset = MNIST("./data", train=False, download=False, transform=self.transform) + return torch.utils.data.DataLoader(test_dataset, batch_size=self.hparams.batch_size) + + +def cli_main(): + # The LightningCLI removes all the boilerplate associated with arguments parsing. This is purely optional. + cli = LightningCLI( + ImageClassifier, MNISTDataModule, seed_everything_default=42, save_config_overwrite=True, run=False + ) + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) + + +if __name__ == "__main__": + cli_lightning_logo() + cli_main() diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py deleted file mode 100644 index 146f25c27c0d4..0000000000000 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""MNIST simple image classifier example. - -To run: python simple_image_classifier.py --trainer.max_epochs=50 -""" - -import torch -from torch.nn import functional as F - -import pytorch_lightning as pl -from pl_examples import cli_lightning_logo -from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule -from pytorch_lightning.utilities.cli import LightningCLI - - -class LitClassifier(pl.LightningModule): - """ - >>> LitClassifier() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - LitClassifier( - (l1): Linear(...) - (l2): Linear(...) - ) - """ - - def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.0001): - super().__init__() - self.save_hyperparameters() - - self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) - self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) - - def forward(self, x): - x = x.view(x.size(0), -1) - x = torch.relu(self.l1(x)) - x = torch.relu(self.l2(x)) - return x - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - return loss - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - self.log("valid_loss", loss) - - def test_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - self.log("test_loss", loss) - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - - -def cli_main(): - cli = LightningCLI( - LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False - ) - cli.trainer.fit(cli.model, datamodule=cli.datamodule) - cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) - - -if __name__ == "__main__": - cli_lightning_logo() - cli_main() diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 48492c8ce7f04..26a6c8aa89f67 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -17,24 +17,21 @@ tensorboard --logdir default """ -import os from argparse import ArgumentParser, Namespace import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import DataLoader from pl_examples import cli_lightning_logo -from pl_examples.basic_examples.mnist_datamodule import MNIST -from pytorch_lightning.core import LightningDataModule, LightningModule +from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule +from pytorch_lightning.core import LightningModule from pytorch_lightning.trainer import Trainer from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: import torchvision - from torchvision import transforms class Generator(nn.Module): @@ -212,35 +209,6 @@ def on_epoch_end(self): self.logger.experiment.add_image("generated_images", grid, self.current_epoch) -class MNISTDataModule(LightningDataModule): - """ - >>> MNISTDataModule() # doctest: +ELLIPSIS - <...generative_adversarial_net.MNISTDataModule object at ...> - """ - - def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4): - super().__init__() - self.batch_size = batch_size - self.data_path = data_path - self.num_workers = num_workers - - self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) - self.dims = (1, 28, 28) - - def prepare_data(self, stage=None): - # Use this method to do things that might write to disk or that need to be done only from a single GPU - # in distributed settings. Like downloading the dataset for the first time. - MNIST(self.data_path, train=True, download=True, transform=transforms.ToTensor()) - - def setup(self, stage=None): - # There are also data operations you might want to perform on every GPU, such as applying transforms - # defined explicitly in your datamodule or assigned in init. - self.mnist_train = MNIST(self.data_path, train=True, transform=self.transform) - - def train_dataloader(self): - return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers) - - def main(args: Namespace) -> None: # ------------------------ # 1 INIT LIGHTNING MODEL @@ -250,7 +218,7 @@ def main(args: Namespace) -> None: # ------------------------ # 2 INIT TRAINER # ------------------------ - # If use distubuted training PyTorch recommends to use DistributedDataParallel. + # If use distributed training PyTorch recommends to use DistributedDataParallel. # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel dm = MNISTDataModule.from_argparse_args(args) trainer = Trainer.from_argparse_args(args) diff --git a/pl_examples/integration_examples/__init__.py b/pl_examples/integration_examples/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/integration_examples/dali_image_classifier.py similarity index 100% rename from pl_examples/basic_examples/dali_image_classifier.py rename to pl_examples/integration_examples/dali_image_classifier.py diff --git a/pl_examples/loop_examples/kfold.py b/pl_examples/loop_examples/kfold.py index 632734b30137c..bd14d42eb796f 100644 --- a/pl_examples/loop_examples/kfold.py +++ b/pl_examples/loop_examples/kfold.py @@ -27,7 +27,7 @@ from pl_examples import _DATASETS_PATH from pl_examples.basic_examples.mnist_datamodule import MNIST -from pl_examples.basic_examples.simple_image_classifier import LitClassifier +from pl_examples.basic_examples.mnist_examples.image_classifier_4_lightning_module import ImageClassifier from pytorch_lightning import LightningDataModule, seed_everything, Trainer from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loops.base import Loop @@ -241,7 +241,7 @@ def __getattr__(self, key) -> Any: ############################################################################################# if __name__ == "__main__": - model = LitClassifier() + model = ImageClassifier() datamodule = MNISTKFoldDataModule() trainer = Trainer( max_epochs=10, diff --git a/pl_examples/loop_examples/mnist_lite.py b/pl_examples/loop_examples/mnist_lite.py new file mode 100644 index 0000000000000..4d59ef326f408 --- /dev/null +++ b/pl_examples/loop_examples/mnist_lite.py @@ -0,0 +1,188 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from typing import Any + +import torch +import torch.nn.functional as F +import torch.optim as optim +import torchvision.transforms as T +from torch.optim.lr_scheduler import StepLR +from torchmetrics import Accuracy + +from pl_examples.basic_examples.mnist_datamodule import MNIST +from pl_examples.basic_examples.mnist_examples.image_classifier_1_pytorch import Net +from pytorch_lightning import seed_everything +from pytorch_lightning.lite import LightningLite +from pytorch_lightning.loops import Loop + + +class TrainLoop(Loop): + def __init__(self, lite, args, model, optimizer, scheduler, dataloader): + super().__init__() + self.lite = lite + self.args = args + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.dataloader = dataloader + self.dataloader_iter = None + + @property + def done(self) -> bool: + return False + + def reset(self): + self.dataloader_iter = enumerate(self.dataloader) + + def advance(self, epoch) -> None: + batch_idx, (data, target) = next(self.dataloader_iter) + self.optimizer.zero_grad() + output = self.model(data) + loss = F.nll_loss(output, target) + self.lite.backward(loss) + self.optimizer.step() + + if (batch_idx == 0) or ((batch_idx + 1) % self.args.log_interval == 0): + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(self.dataloader), + len(self.dataloader.dataset), + 100.0 * batch_idx / len(self.dataloader), + loss.item(), + ) + ) + + if self.args.dry_run: + raise StopIteration + + def on_run_end(self): + self.scheduler.step() + self.dataloader_iter = None + + +class TestLoop(Loop): + def __init__(self, lite, args, model, dataloader): + super().__init__() + self.lite = lite + self.args = args + self.model = model + self.dataloader = dataloader + self.dataloader_iter = None + self.accuracy = Accuracy().to(lite.device) + self.test_loss = 0 + + @property + def done(self) -> bool: + return False + + def reset(self): + self.dataloader_iter = enumerate(self.dataloader) + self.test_loss = 0 + self.accuracy.reset() + + def advance(self) -> None: + _, (data, target) = next(self.dataloader_iter) + output = self.model(data) + self.test_loss += F.nll_loss(output, target) + self.accuracy(output, target) + + if self.args.dry_run: + raise StopIteration + + def on_run_end(self): + test_loss = self.lite.all_gather(self.test_loss).sum() / len(self.dataloader.dataset) + + if self.lite.is_global_zero: + print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({self.accuracy.compute():.0f}%)\n") + + +class MainLoop(Loop): + def __init__(self, lite, args, model, optimizer, scheduler, train_loader, test_loader): + super().__init__() + self.lite = lite + self.args = args + self.epoch = 0 + self.train_loop = TrainLoop(self.lite, self.args, model, optimizer, scheduler, train_loader) + self.test_loop = TestLoop(self.lite, self.args, model, test_loader) + + @property + def done(self) -> bool: + return self.epoch >= self.args.epochs + + def reset(self): + pass + + def advance(self, *args: Any, **kwargs: Any) -> None: + self.train_loop.run(self.epoch) + self.test_loop.run() + + if self.args.dry_run: + raise StopIteration + + self.epoch += 1 + + +class Lite(LightningLite): + def run(self, hparams): + transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) + if self.is_global_zero: + MNIST("./data", download=True) + self.barrier() + train_dataset = MNIST("./data", train=True, transform=transform) + test_dataset = MNIST("./data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(train_dataset, hparams.batch_size) + test_loader = torch.utils.data.DataLoader(test_dataset, hparams.test_batch_size) + + train_loader, test_loader = self.setup_dataloaders(train_loader, test_loader) + + model = Net() + optimizer = optim.Adadelta(model.parameters(), lr=hparams.lr) + + model, optimizer = self.setup(model, optimizer) + scheduler = StepLR(optimizer, step_size=1, gamma=hparams.gamma) + + MainLoop(self, hparams, model, optimizer, scheduler, train_loader, test_loader).run() + + if hparams.save_model and self.is_global_zero: + self.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="LightningLite MNIST Example with Lightning Loops.") + parser.add_argument( + "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" + ) + parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 14)") + parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)") + parser.add_argument("--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)") + parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model") + hparams = parser.parse_args() + + seed_everything(hparams.seed) + + Lite(accelerator="cpu", devices=1).run(hparams) diff --git a/pl_examples/run_examples.sh b/pl_examples/run_examples.sh index 7555e472d24e2..4a15c3367d35f 100755 --- a/pl_examples/run_examples.sh +++ b/pl_examples/run_examples.sh @@ -9,8 +9,28 @@ args=" --trainer.limit_val_batches=2 --trainer.limit_test_batches=2 --trainer.limit_predict_batches=2 + --optimizer=Adam " -python "${dir_path}/basic_examples/simple_image_classifier.py" ${args} "$@" python "${dir_path}/basic_examples/backbone_image_classifier.py" ${args} "$@" python "${dir_path}/basic_examples/autoencoder.py" ${args} "$@" + + +args="--dry-run" +python "${dir_path}/basic_examples/mnist_examples/image_classifier_1_pytorch.py" ${args} +python "${dir_path}/basic_examples/mnist_examples/image_classifier_2_lite.py" ${args} +python "${dir_path}/basic_examples/mnist_examples/image_classifier_3_lite_to_lightning_module.py" ${args} +python "${dir_path}/loop_examples/mnist_lite.py" ${args} + + +args=" + --trainer.max_epochs=1 + --trainer.limit_train_batches=2 + --trainer.limit_val_batches=2 + --trainer.limit_test_batches=2 + --trainer.limit_predict_batches=2 + --optimizer=Adam +" + +python "${dir_path}/basic_examples/mnist_examples/image_classifier_4_lightning_module.py" ${args} "$@" +python "${dir_path}/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py" ${args} "$@" diff --git a/pl_examples/test_examples.py b/pl_examples/test_examples.py index b0b451692e4b6..19d09836ef34c 100644 --- a/pl_examples/test_examples.py +++ b/pl_examples/test_examples.py @@ -34,7 +34,7 @@ @RunIf(min_gpus=1, skip_windows=True) @pytest.mark.parametrize("cli_args", [ARGS_GPU]) def test_examples_mnist_dali(tmpdir, cli_args): - from pl_examples.basic_examples.dali_image_classifier import cli_main + from pl_examples.integration_examples.dali_image_classifier import cli_main # update the temp dir cli_args = cli_args % {"tmpdir": tmpdir} diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index d9acba70bcba1..8b6f072c57adc 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -14,7 +14,7 @@ import functools import inspect from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Type, Union +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Sized, Type, Union import torch from torch import nn as nn @@ -150,7 +150,7 @@ def _replace_dataloader_init_method() -> Generator: class _LiteDataLoader: - def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) -> None: + def __init__(self, dataloader: Union[Iterable, DataLoader], device: Optional[torch.device] = None) -> None: """The LiteDataLoader is an extension of an Iterator. It would move the data to the device automatically if the device is specified. @@ -164,6 +164,11 @@ def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) self._dataloader = dataloader self._device = device + def __len__(self) -> Union[int, float]: + if isinstance(self._dataloader, Sized): + return len(self._dataloader) + return float("inf") + @property def device(self) -> Optional[torch.device]: return self._device diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index d396ec1e60174..aa9e9b401f40a 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -64,7 +64,7 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool: return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset) -def has_len(dataloader: DataLoader) -> bool: +def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or infinite dataloader. diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 4dd7b4a890648..8fc4f7e9c6e53 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -93,6 +93,8 @@ def test_lite_dataloader_device_placement(src_device, dest_device): batch0 = next(iterator) assert batch0 == 0 + assert len(lite_dataloader) == 4 + def test_lite_optimizer_wraps(): """Test that the LiteOptimizer fully wraps the optimizer.""" diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 1346cea295d54..f4b760dd75291 100755 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -87,9 +87,9 @@ fi # report+="Ran\ttests/plugins/environments/torch_elastic_deadlock.py\n" # test that a user can manually launch individual processes -args="--trainer.gpus 2 --trainer.accelerator ddp --trainer.max_epochs=1 --trainer.limit_train_batches=1 --trainer.limit_val_batches=1 --trainer.limit_test_batches=1" -MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python pl_examples/basic_examples/simple_image_classifier.py ${args} & -MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=0 python pl_examples/basic_examples/simple_image_classifier.py ${args} +args="--trainer.gpus 2 --trainer.strategy ddp --trainer.max_epochs=1 --trainer.limit_train_batches=1 --trainer.limit_val_batches=1 --trainer.limit_test_batches=1" +MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python pl_examples/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py ${args} & +MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=0 python pl_examples/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py ${args} report+="Ran\tmanual ddp launch test\n" # echo test report