|
| 1 | +# Copyright The PyTorch Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import inspect |
| 15 | +from functools import partial |
| 16 | +from typing import Generator |
| 17 | + |
| 18 | +import torch |
| 19 | + |
| 20 | +from pl_examples.domain_templates.generative_adversarial_net import GAN as GANTemplate |
| 21 | +from pl_examples.domain_templates.generative_adversarial_net import MNISTDataModule |
| 22 | +from pytorch_lightning import Trainer |
| 23 | +from pytorch_lightning.loops import OptimizerLoop |
| 24 | +from pytorch_lightning.loops.optimization.optimizer_loop import ClosureResult |
| 25 | +from pytorch_lightning.loops.utilities import _build_training_step_kwargs |
| 26 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 27 | + |
| 28 | +############################################################################################# |
| 29 | +# Yield Loop # |
| 30 | +# # |
| 31 | +# This example shows an implementation of a custom loop that changes how the # |
| 32 | +# `LightningModule.training_step` behaves. In particular, this custom "Yield" loop will # |
| 33 | +# enable the `training_step` to yield like a Python generator, retaining the values # |
| 34 | +# of local variables for subsequent calls. This can result in much cleaner and elegant # |
| 35 | +# code when dealing with multiple optimizers (automatic optimization). # |
| 36 | +# # |
| 37 | +# Learn more about the loop structure from the documentation: # |
| 38 | +# https://pytorch-lightning.readthedocs.io/en/latest/extensions/loops.html # |
| 39 | +############################################################################################# |
| 40 | + |
| 41 | + |
| 42 | +############################################################################################# |
| 43 | +# Step 1 / 3: Implement a custom OptimizerLoop # |
| 44 | +# # |
| 45 | +# The `training_step` gets called in the # |
| 46 | +# `pytorch_lightning.loops.optimization.OptimizerLoop`. To make it into a Python generator, # |
| 47 | +# we need to override the place where it gets called. # |
| 48 | +############################################################################################# |
| 49 | + |
| 50 | + |
| 51 | +class YieldLoop(OptimizerLoop): |
| 52 | + def __init__(self): |
| 53 | + super().__init__() |
| 54 | + self._generator = None |
| 55 | + |
| 56 | + def connect(self, **kwargs): |
| 57 | + raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") |
| 58 | + |
| 59 | + def on_run_start(self, batch, optimizers, batch_idx): |
| 60 | + super().on_run_start(batch, optimizers, batch_idx) |
| 61 | + if not inspect.isgeneratorfunction(self.trainer.lightning_module.training_step): |
| 62 | + raise MisconfigurationException("The LightingModule does not yield anything in the `training_step`.") |
| 63 | + assert self.trainer.lightning_module.automatic_optimization |
| 64 | + |
| 65 | + # We request the generator once and save it for later |
| 66 | + # so we can call next() on it. |
| 67 | + self._generator = self._get_generator(batch, batch_idx, opt_idx=0) |
| 68 | + |
| 69 | + def _make_step_fn(self, split_batch, batch_idx, opt_idx): |
| 70 | + return partial(self._training_step, self._generator) |
| 71 | + |
| 72 | + def _get_generator(self, split_batch, batch_idx, opt_idx): |
| 73 | + step_kwargs = _build_training_step_kwargs( |
| 74 | + self.trainer.lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens=None |
| 75 | + ) |
| 76 | + |
| 77 | + # Here we are basically calling `lightning_module.training_step()` |
| 78 | + # and this returns a generator! The `training_step` is handled by the |
| 79 | + # accelerator to enable distributed training. |
| 80 | + return self.trainer.accelerator.training_step(step_kwargs) |
| 81 | + |
| 82 | + def _training_step(self, generator): |
| 83 | + # required for logging |
| 84 | + self.trainer.lightning_module._current_fx_name = "training_step" |
| 85 | + |
| 86 | + # Here, instead of calling `lightning_module.training_step()` |
| 87 | + # we call next() on the generator! |
| 88 | + training_step_output = next(generator) |
| 89 | + self.trainer.accelerator.post_training_step() |
| 90 | + |
| 91 | + training_step_output = self.trainer.call_hook("training_step_end", training_step_output) |
| 92 | + |
| 93 | + # The closure result takes care of properly detaching the loss for logging and peforms |
| 94 | + # some additional checks that the output format is correct. |
| 95 | + result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches) |
| 96 | + return result |
| 97 | + |
| 98 | + |
| 99 | +############################################################################################# |
| 100 | +# Step 2 / 3: Implement a model using the new yield mechanism # |
| 101 | +# # |
| 102 | +# We can now implement a model that defines the `training_step` using "yield" statements. # |
| 103 | +# We choose a generative adversarial network (GAN) because it alternates between two # |
| 104 | +# optimizers updating the model parameters. In the first step we compute the loss of the # |
| 105 | +# first network (coincidentally also named "generator") and yield the loss. In the second # |
| 106 | +# step we compute the loss of the second network (the "discriminator") and yield again. # |
| 107 | +# The nice property of this yield approach is that we can reuse variables that we computed # |
| 108 | +# earlier. If this was a regular Lightning `training_step`, we would have to recompute the # |
| 109 | +# output of the first network. # |
| 110 | +############################################################################################# |
| 111 | + |
| 112 | + |
| 113 | +class GAN(GANTemplate): |
| 114 | + |
| 115 | + # This training_step method is now a Python generator |
| 116 | + def training_step(self, batch, batch_idx, optimizer_idx=0) -> Generator: |
| 117 | + imgs, _ = batch |
| 118 | + z = torch.randn(imgs.shape[0], self.hparams.latent_dim) |
| 119 | + z = z.type_as(imgs) |
| 120 | + |
| 121 | + # Here, we compute the generator output once and reuse it later. |
| 122 | + # It gets saved when we yield from the training_step. |
| 123 | + # The output then gets re-used again in the discriminator update. |
| 124 | + generator_output = self(z) |
| 125 | + |
| 126 | + # train generator |
| 127 | + real_labels = torch.ones(imgs.size(0), 1) |
| 128 | + real_labels = real_labels.type_as(imgs) |
| 129 | + g_loss = self.adversarial_loss(self.discriminator(generator_output), real_labels) |
| 130 | + self.log("g_loss", g_loss) |
| 131 | + |
| 132 | + # Yield instead of return: This makes the training_step a Python generator. |
| 133 | + # Once we call it again, it will continue the execution with the block below |
| 134 | + yield g_loss |
| 135 | + |
| 136 | + # train discriminator |
| 137 | + real_labels = torch.ones(imgs.size(0), 1) |
| 138 | + real_labels = real_labels.type_as(imgs) |
| 139 | + real_loss = self.adversarial_loss(self.discriminator(imgs), real_labels) |
| 140 | + fake_labels = torch.zeros(imgs.size(0), 1) |
| 141 | + fake_labels = fake_labels.type_as(imgs) |
| 142 | + |
| 143 | + # We make use again of the generator_output |
| 144 | + fake_loss = self.adversarial_loss(self.discriminator(generator_output.detach()), fake_labels) |
| 145 | + d_loss = (real_loss + fake_loss) / 2 |
| 146 | + self.log("d_loss", d_loss) |
| 147 | + |
| 148 | + yield d_loss |
| 149 | + |
| 150 | + |
| 151 | +############################################################################################# |
| 152 | +# Step 3 / 3: Connect the loop to the Trainer # |
| 153 | +# # |
| 154 | +# Finally, attach the loop to the `Trainer`. Here, we modified the `AutomaticOptimization` # |
| 155 | +# loop which is a subloop of the `TrainingBatchLoop`. We use `.connect()` to attach it. # |
| 156 | +############################################################################################# |
| 157 | + |
| 158 | +if __name__ == "__main__": |
| 159 | + model = GAN() |
| 160 | + dm = MNISTDataModule() |
| 161 | + trainer = Trainer() |
| 162 | + |
| 163 | + # Connect the new loop |
| 164 | + # YieldLoop now replaces the previous optimizer loop |
| 165 | + trainer.fit_loop.epoch_loop.batch_loop.connect(optimizer_loop=YieldLoop()) |
| 166 | + |
| 167 | + # fit() will now use the new loop! |
| 168 | + trainer.fit(model, dm) |
0 commit comments