Skip to content

Deepspeed + Auto Select GPUs = CUDA Out of Memory Error #6857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
fishbotics opened this issue Apr 6, 2021 · 4 comments
Closed

Deepspeed + Auto Select GPUs = CUDA Out of Memory Error #6857

fishbotics opened this issue Apr 6, 2021 · 4 comments
Labels
3rd party Related to a 3rd-party bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task won't fix This will not be worked on

Comments

@fishbotics
Copy link

🐛 Bug

Please reproduce using the BoringModel

https://colab.research.google.com/drive/17Bt2m570f4o16iwbEV1fpUhgO04cuCqg?usp=sharing

To Reproduce

You can see the code on the BoringModel above, but I don't think it'll run on Colab because it's a multigpu issue.

Basically, when I have a large-ish model (2M parameters), I find that deepspeed is incompatible with auto_select_gpus.

So,

trainer = pl.Trainer(
    gpus=8,
    accelerator='ddp',
    plugs='deepspeed',
    precision=16,
)

seems to work

But,

trainer = pl.Trainer(
    gpus=-1,
    auto_select_gpus=True,
    accelerator='ddp',
    plugs='deepspeed',
    precision=16,
)

causes a CUDA out of memory error.

Expected behavior

I'd expect it to select the GPUs and run it.

Environment

  • CUDA:
    - GPU:
    - Tesla V100-SXM2-16GB-N
    - Tesla V100-SXM2-16GB-N
    - Tesla V100-SXM2-16GB-N
    - Tesla V100-SXM2-16GB-N
    - Tesla V100-SXM2-16GB-N
    - Tesla V100-SXM2-16GB-N
    - Tesla V100-SXM2-16GB-N
    - Tesla V100-SXM2-16GB-N
    - available: True
    - version: 11.1
  • Packages:
    - numpy: 1.20.2
    - pyTorch_debug: False
    - pyTorch_version: 1.8.1+cu111
    - pytorch-lightning: 1.2.6
    - tqdm: 4.59.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    -
    - processor: x86_64
    - python: 3.9.2
    - version: Tests #16~16.04.1-Ubuntu SMP Thu Apr 5 12:19:23 UTC 2018

Additional context

Semi-related, am I supposed to specify an accelerator when using deepspeed? In the docs, none is specified, but when i run without an accelerator, it complains and says I should be setting one.

@fishbotics fishbotics added bug Something isn't working help wanted Open to be worked on labels Apr 6, 2021
@SeanNaren SeanNaren added the 3rd party Related to a 3rd-party label Apr 7, 2021
@SeanNaren SeanNaren self-assigned this Apr 7, 2021
@SeanNaren
Copy link
Contributor

Hey @fishbotics thanks for your issue!

I used your notebook and created a script I could run on multi-GPUs:

import os

import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

tmpdir = os.getcwd()
import torch
from pytorch_lightning import LightningModule


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


# This one is snagged from https://github.com/WangYueFt/dgcnn
# It has about 2M parameters, which is big enough to cause this bug
# to come up on my machine
class BoringModel(LightningModule):
    def __init__(self, k=20, emb_dims=1034, dropout=0.5, output_channels=256):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.k = k
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm1d(emb_dims)

        self.conv1 = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=1, bias=False),
            self.bn1,
            nn.LeakyReLU(negative_slope=0.2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False),
            self.bn2,
            nn.LeakyReLU(negative_slope=0.2),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64 * 2, 128, kernel_size=1, bias=False),
            self.bn3,
            nn.LeakyReLU(negative_slope=0.2),
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(128 * 2, 256, kernel_size=1, bias=False),
            self.bn4,
            nn.LeakyReLU(negative_slope=0.2),
        )
        self.conv5 = nn.Sequential(
            nn.Conv1d(512, emb_dims, kernel_size=1, bias=False),
            self.bn5,
            nn.LeakyReLU(negative_slope=0.2),
        )
        self.linear1 = nn.Linear(emb_dims * 2, 512, bias=False)
        self.bn6 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=dropout)
        self.linear2 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=dropout)
        self.linear3 = nn.Linear(256, output_channels)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.get_graph_feature(x, k=self.k)
        x = self.conv1(x)
        x1 = x.max(dim=-1, keepdim=False)[0]

        x = self.get_graph_feature(x1, k=self.k)
        x = self.conv2(x)
        x2 = x.max(dim=-1, keepdim=False)[0]

        x = self.get_graph_feature(x2, k=self.k)
        x = self.conv3(x)
        x3 = x.max(dim=-1, keepdim=False)[0]

        x = self.get_graph_feature(x3, k=self.k)
        x = self.conv4(x)
        x4 = x.max(dim=-1, keepdim=False)[0]

        x = torch.cat((x1, x2, x3, x4), dim=1)

        x = self.conv5(x)
        x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)
        x = torch.cat((x1, x2), 1)

        x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2)
        x = self.dp1(x)
        x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2)
        x = self.dp2(x)
        x = self.linear3(x)
        return x

    def knn(self, x, k):
        inner = -2 * torch.matmul(x.transpose(2, 1), x)
        xx = torch.sum(x ** 2, dim=1, keepdim=True)
        pairwise_distance = -xx - inner - xx.transpose(2, 1)
        idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)
        return idx

    def get_graph_feature(self, x, k=20, idx=None):
        batch_size = x.size(0)
        num_points = x.size(2)
        x = x.view(batch_size, -1, num_points)
        if idx is None:
            idx = self.knn(x, k=k)  # (batch_size, num_points, k)

        idx_base = (
                torch.arange(0, batch_size).view(-1, 1, 1) * num_points
        ).type_as(idx)

        idx = idx + idx_base

        idx = idx.view(-1)

        _, num_dims, _ = x.size()

        x = x.transpose(2, 1)
        feature = x.reshape(batch_size * num_points, -1)[idx, :]
        feature = feature.view(batch_size, num_points, k, num_dims)
        x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)

        feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2)

        return feature

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x['x'] for x in outputs]).mean()

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log('fake_test_acc', loss)
        return {"y": loss}

    def test_epoch_end(self, outputs) -> None:
        torch.stack([x["y"] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]


if __name__ == '__main__':
    num_samples = 10000
    train = RandomDataset(32, num_samples)
    train = DataLoader(train, batch_size=32)

    val = RandomDataset(32, num_samples)
    val = DataLoader(val, batch_size=32)

    test = RandomDataset(32, num_samples)
    test = DataLoader(test, batch_size=32)
    # init model
    model = BoringModel()

    # Initialize a trainer
    trainer = pl.Trainer(
        gpus=8,
        auto_select_gpus=True,
        accelerator='ddp',
        precision=16,
        max_epochs=1,
        progress_bar_refresh_rate=20
    )

    # Train the model ⚡
    trainer.fit(model, train, val)

    trainer.test(test_dataloaders=test)

I noticed that when using auto_select_gpus=True without DeepSpeed, on GPU 0 there is additional memory being assigned based on the number of GPUs. This might lead to the cause of the issue here and I can investigate further.

Could you verify that not using auto_select_gpus but using gpus=-1 on a machine with all available GPUs works? This can be a short term workaround as we debug why additional memory is being allocated when using auto_select_gpus.

Semi-related, am I supposed to specify an accelerator when using deepspeed? In the docs, none is specified, but when i run without an accelerator, it complains and says I should be setting one.

Apologies for the confusion! The warning message can be ignored when plugins='deepspeed'. This is due to our current function setup when selecting the plugins in the backend throwing a warning incorrectly. #6090 will fix this issue, but I might see what I can do in the short term to get this fixed in our minor bug release!

@fishbotics
Copy link
Author

Yes that seems to work! Can you explain a little about what's going on under the good with auto_select_gpus where ignoring it and just using -1 would help so much?

Thanks for your help!

@SeanNaren
Copy link
Contributor

Awesome! Definitely there is a bug with auto_select_gpus and DeepSpeed that needs investigation.

auto_select_gpus assigns a tensor onto the GPU; if it successfully allocates then we use the GPU, however if it doesn't allocate the GPU is not used within training. This is good for when you may have other jobs running on the cluster and you need a simple check.

@edenlightning edenlightning added the priority: 1 Medium priority task label Jul 1, 2021
@stale
Copy link

stale bot commented Aug 1, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Aug 1, 2021
@stale stale bot closed this as completed Aug 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

3 participants