Skip to content

RuntimeError when running basic GAN model (from tutorial at lightning.ai) with DDP #20328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
pranavrao-qure opened this issue Oct 9, 2024 · 4 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x

Comments

@pranavrao-qure
Copy link

Bug description

I am trying to train a GAN model on multiple GPUs using DDP. I followed the tutorial at https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html, changing the arguments to Trainer to

trainer = pl.Trainer(
        accelerator="auto",
        devices=[0, 1, 2, 3],
        strategy="ddp",
        max_epochs=5,
)

Running the script raise Runtime error as follows:

[rank0]: RuntimeError: It looks like your LightningModule has parameters that were not used in producing the loss returned by training_step. If this is intentional, you must enable the detection of unused parameters in DDP, either by setting the string value `strategy='ddp_find_unused_parameters_true'` or by setting the flag in the strategy with `strategy=DDPStrategy(find_unused_parameters=True)`.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

import os

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)


class MNISTDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = PATH_DATASETS,
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh(),
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity
    
class GAN(pl.LightningModule):
    def __init__(
        self,
        channels,
        width,
        height,
        latent_dim: int = 100,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = BATCH_SIZE,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        # networks
        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)

        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch):
        imgs, _ = batch

        optimizer_g, optimizer_d = self.optimizers()

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        # train generator
        # generate images
        self.toggle_optimizer(optimizer_g)
        self.generated_imgs = self(z)

        # log sampled images
        sample_imgs = self.generated_imgs[:6]
        grid = torchvision.utils.make_grid(sample_imgs)
        # self.logger.experiment.add_image("train/generated_images", grid, self.current_epoch)

        # ground truth result (ie: all fake)
        # put on GPU because we created this tensor inside training_loop
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        # adversarial loss is binary cross-entropy
        g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        # train discriminator
        # Measure discriminator's ability to classify real from generated samples
        self.toggle_optimizer(optimizer_d)

        # how well can it label as real?
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)

        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

        # how well can it label as fake?
        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs)

        fake_loss = self.adversarial_loss(self.discriminator(self.generated_imgs.detach()), fake)

        # discriminator loss is the average of these
        d_loss = (real_loss + fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def validation_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

    def on_validation_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)

        # log sampled images
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
        # self.logger.experiment.add_image("validation/generated_images", grid, self.current_epoch)

if __name__ == '__main__':
    dm = MNISTDataModule()
    model = GAN(*dm.dims)
    trainer = pl.Trainer(
        accelerator="auto",
        devices=[0, 1, 2, 3],
        strategy="ddp",
        max_epochs=5,
    )
    trainer.fit(model, dm)

Error messages and logs

[rank0]:   File "/home/ubuntu/pranav.rao/qxr_training/test_gan_lightning_ddp.py", line 196, in training_step
[rank0]:     self.manual_backward(d_loss)
[rank0]:   File "/home/ubuntu/miniconda3/envs/lightningTest/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1082, in manual_backward
[rank0]:     self.trainer.strategy.backward(loss, None, *args, **kwargs)
[rank0]:   File "/home/ubuntu/miniconda3/envs/lightningTest/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 208, in backward
[rank0]:     self.pre_backward(closure_loss)
[rank0]:   File "/home/ubuntu/miniconda3/envs/lightningTest/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 317, in pre_backward
[rank0]:     prepare_for_backward(self.model, closure_loss)
[rank0]:   File "/home/ubuntu/miniconda3/envs/lightningTest/lib/python3.10/site-packages/pytorch_lightning/overrides/distributed.py", line 55, in prepare_for_backward
[rank0]:     reducer._rebuild_buckets()  # avoids "INTERNAL ASSERT FAILED" with `find_unused_parameters=False`
[rank0]: RuntimeError: It looks like your LightningModule has parameters that were not used in producing the loss returned by training_step. If this is intentional, you must enable the detection of unused parameters in DDP, either by setting the string value `strategy='ddp_find_unused_parameters_true'` or by setting the flag in the strategy with `strategy=DDPStrategy(find_unused_parameters=True)`.

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA L40S
    - NVIDIA L40S
    - NVIDIA L40S
    - NVIDIA L40S
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.4.0
    - lightning-utilities: 0.11.7
    - pytorch-lightning: 2.4.0
    - torch: 2.4.1
    - torchmetrics: 1.4.2
    - torchvision: 0.19.1
  • Packages:
    - aiohappyeyeballs: 2.4.3
    - aiohttp: 3.10.9
    - aiosignal: 1.3.1
    - async-timeout: 4.0.3
    - attrs: 24.2.0
    - autocommand: 2.2.2
    - backports.tarfile: 1.2.0
    - cxr-training: 0.1.0
    - filelock: 3.16.1
    - frozenlist: 1.4.1
    - fsspec: 2024.9.0
    - idna: 3.10
    - importlib-metadata: 8.0.0
    - importlib-resources: 6.4.0
    - inflect: 7.3.1
    - jaraco.collections: 5.1.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jaraco.text: 3.12.1
    - jinja2: 3.1.4
    - lightning: 2.4.0
    - lightning-utilities: 0.11.7
    - markupsafe: 3.0.1
    - more-itertools: 10.3.0
    - mpmath: 1.3.0
    - multidict: 6.1.0
    - networkx: 3.3
    - numpy: 2.1.2
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu12: 9.1.0.70
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-nccl-cu12: 2.20.5
    - nvidia-nvjitlink-cu12: 12.6.77
    - nvidia-nvtx-cu12: 12.1.105
    - packaging: 24.1
    - pillow: 10.4.0
    - pip: 24.2
    - platformdirs: 4.2.2
    - propcache: 0.2.0
    - pytorch-lightning: 2.4.0
    - pyyaml: 6.0.2
    - setuptools: 75.1.0
    - sympy: 1.13.3
    - tomli: 2.0.1
    - torch: 2.4.1
    - torchmetrics: 1.4.2
    - torchvision: 0.19.1
    - tqdm: 4.66.5
    - triton: 3.0.0
    - typeguard: 4.3.0
    - typing-extensions: 4.12.2
    - wheel: 0.44.0
    - yarl: 1.14.0
    - zipp: 3.19.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.0
    - release: 5.15.0-1063-nvidia
    - version: added test model to do also #64-Ubuntu SMP Fri Aug 9 17:13:45 UTC 2024

More info

No response

@pranavrao-qure pranavrao-qure added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Oct 9, 2024
@dadwadw233
Copy link
Contributor

set strategy == "ddp_find_unused_parameters_true" as error log said

@pranavrao-qure
Copy link
Author

Doesn't setting strategy == "ddp_find_unused_parameters_true" make an extra forward pass, using more computation and time? As far I understant the tutorial, there doesn't seem to be any parameters with requires_grad=True during the computation of d_loss which should have grad=None, as the function call self.toggle_optimizer(optimizer_d) will set the value of the requires_grad to False for parameters other than ones being optimised by optimizer_d

@dadwadw233
Copy link
Contributor

dadwadw233 commented Oct 10, 2024

Doesn't setting strategy == "ddp_find_unused_parameters_true" make an extra forward pass, using more computation and time? As far I understant the tutorial, there doesn't seem to be any parameters with requires_grad=True during the computation of d_loss which should have grad=None, as the function call self.toggle_optimizer(optimizer_d) will set the value of the requires_grad to False for parameters other than ones being optimised by optimizer_d

I check the code you provided and find out the "unused params" follow the https://discuss.pytorch.org/t/how-to-find-the-unused-parameters-in-network/63948/5, it looks like the main reason is discriminator and generator calculate loss separately but lightning module make them as single model, follow the debug method i mentioned above:

# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
self.log("g_loss", g_loss, prog_bar=True)
self.manual_backward(g_loss)
for name, param in self.named_parameters():
    if param.grad is None:
        print(name)
optimizer_g.step()
optimizer_g.zero_grad()
self.untoggle_optimizer(optimizer_g)

# train discriminator
# Measure discriminator's ability to classify real from generated samples
self.toggle_optimizer(optimizer_d)

# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)

real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)

fake_loss = self.adversarial_loss(self.discriminator(self.generated_imgs.detach()), fake)

# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
self.log("d_loss", d_loss, prog_bar=True)
self.manual_backward(d_loss)
for name, param in self.named_parameters():
    if param.grad is None:
        print(name)
optimizer_d.step()
optimizer_d.zero_grad()
self.untoggle_optimizer(optimizer_d)

i got the output (by setting "ddp_find_unused_parameters_true"):
image

if you call backward by :

self.manual_backward(d_loss + g_loss)

self.toggle_optimizer(optimizer_d)
optimizer_d.step()
optimizer_d.zero_grad()
self.untoggle_optimizer(optimizer_d)
self.toggle_optimizer(optimizer_g)
optimizer_g.step()
optimizer_g.zero_grad()
self.untoggle_optimizer(optimizer_g)

"ddp" setting will work correctly

@dadwadw233
Copy link
Contributor

dadwadw233 commented Apr 12, 2025

As discussed further in this issue, I need to correct my suggestion ⤴ that adding up the generator and discriminator losses and calling backpropagation only once is completely wrong, even if they do “solve” the problem of needing to set “ddp_find_unused_parameters_true”

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

2 participants