diff --git a/.gitignore b/.gitignore index 7b1247433e7b4..4229c050e9b7f 100644 --- a/.gitignore +++ b/.gitignore @@ -137,6 +137,8 @@ ENV/ Datasets/ mnist/ legacy/checkpoints/ +*.gz +*ubyte # pl tests ml-runs/ diff --git a/grid_generated_0.png b/grid_generated_0.png new file mode 100644 index 0000000000000..77820f68637fd Binary files /dev/null and b/grid_generated_0.png differ diff --git a/grid_ori_0.png b/grid_ori_0.png new file mode 100644 index 0000000000000..497e4973b884c Binary files /dev/null and b/grid_ori_0.png differ diff --git a/pl_examples/README.md b/pl_examples/README.md index 08070015357b0..e5c82c9bdcc83 100644 --- a/pl_examples/README.md +++ b/pl_examples/README.md @@ -5,17 +5,31 @@ can be found in our sister library [lightning-bolts](https://pytorch-lightning.r ______________________________________________________________________ -## Basic examples +## MNIST Examples -In this folder we add 3 simple examples: +5 MNIST examples showing how to gradually convert from pure PyTorch to PyTorch Lightning. + +The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst) from pure PyTorch is optional but it might helpful to learn about it. + +- [MNIST with vanilla PyTorch](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist_examples/image_classifier_1_pytorch.py) +- [MNIST with LightningLite](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist_examples/image_classifier_2_lite.py) +- [MNIST LightningLite to LightningModule](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist_examples/image_classifier_3_lite_to_lightning.py) +- [MNIST with LightningModule](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist_examples/image_classifier_4_lightning.py) +- [MNIST with LightningModule + LightningDataModule](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py) + +______________________________________________________________________ + +## Basic Examples + +In this folder, we add 2 simple examples: -- [MNIST Classifier](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/simple_image_classifier.py) (defines the model inside the `LightningModule`). - [Image Classifier](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/backbone_image_classifier.py) (trains arbitrary datasets with arbitrary backbones). +- [Image Classifier + DALI](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist_examples/image_classifier_4_dali.py) (defines the model inside the `LightningModule`). - [Autoencoder](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/autoencoder.py) (shows how the `LightningModule` can be used as a system) ______________________________________________________________________ -## Domain examples +## Domain Examples This folder contains older examples. You should instead use the examples in [lightning-bolts](https://pytorch-lightning.readthedocs.io/en/latest/ecosystem/bolts.html) diff --git a/pl_examples/basic_examples/README.md b/pl_examples/basic_examples/README.md index fd7824140d470..05440cbad6689 100644 --- a/pl_examples/basic_examples/README.md +++ b/pl_examples/basic_examples/README.md @@ -2,47 +2,70 @@ Use these examples to test how lightning works. -#### MNIST +## MNIST Examples -Trains MNIST where the model is defined inside the `LightningModule`. +5 MNIST examples showing how to gradually convert from pure PyTorch to PyTorch Lightning. + +The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst) from pure PyTorch is optional but it might helpful to learn about it. + +#### 1 . Image Classifier with Vanilla PyTorch + +Trains a simple CNN over MNIST using vanilla PyTorch. ```bash # cpu -python simple_image_classifier.py +python mnist_examples/image_classifier_1_pytorch.py +``` -# gpus (any number) -python simple_image_classifier.py --trainer.gpus 2 +______________________________________________________________________ -# dataparallel -python simple_image_classifier.py --trainer.gpus 2 --trainer.accelerator 'dp' +#### 2. Image Classifier with LightningLite + +Trains a simple CNN over MNIST using [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst). + +```bash +# cpu / multiple gpus if available +python mnist_examples/image_classifier_2_lite.py ``` ______________________________________________________________________ -#### MNIST with DALI +Trains a simple CNN over MNIST where `LightningLite` is almost a `LightningModule`. -The MNIST example above using [NVIDIA DALI](https://developer.nvidia.com/DALI). -Requires NVIDIA DALI to be installed based on your CUDA version, see [here](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html). +```bash +# cpu / multiple gpus if available +python mnist_examples/image_classifier_3_lite_to_lightning.py +``` + +______________________________________________________________________ + +#### 4. Image Classifier with LightningModule + +Trains a simple CNN over MNIST with `Lightning Trainer` and the converted `LightningModule`. ```bash -python dali_image_classifier.py +# cpu +python mnist_examples/image_classifier_4_lightning.py + +# gpus (any number) +python mnist_examples/image_classifier_4_lightning.py --trainer.gpus 2 ``` ______________________________________________________________________ -#### Image classifier +#### 5. Image Classifier with LightningModule + LightningDataModule -Generic image classifier with an arbitrary backbone (ie: a simple system) +Trains a simple CNN over MNIST with `Lightning Trainer` and the converted `LightningModule` and `LightningDataModule` ```bash # cpu -python backbone_image_classifier.py +python mnist_examples/image_classifier_5_lightning_datamodule.py # gpus (any number) -python backbone_image_classifier.py --trainer.gpus 2 +python mnist_examples/image_classifier_5_lightning_datamodule.py --trainer.gpus 2 -# dataparallel -python backbone_image_classifier.py --trainer.gpus 2 --trainer.accelerator 'dp' +# data parallel +python mnist_examples/image_classifier_5_lightning_datamodule.py --trainer.gpus 2 --trainer.accelerator 'dp' ``` ______________________________________________________________________ diff --git a/pl_examples/basic_examples/mnist_examples/README.md b/pl_examples/basic_examples/mnist_examples/README.md new file mode 100644 index 0000000000000..323273d9ff718 --- /dev/null +++ b/pl_examples/basic_examples/mnist_examples/README.md @@ -0,0 +1,67 @@ +## MNIST Examples + +5 MNIST examples showing how to gradually convert from pure PyTorch to PyTorch Lightning. + +The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst) from pure PyTorch is optional but it might helpful to learn about it. + +#### 1 . Image Classifier with Vanilla PyTorch + +Trains a simple CNN over MNIST using vanilla PyTorch. + +```bash +# cpu +python image_classifier_1_pytorch.py +``` + +______________________________________________________________________ + +#### 2. Image Classifier with LightningLite + +Trains a simple CNN over MNIST using [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst). + +```bash +# cpu / multiple gpus if available +python image_classifier_2_lite.py +``` + +______________________________________________________________________ + +#### 3. Image Classifier - Conversion Lite to Lightning + +Trains a simple CNN over MNIST where `LightningLite` is almost a `LightningModule`. + +```bash +# cpu / multiple gpus if available +python image_classifier_3_lite_to_lightning.py +``` + +______________________________________________________________________ + +#### 4. Image Classifier with LightningModule + +Trains a simple CNN over MNIST with `Lightning Trainer` and the converted `LightningModule`. + +```bash +# cpu +python mnist_examples/image_classifier_4_lightning.py + +# gpus (any number) +python mnist_examples/image_classifier_4_lightning.py --trainer.gpus 2 +``` + +______________________________________________________________________ + +#### 5. Image Classifier with LightningModule + LightningDataModule + +Trains a simple CNN over MNIST with `Lightning Trainer` and the converted `LightningModule` and `LightningDataModule` + +```bash +# cpu +python image_classifier_5_lightning_datamodule.py + +# gpus (any number) +python image_classifier_5_lightning_datamodule.py --trainer.gpus 2 + +# dataparallel +python image_classifier_5_lightning_datamodule.py --trainer.gpus 2 --trainer.accelerator 'dp' +``` diff --git a/pl_examples/lite_examples/__init__.py b/pl_examples/basic_examples/mnist_examples/__init__.py similarity index 100% rename from pl_examples/lite_examples/__init__.py rename to pl_examples/basic_examples/mnist_examples/__init__.py diff --git a/pl_examples/basic_examples/mnist_examples/image_classifier_1_pytorch.py b/pl_examples/basic_examples/mnist_examples/image_classifier_1_pytorch.py new file mode 100644 index 0000000000000..e7449473194ed --- /dev/null +++ b/pl_examples/basic_examples/mnist_examples/image_classifier_1_pytorch.py @@ -0,0 +1,160 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision.transforms as T +from torch.optim.lr_scheduler import StepLR + +from pl_examples.basic_examples.mnist_datamodule import MNIST + +# Credit to the PyTorch Team +# Taken from https://github.com/pytorch/examples/blob/master/mnist/main.py and slightly adapted. + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if (batch_idx == 0) or ((batch_idx + 1) % args.log_interval == 0): + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + if args.dry_run: + break + + +def test(args, model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + if args.dry_run: + break + + test_loss /= len(test_loader.dataset) + + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) + ) + ) + + +def main(): + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" + ) + parser.add_argument("--epochs", type=int, default=14, metavar="N", help="number of epochs to train (default: 14)") + parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)") + parser.add_argument("--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)") + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model") + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} + if use_cuda: + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) + train_dataset = MNIST("./data", train=True, download=True, transform=transform) + test_dataset = MNIST("./data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs) + test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs) + + model = Net().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(args, model, device, test_loader) + scheduler.step() + + if args.dry_run: + break + + if args.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == "__main__": + main() diff --git a/pl_examples/basic_examples/mnist_examples/image_classifier_2_lite.py b/pl_examples/basic_examples/mnist_examples/image_classifier_2_lite.py new file mode 100644 index 0000000000000..78677cdf33bc4 --- /dev/null +++ b/pl_examples/basic_examples/mnist_examples/image_classifier_2_lite.py @@ -0,0 +1,131 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import torch +import torch.nn.functional as F +import torch.optim as optim +import torchvision.transforms as T +from torch.optim.lr_scheduler import StepLR +from torchmetrics.classification import Accuracy + +from pl_examples.basic_examples.mnist_datamodule import MNIST +from pl_examples.basic_examples.mnist_examples.image_classifier_1_pytorch import Net +from pytorch_lightning import seed_everything +from pytorch_lightning.lite import LightningLite + + +def train(lite, args, model, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + lite.backward(loss) + optimizer.step() + if (batch_idx == 0) or ((batch_idx + 1) % args.log_interval == 0): + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + if args.dry_run: + break + + +def test(lite, args, model, test_loader): + model.eval() + test_loss = 0 + acc = Accuracy().to(lite.device) + with torch.no_grad(): + for data, target in test_loader: + output = model(data) + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + acc.update(output, target) + if args.dry_run: + break + + test_loss = lite.all_gather(test_loss).sum() / len(test_loader.dataset) + + if lite.is_global_zero: + print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({acc.compute():.0f}%)\n") + + +class Lite(LightningLite): + def run(self, args): + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} + transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) + train_dataset = MNIST("./data", train=True, download=True, transform=transform) + test_dataset = MNIST("./data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs) + test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs) + + train_loader, test_loader = self.setup_dataloaders(train_loader, test_loader) + + model = Net() + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + model, optimizer = self.setup(model, optimizer) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(self, args, model, train_loader, optimizer, epoch) + test(self, args, model, test_loader) + scheduler.step() + + if args.dry_run: + break + + if args.save_model and self.is_global_zero: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="LightningLite MNIST Example") + parser.add_argument( + "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" + ) + parser.add_argument("--epochs", type=int, default=14, metavar="N", help="number of epochs to train (default: 14)") + parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)") + parser.add_argument("--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)") + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model") + args = parser.parse_args() + + seed_everything(args.seed) + + if torch.cuda.is_available(): + lite_kwargs = {"accelerator": "gpu", "devices": torch.cuda.device_count()} + else: + lite_kwargs = {"accelerator": "cpu"} + + Lite(**lite_kwargs).run(args) diff --git a/pl_examples/basic_examples/mnist_examples/image_classifier_3_lite_to_lightning.py b/pl_examples/basic_examples/mnist_examples/image_classifier_3_lite_to_lightning.py new file mode 100644 index 0000000000000..223f23312586e --- /dev/null +++ b/pl_examples/basic_examples/mnist_examples/image_classifier_3_lite_to_lightning.py @@ -0,0 +1,166 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import torch +import torch.nn.functional as F +import torch.optim as optim +import torchvision.transforms as T +from torch.optim.lr_scheduler import StepLR +from torchmetrics import Accuracy + +from pl_examples.basic_examples.mnist_datamodule import MNIST +from pl_examples.basic_examples.mnist_examples.image_classifier_1_pytorch import Net +from pytorch_lightning import seed_everything +from pytorch_lightning.lite import LightningLite + + +def train(lite, args, model, train_loader, optimizer, epoch): + model.train() + for batch_idx, batch in enumerate(train_loader): + optimizer.zero_grad() + loss = lite.training_step(batch, batch_idx) + lite.backward(loss) + optimizer.step() + if (batch_idx == 0) or ((batch_idx + 1) % args.log_interval == 0): + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(batch[0]), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + if args.dry_run: + break + + +def test(lite, args, model, test_loader): + model.eval() + test_loss = 0 + with torch.no_grad(): + for batch_idx, batch in enumerate(test_loader): + test_loss += lite.test_step(batch, batch_idx) + if args.dry_run: + break + + test_loss = lite.all_gather(test_loss).sum() / len(test_loader.dataset) + + if lite.is_global_zero: + print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({lite.test_acc.compute():.0f}%)\n") + + +class Lite(LightningLite): + + """`Lite` is starting to look like a `LightningModule`.""" + + def run(self, hparams): + self.hparams = hparams + + self.model = Net() + [optimizer], [scheduler] = self.configure_optimizers() + model, optimizer = self.setup(self.model, optimizer) + + if self.is_global_zero: + self.prepare_data() + + train_loader, test_loader = self.setup_dataloaders(self.train_dataloader(), self.train_dataloader()) + + self.test_acc = Accuracy() + + for epoch in range(1, hparams.epochs + 1): + train(self, hparams, model, train_loader, optimizer, epoch) + test(self, hparams, model, test_loader) + scheduler.step() + + if args.dry_run: + break + + if hparams.save_model and self.is_global_zero: + torch.save(model.state_dict(), "mnist_cnn.pt") + + # Functions for the `LightningModule` conversion + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y.long()) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y.long()) + self.test_acc(logits, y.long()) + return loss + + def configure_optimizers(self): + optimizer = optim.Adadelta(self.model.parameters(), lr=self.hparams.lr) + return [optimizer], [StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)] + + # Functions for the `LightningDataModule` conversion + + @property + def transform(self): + return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) + + def prepare_data(self) -> None: + MNIST("./data", download=True) + + def train_dataloader(self): + train_dataset = MNIST("./data", train=True, download=False, transform=self.transform) + return torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.batch_size) + + def test_dataloader(self): + test_dataset = MNIST("./data", train=False, download=False, transform=self.transform) + return torch.utils.data.DataLoader(test_dataset, batch_size=self.hparams.batch_size) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="LightningLite to LightningModule MNIST Example") + parser.add_argument( + "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" + ) + parser.add_argument("--epochs", type=int, default=14, metavar="N", help="number of epochs to train (default: 14)") + parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)") + parser.add_argument("--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)") + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model") + args = parser.parse_args() + + seed_everything(args.seed) + + if torch.cuda.is_available(): + lite_kwargs = {"accelerator": "gpu", "devices": torch.cuda.device_count()} + else: + lite_kwargs = {"accelerator": "cpu"} + + Lite(**lite_kwargs).run(args) diff --git a/pl_examples/basic_examples/mnist_examples/image_classifier_4_lightning.py b/pl_examples/basic_examples/mnist_examples/image_classifier_4_lightning.py new file mode 100644 index 0000000000000..a414d96281b01 --- /dev/null +++ b/pl_examples/basic_examples/mnist_examples/image_classifier_4_lightning.py @@ -0,0 +1,86 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MNIST simple image classifier example with LightningModule. + +To run: python image_classifier_4_lightning.py --trainer.max_epochs=50 +""" +import torch +import torchvision.transforms as T +from torch.nn import functional as F +from torchmetrics.classification import Accuracy + +from pl_examples import cli_lightning_logo +from pl_examples.basic_examples.mnist_datamodule import MNIST +from pl_examples.basic_examples.mnist_examples.image_classifier_1_pytorch import Net +from pytorch_lightning import LightningModule +from pytorch_lightning.utilities.cli import LightningCLI + + +class ImageClassifier(LightningModule): + def __init__(self, model, lr=1.0, gamma=0.7, batch_size=32): + super().__init__() + self.save_hyperparameters() + self.model = model or Net() + self.test_acc = Accuracy() + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y.long()) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y.long()) + self.test_acc(logits, y.long()) + return loss + + def test_epoch_end(self, *_) -> None: + self.log("test_acc", self.test_acc.compute()) + + def configure_optimizers(self): + optimizer = torch.optim.Adadelta(self.model.parameters(), lr=self.hparams.lr) + return [optimizer], [torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)] + + # Functions for the `LightningDataModule` conversion + + @property + def transform(self): + return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) + + def prepare_data(self) -> None: + MNIST("./data", download=True) + + def train_dataloader(self): + train_dataset = MNIST("./data", train=True, download=False, transform=self.transform) + return torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.batch_size) + + def test_dataloader(self): + test_dataset = MNIST("./data", train=False, download=False, transform=self.transform) + return torch.utils.data.DataLoader(test_dataset, batch_size=self.hparams.batch_size) + + +def cli_main(): + cli = LightningCLI(ImageClassifier, seed_everything_default=1234, save_config_overwrite=True, run=False) + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) + + +if __name__ == "__main__": + cli_lightning_logo() + cli_main() diff --git a/pl_examples/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py b/pl_examples/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py new file mode 100644 index 0000000000000..fc30836b6c37b --- /dev/null +++ b/pl_examples/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py @@ -0,0 +1,92 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MNIST simple image classifier example with LightningModule and DataModule. + +To run: python image_classifier_5_lightning_datamodule.py --trainer.max_epochs=50 +""" +import torch +import torchvision.transforms as T +from torch.nn import functional as F +from torchmetrics.classification import Accuracy + +from pl_examples import cli_lightning_logo +from pl_examples.basic_examples.mnist_datamodule import MNIST +from pl_examples.basic_examples.mnist_examples.image_classifier_1_pytorch import Net +from pytorch_lightning import LightningDataModule, LightningModule +from pytorch_lightning.utilities.cli import LightningCLI + + +class ImageClassifier(LightningModule): + def __init__(self, model, lr=1.0, gamma=0.7, batch_size=32): + super().__init__() + self.save_hyperparameters() + self.model = model or Net() + self.test_acc = Accuracy() + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y.long()) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.nll_loss(logits, y.long()) + self.test_acc(logits, y.long()) + return loss + + def test_epoch_end(self, *_) -> None: + self.log("test_acc", self.test_acc.compute()) + + def configure_optimizers(self): + optimizer = torch.optim.Adadelta(self.model.parameters(), lr=self.hparams.lr) + return [optimizer], [torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)] + + +class MNISTDataModule(LightningDataModule): + def __init__(self, batch_size=32): + super().__init__() + self.save_hyperparameters() + + @property + def transform(self): + return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) + + def prepare_data(self) -> None: + MNIST("./data", download=True) + + def train_dataloader(self): + train_dataset = MNIST("./data", train=True, download=False, transform=self.transform) + return torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.batch_size) + + def test_dataloader(self): + test_dataset = MNIST("./data", train=False, download=False, transform=self.transform) + return torch.utils.data.DataLoader(test_dataset, batch_size=self.hparams.batch_size) + + +def cli_main(): + cli = LightningCLI( + ImageClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False + ) + cli.trainer.fit(cli.model, datamodule=cli.datamodule) + cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) + + +if __name__ == "__main__": + cli_lightning_logo() + cli_main() diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py deleted file mode 100644 index 146f25c27c0d4..0000000000000 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""MNIST simple image classifier example. - -To run: python simple_image_classifier.py --trainer.max_epochs=50 -""" - -import torch -from torch.nn import functional as F - -import pytorch_lightning as pl -from pl_examples import cli_lightning_logo -from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule -from pytorch_lightning.utilities.cli import LightningCLI - - -class LitClassifier(pl.LightningModule): - """ - >>> LitClassifier() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - LitClassifier( - (l1): Linear(...) - (l2): Linear(...) - ) - """ - - def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.0001): - super().__init__() - self.save_hyperparameters() - - self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) - self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) - - def forward(self, x): - x = x.view(x.size(0), -1) - x = torch.relu(self.l1(x)) - x = torch.relu(self.l2(x)) - return x - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - return loss - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - self.log("valid_loss", loss) - - def test_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - self.log("test_loss", loss) - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - - -def cli_main(): - cli = LightningCLI( - LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False - ) - cli.trainer.fit(cli.model, datamodule=cli.datamodule) - cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule) - - -if __name__ == "__main__": - cli_lightning_logo() - cli_main() diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 48492c8ce7f04..26a6c8aa89f67 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -17,24 +17,21 @@ tensorboard --logdir default """ -import os from argparse import ArgumentParser, Namespace import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import DataLoader from pl_examples import cli_lightning_logo -from pl_examples.basic_examples.mnist_datamodule import MNIST -from pytorch_lightning.core import LightningDataModule, LightningModule +from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule +from pytorch_lightning.core import LightningModule from pytorch_lightning.trainer import Trainer from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: import torchvision - from torchvision import transforms class Generator(nn.Module): @@ -212,35 +209,6 @@ def on_epoch_end(self): self.logger.experiment.add_image("generated_images", grid, self.current_epoch) -class MNISTDataModule(LightningDataModule): - """ - >>> MNISTDataModule() # doctest: +ELLIPSIS - <...generative_adversarial_net.MNISTDataModule object at ...> - """ - - def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4): - super().__init__() - self.batch_size = batch_size - self.data_path = data_path - self.num_workers = num_workers - - self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) - self.dims = (1, 28, 28) - - def prepare_data(self, stage=None): - # Use this method to do things that might write to disk or that need to be done only from a single GPU - # in distributed settings. Like downloading the dataset for the first time. - MNIST(self.data_path, train=True, download=True, transform=transforms.ToTensor()) - - def setup(self, stage=None): - # There are also data operations you might want to perform on every GPU, such as applying transforms - # defined explicitly in your datamodule or assigned in init. - self.mnist_train = MNIST(self.data_path, train=True, transform=self.transform) - - def train_dataloader(self): - return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers) - - def main(args: Namespace) -> None: # ------------------------ # 1 INIT LIGHTNING MODEL @@ -250,7 +218,7 @@ def main(args: Namespace) -> None: # ------------------------ # 2 INIT TRAINER # ------------------------ - # If use distubuted training PyTorch recommends to use DistributedDataParallel. + # If use distributed training PyTorch recommends to use DistributedDataParallel. # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel dm = MNISTDataModule.from_argparse_args(args) trainer = Trainer.from_argparse_args(args) diff --git a/pl_examples/integration_examples/__init__.py b/pl_examples/integration_examples/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/integration_examples/dali_image_classifier.py similarity index 100% rename from pl_examples/basic_examples/dali_image_classifier.py rename to pl_examples/integration_examples/dali_image_classifier.py diff --git a/pl_examples/lite_examples/pytorch_2_lite_2_lightning.py b/pl_examples/lite_examples/pytorch_2_lite_2_lightning.py deleted file mode 100644 index 592d5a7ab951b..0000000000000 --- a/pl_examples/lite_examples/pytorch_2_lite_2_lightning.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -from torch import nn -from torch.utils.data import DataLoader, Dataset - -from pytorch_lightning import seed_everything -from pytorch_lightning.lite import LightningLite - -############################################################################################# -# Section 1: PyTorch to Lightning Lite # -# # -# What is LightningLite ? # -# # -# `LightningLite` is a python class you can override to get access to Lightning # -# accelerators and scale your training, but furthermore, it is intended to be the safest # -# route to fully transition to Lightning. # -# # -# Does LightningLite requires code changes ? # -# # -# `LightningLite` code changes are minimal and this tutorial will show you how easy it # -# is to convert to `lite` using a `BoringModel`. # -# # -############################################################################################# - -############################################################################################# -# Pure PyTorch Section # -############################################################################################# - - -# 1 / 6: Implement a `BoringModel` with only one layer. -class BoringModel(nn.Module): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2) - - def forward(self, x): - x = self.layer(x) - return torch.nn.functional.mse_loss(x, torch.ones_like(x)) - - -# 2 / 6: Implement a `configure_optimizers` taking a module and returning an optimizer. -def configure_optimizers(module: nn.Module): - return torch.optim.SGD(module.parameters(), lr=0.001) - - -# 3 / 6: Implement a simple dataset returning random data with the specified shape. -class RandomDataset(Dataset): - def __init__(self, length: int, size: int): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - return self.data[index] - - def __len__(self): - return self.len - - -# 4 / 6: Implement the functions to create the dataloaders. -def train_dataloader(): - return DataLoader(RandomDataset(64, 32)) - - -def val_dataloader(): - return DataLoader(RandomDataset(64, 32)) - - -# 5 / 6: Our main PyTorch Loop to train our `BoringModel` on our random data. -def main(model: nn.Module, train_dataloader: DataLoader, val_dataloader: DataLoader, num_epochs: int = 10): - optimizer = configure_optimizers(model) - - for epoch in range(num_epochs): - train_losses = [] - val_losses = [] - - model.train() - for batch in train_dataloader: - optimizer.zero_grad() - loss = model(batch) - train_losses.append(loss) - loss.backward() - optimizer.step() - - model.eval() - with torch.no_grad(): - for batch in val_dataloader: - val_losses.append(model(batch)) - - train_epoch_loss = torch.stack(train_losses).mean() - val_epoch_loss = torch.stack(val_losses).mean() - - print(f"{epoch}/{num_epochs}| Train Epoch Loss: {torch.mean(train_epoch_loss)}") - print(f"{epoch}/{num_epochs}| Valid Epoch Loss: {torch.mean(val_epoch_loss)}") - - return model.state_dict() - - -# 6 / 6: Run the pure PyTorch Loop and train / validate the model. -if __name__ == "__main__": - seed_everything(42) - model = BoringModel() - pure_model_weights = main(model, train_dataloader(), val_dataloader()) - - -############################################################################################# -# Convert to LightningLite # -# # -# By converting to `LightningLite`, you get the full power of Lightning accelerators # -# while conversing your original code ! # -# To get started, you would need to `from pytorch_lightning.lite import LightningLite` # -# and override its `run` method. # -############################################################################################# - - -class LiteTrainer(LightningLite): - def run(self, model: nn.Module, train_dataloader: DataLoader, val_dataloader: DataLoader, num_epochs: int = 10): - optimizer = configure_optimizers(model) - - ################################################################################### - # You would need to call `self.setup` to wrap `model` and `optimizer`. If you # - # have multiple models (c.f GAN), call `setup` for each one of them and their # - # associated optimizers. # - model, optimizer = self.setup(model, optimizer) - ################################################################################### - - ################################################################################### - # You would need to call `self.setup_dataloaders` to prepare the dataloaders # - # in case you are running in a distributed setting. # - train_dataloader = self.setup_dataloaders(train_dataloader) - val_dataloader = self.setup_dataloaders(val_dataloader) - ################################################################################### - - for epoch in range(num_epochs): - train_losses = [] - val_losses = [] - - model.train() - for batch in train_dataloader: - optimizer.zero_grad() - loss = model(batch) - train_losses.append(loss) - ########################################################################### - # By calling `self.backward` directly, `LightningLite` will automate # - # precision and distributions. # - self.backward(loss) - ########################################################################### - optimizer.step() - - model.eval() - with torch.no_grad(): - for batch in val_dataloader: - val_losses.append(model(batch)) - - train_epoch_loss = torch.stack(train_losses).mean() - val_epoch_loss = torch.stack(val_losses).mean() - - ################################################################################ - # Optional: Utility to print only on rank 0 (when using distributed setting) # - self.print(f"{epoch}/{num_epochs}| Train Epoch Loss: {train_epoch_loss}") - self.print(f"{epoch}/{num_epochs}| Valid Epoch Loss: {val_epoch_loss}") - ################################################################################ - - -if __name__ == "__main__": - seed_everything(42) - lite_model = BoringModel() - lite = LiteTrainer() - lite.run(lite_model, train_dataloader(), val_dataloader()) - - ############################################################################################# - # Assert the weights are the same # - ############################################################################################# - - for pure_w, lite_w in zip(pure_model_weights.values(), lite_model.state_dict().values()): - torch.equal(pure_w, lite_w) - - -############################################################################################# -# Convert to Lightning # -# # -# By converting to Lightning, not-only your research code becomes inter-operable # -# (can easily be shared), but you get access to hundreds of extra features to make your # -# research faster. # -# Check `Facebook` blogpost on how `Lightning` enabled their research to scale at scale # -# On https://ai.facebook.com/blog # -# /reengineering-facebook-ais-deep-learning-platforms-for-interoperability/ # -############################################################################################# - -from pytorch_lightning import LightningDataModule, LightningModule, Trainer # noqa E402 - - -class LightningBoringModel(LightningModule): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2) - - def forward(self, x): - x = self.layer(x) - return torch.nn.functional.mse_loss(x, torch.ones_like(x)) - - # LightningModule hooks - def training_step(self, batch, batch_idx): - x = self.forward(batch) - self.log("train_loss", x) - return x - - def validation_step(self, batch, batch_idx): - x = self.forward(batch) - self.log("val_loss", x) - return x - - def configure_optimizers(self): - return configure_optimizers(self) - - -class BoringDataModule(LightningDataModule): - def train_dataloader(self): - return train_dataloader() - - def val_dataloader(self): - return val_dataloader() - - -if __name__ == "__main__": - seed_everything(42) - lightning_module = LightningBoringModel() - datamodule = BoringDataModule() - trainer = Trainer(max_epochs=10) - trainer.fit(lightning_module, datamodule) - - ############################################################################################# - # Assert the weights are the same # - ############################################################################################# - - for pure_w, lite_w in zip(pure_model_weights.values(), lightning_module.state_dict().values()): - torch.equal(pure_w, lite_w) diff --git a/pl_examples/loop_examples/mnist_lite.py b/pl_examples/loop_examples/mnist_lite.py new file mode 100644 index 0000000000000..738964a56f6dc --- /dev/null +++ b/pl_examples/loop_examples/mnist_lite.py @@ -0,0 +1,189 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from typing import Any + +import torch +import torch.nn.functional as F +import torch.optim as optim +import torchvision.transforms as T +from torch.optim.lr_scheduler import StepLR +from torchmetrics import Accuracy + +from pl_examples.basic_examples.mnist_datamodule import MNIST +from pl_examples.basic_examples.mnist_examples.image_classifier_1_pytorch import Net +from pytorch_lightning import seed_everything +from pytorch_lightning.lite import LightningLite +from pytorch_lightning.loops import Loop + + +class TrainLoop(Loop): + def __init__(self, lite, args, model, optimizer, scheduler, dataloader): + super().__init__() + self.lite = lite + self.args = args + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.dataloader = dataloader + + @property + def done(self) -> bool: + return False + + def reset(self): + self.dataloader_iter = enumerate(self.dataloader) + + def advance(self, epoch) -> None: + batch_idx, (data, target) = next(self.dataloader_iter) + self.optimizer.zero_grad() + output = self.model(data) + loss = F.nll_loss(output, target) + self.lite.backward(loss) + self.optimizer.zero_grad() + + if (batch_idx == 0) or ((batch_idx + 1) % self.args.log_interval == 0): + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(self.dataloader), + len(self.dataloader.dataset), + 100.0 * batch_idx / len(self.dataloader), + loss.item(), + ) + ) + + if self.args.dry_run: + raise StopIteration + + def on_run_end(self): + self.scheduler.step() + self.dataloader_iter = None + + +class TestLoop(Loop): + def __init__(self, lite, args, model, dataloader): + super().__init__() + self.lite = lite + self.args = args + self.model = model + self.dataloader = dataloader + self.accuracy = Accuracy() + + @property + def done(self) -> bool: + return False + + def reset(self): + self.dataloader_iter = enumerate(self.dataloader) + self.test_loss = 0 + self.accuracy.reset() + + def advance(self) -> None: + _, (data, target) = next(self.dataloader_iter) + output = self.model(data) + self.test_loss += F.nll_loss(output, target) + self.accuracy(output, target) + + if self.args.dry_run: + raise StopIteration + + def on_run_end(self): + test_loss = self.lite.all_gather(self.test_loss).sum() / len(self.dataloader.dataset) + + if self.lite.is_global_zero: + print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({self.accuracy.compute():.0f}%)\n") + + +class MainLoop(Loop): + def __init__(self, lite, args, model, optimizer, scheduler, train_loader, test_loader): + super().__init__() + self.lite = lite + self.args = args + self.epoch = 0 + self.train_loop = TrainLoop(self.lite, self.args, model, optimizer, scheduler, train_loader) + self.test_loop = TestLoop(self.lite, self.args, model, test_loader) + + @property + def done(self) -> bool: + return self.epoch >= self.args.epochs + + def reset(self): + pass + + def advance(self, *args: Any, **kwargs: Any) -> None: + self.train_loop.run(self.epoch) + self.test_loop.run() + + if self.args.dry_run: + raise StopIteration + + self.epoch += 1 + self.lite.val_acc.reset() + + +class Lite(LightningLite): + def run(self, hparams): + transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) + train_dataset = MNIST("./data", train=True, download=True, transform=transform) + test_dataset = MNIST("./data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(train_dataset, hparams.batch_size) + test_loader = torch.utils.data.DataLoader(test_dataset, hparams.test_batch_size) + + train_loader, test_loader = self.setup_dataloaders(train_loader, test_loader) + + model = Net() + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + model, optimizer = self.setup(model, optimizer) + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + + MainLoop(self, args, model, optimizer, scheduler, train_loader, test_loader).run() + + if args.save_model and self.is_global_zero: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="LightningLite MNIST Example with Lightning Loops.") + parser.add_argument( + "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" + ) + parser.add_argument("--epochs", type=int, default=14, metavar="N", help="number of epochs to train (default: 14)") + parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)") + parser.add_argument("--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)") + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model") + args = parser.parse_args() + + seed_everything(args.seed) + + if torch.cuda.is_available(): + lite_kwargs = {"accelerator": "gpu", "devices": torch.cuda.device_count()} + else: + lite_kwargs = {"accelerator": "cpu"} + + Lite(**lite_kwargs).run(args) diff --git a/pl_examples/run_examples.sh b/pl_examples/run_examples.sh index 7555e472d24e2..e6b0c6bef1170 100755 --- a/pl_examples/run_examples.sh +++ b/pl_examples/run_examples.sh @@ -11,6 +11,23 @@ args=" --trainer.limit_predict_batches=2 " -python "${dir_path}/basic_examples/simple_image_classifier.py" ${args} "$@" python "${dir_path}/basic_examples/backbone_image_classifier.py" ${args} "$@" python "${dir_path}/basic_examples/autoencoder.py" ${args} "$@" + + +args=" + --trainer.max_epochs=1 + --trainer.limit_train_batches=2 + --trainer.limit_val_batches=2 + --trainer.limit_test_batches=2 + --trainer.limit_predict_batches=2 +" + +python "${dir_path}/basic_examples/mnist_examples/image_classifier_4_lightning.py" ${args} "$@" +python "${dir_path}/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py" ${args} "$@" + +args="--dry-run" +python "${dir_path}/basic_examples/mnist_examples/image_classifier_1_pytorch.py" ${args} "$@" +python "${dir_path}/basic_examples/mnist_examples/image_classifier_2_lite.py" ${args} "$@" +python "${dir_path}/basic_examples/mnist_examples/image_classifier_3_lite_to_lightning.py" ${args} "$@" +python "${dir_path}/loop_examples/mnist_lite.py" ${args} "$@" diff --git a/pl_examples/test_examples.py b/pl_examples/test_examples.py index b0b451692e4b6..19d09836ef34c 100644 --- a/pl_examples/test_examples.py +++ b/pl_examples/test_examples.py @@ -34,7 +34,7 @@ @RunIf(min_gpus=1, skip_windows=True) @pytest.mark.parametrize("cli_args", [ARGS_GPU]) def test_examples_mnist_dali(tmpdir, cli_args): - from pl_examples.basic_examples.dali_image_classifier import cli_main + from pl_examples.integration_examples.dali_image_classifier import cli_main # update the temp dir cli_args = cli_args % {"tmpdir": tmpdir}