Skip to content

batch_size selected by auto_scale_batch_size triggers out of memory error #9625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
cowwoc opened this issue Sep 21, 2021 · 18 comments · Fixed by #14372
Closed

batch_size selected by auto_scale_batch_size triggers out of memory error #9625

cowwoc opened this issue Sep 21, 2021 · 18 comments · Fixed by #14372
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 2 Low priority task tuner
Milestone

Comments

@cowwoc
Copy link
Contributor

cowwoc commented Sep 21, 2021

🐛 Bug

To Reproduce

When I run:

import math
import os
from typing import Optional, Tuple

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch import Tensor
from torch.utils.data import DataLoader, Subset, Dataset

DETERMINISTIC = True
DETERMINISTIC_SEED = 41
# Debugging "one of the variables needed for gradient computation has been modified by an inplace operation"
torch.autograd.set_detect_anomaly(True)


class OutdoorTemperatureDataset(Dataset):
    def __init__(self, batch_size: int):
        self.batch_size = batch_size
        self.input_horizon = 60 * 5 * 2
        self.output_horizon = 1
        self.total_horizon = self.input_horizon + self.output_horizon
        # TODO: Why does crash only occur if tensor contains at least (batch_size + 134) entries?
        self.outdoor_temperature = torch.tensor([1.0]).repeat(batch_size + 134, self.total_horizon)

    def __getitem__(self, index) -> Tuple[Tensor, Tensor]:
        samples = torch.stack([self.outdoor_temperature[index]])
        # Convert [features, samples] to [samples, features]
        samples = samples.permute(1, 0)
        x = samples[:self.input_horizon, :]
        y = samples[self.input_horizon:, 0]
        return x, y

    def __len__(self):
        return self.outdoor_temperature.shape[0]


class ProcessContext:
    def __init__(self, dataset: OutdoorTemperatureDataset):
        self.input_horizon = dataset.input_horizon
        self.output_horizon = dataset.output_horizon
        train_size = max(1,
                         min(len(dataset) - 1,
                             math.ceil(len(dataset) * 0.9)))
        val_size = len(dataset) - train_size
        assert train_size > 0
        assert val_size > 0
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            Subset(dataset, range(0, (train_size + val_size))),
            [train_size, val_size])

    def get_train_dataset(self):
        return self.train_dataset

    def get_validation_dataset(self):
        return self.val_dataset

    def get_model(self, learning_rate: float, max_epochs: int, hidden_layer_size: int, batch_size: int):
        return Predictor(self.train_dataset, self.val_dataset, self.input_horizon, self.output_horizon,
                         learning_rate=learning_rate, max_epochs=max_epochs,
                         hidden_layer_size=hidden_layer_size, batch_size=batch_size)


class Predictor(LightningModule):
    def __init__(self, train_dataset: Dataset, val_dataset: Dataset, input_horizon: int, output_horizon: int,
                 learning_rate: float, max_epochs: int, hidden_layer_size: int, batch_size: int):
        super(Predictor, self).__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.input_horizon = input_horizon
        self.output_horizon = output_horizon
        self.total_horizon = self.input_horizon + self.output_horizon
        self.max_epochs = max_epochs
        self.learning_rate = learning_rate
        self.hidden_layer_size = hidden_layer_size

        self.input_norm = nn.LayerNorm(1)
        self.layer_norm = nn.LayerNorm(self.hidden_layer_size)
        self.gru = nn.GRU(1, self.hidden_layer_size, 1)

        self.linear_layer = nn.Linear(self.hidden_layer_size, self.output_horizon)
        self.loss_function = F.mse_loss
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=True,
                          pin_memory=True)

    def val_dataloader(self):
        return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size, pin_memory=True)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

    def forward(self, input):
        output = self.input_norm(input)
        # Input shape is [batch, sequence, feature] but lstm/gru expects [sequence, batch, feature]
        output = output.permute(1, 0, 2)
        output, _ = self.gru(output)
        # Extract the hidden layer of the last element of the sequence
        output = output[-1, :, :]
        output = F.relu(output)
        output = self.layer_norm(output)
        output = self.linear_layer(output)
        return output

    def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
        input, expected = batch

        actual = self(input)
        return self.loss_function(actual, expected)

    def validation_step(self, batch, batch_index) -> Optional[STEP_OUTPUT]:
        input, expected = batch

        actual = self(input)
        return self.loss_function(actual, expected)


def train(dataset: OutdoorTemperatureDataset, learning_rate: float, max_epochs: int,
          hidden_layer_size: int) -> float:
    process_context = ProcessContext(dataset)
    model = process_context.get_model(learning_rate, max_epochs, hidden_layer_size, dataset.batch_size)
    model.learning_rate = learning_rate

    trainer = Trainer(gpus=-1, benchmark=not DETERMINISTIC,  # precision=16,
                      weights_summary=None, max_epochs=max_epochs, deterministic=DETERMINISTIC,
                      auto_scale_batch_size=True)
    trainer.tune(model)
    trainer.fit(model)
    return trainer.logged_metrics["val_loss"]


def main():
    if DETERMINISTIC:
        # https://pytorch.org/docs/stable/notes/randomness.html
        pl.seed_everything(DETERMINISTIC_SEED, workers=True)
        torch.use_deterministic_algorithms(True)
        # https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM
        os.putenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
    LEARNING_RATE = 0.001
    MAX_EPOCHS = 1000
    HIDDEN_LAYER_SIZE = 64
    batch_size = 10240
    dataset = OutdoorTemperatureDataset(batch_size)
    train(dataset, LEARNING_RATE, MAX_EPOCHS, HIDDEN_LAYER_SIZE)


if __name__ == "__main__":
    main()

I get the following output:

C:\Users\Gili\Documents\daikin-one\python\Scripts\python.exe C:/Users/Gili/Documents/daikin-one/aggregator/src/main/python/com.holdmyspot.hvac.aggregator/testcase.py
Global seed set to 41
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\data_loading.py:105: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Global seed set to 41
C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\data_loading.py:105: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Batch size 2 succeeded, trying batch size 4
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 41
Batch size 4 succeeded, trying batch size 8
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 41
Batch size 8 succeeded, trying batch size 16
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 41
Batch size 16 succeeded, trying batch size 32
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 41
Batch size 32 succeeded, trying batch size 64
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 41
Batch size 64 succeeded, trying batch size 128
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 41
Batch size 128 succeeded, trying batch size 256
C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\data_loading.py:326: UserWarning: The number of training samples (37) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 41
Batch size 256 succeeded, trying batch size 512
C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\data_loading.py:326: UserWarning: The number of training samples (19) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 41
Batch size 512 succeeded, trying batch size 1024
C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\data_loading.py:326: UserWarning: The number of training samples (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 41
Batch size 1024 succeeded, trying batch size 2048
C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\data_loading.py:326: UserWarning: The number of training samples (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 41
Batch size 2048 succeeded, trying batch size 4096
C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\data_loading.py:326: UserWarning: The number of training samples (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 41
--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{}
--------------------------------------------------------------------------------
Batch size 4096 succeeded, trying batch size 8192
C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\data_loading.py:326: UserWarning: The number of training samples (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Global seed set to 41
Batch size 8192 failed, trying batch size 4096
Finished batch size finder, will continue with full run using batch size 4096
Restoring states from the checkpoint file at C:\Users\Gili\Documents\daikin-one\aggregator\src\main\python\com.holdmyspot.hvac.aggregator\scale_batch_size_temp_model.ckpt
Restored all states from the checkpoint file at C:\Users\Gili\Documents\daikin-one\aggregator\src\main\python\com.holdmyspot.hvac.aggregator\scale_batch_size_temp_model.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Epoch 0:   0%|          | 0/521 [00:00<?, ?it/s] Global seed set to 41
Traceback (most recent call last):
  File "C:\Users\Gili\Documents\daikin-one\aggregator\src\main\python\com.holdmyspot.hvac.aggregator\testcase.py", line 154, in <module>
    main()
  File "C:\Users\Gili\Documents\daikin-one\aggregator\src\main\python\com.holdmyspot.hvac.aggregator\testcase.py", line 150, in main
    train(dataset, LEARNING_RATE, MAX_EPOCHS, HIDDEN_LAYER_SIZE)
  File "C:\Users\Gili\Documents\daikin-one\aggregator\src\main\python\com.holdmyspot.hvac.aggregator\testcase.py", line 134, in train
    trainer.fit(model)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 552, in fit
    self._run(model)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 917, in _run
    self._dispatch()
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 985, in _dispatch
    self.accelerator.start_training(self)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 92, in start_training
    self.training_type_plugin.start_training(trainer)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py", line 161, in start_training
    self._results = trainer.run_stage()
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 995, in run_stage
    return self._run_train()
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1044, in _run_train
    self.fit_loop.run()
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\loops\base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 200, in advance
    epoch_output = self.epoch_loop.run(train_dataloader)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\loops\base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\loops\epoch\training_epoch_loop.py", line 130, in advance
    batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\loops\batch\training_batch_loop.py", line 100, in run
    super().run(batch, batch_idx, dataloader_idx)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\loops\base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\loops\batch\training_batch_loop.py", line 147, in advance
    result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\loops\batch\training_batch_loop.py", line 201, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\loops\batch\training_batch_loop.py", line 395, in _optimizer_step
    model_ref.optimizer_step(
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\core\lightning.py", line 1616, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\core\optimizer.py", line 206, in step
    self.__optimizer_step(closure=closure, profiler_name=profiler_name, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\core\optimizer.py", line 128, in __optimizer_step
    trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 296, in optimizer_step
    self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 303, in run_optimizer_step
    self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py", line 226, in optimizer_step
    optimizer.step(closure=lambda_closure, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\torch\optim\optimizer.py", line 88, in wrapper
    return func(*args, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\torch\autograd\grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\torch\optim\adam.py", line 66, in step
    loss = closure()
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\loops\batch\training_batch_loop.py", line 235, in _training_step_and_backward_closure
    result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\loops\batch\training_batch_loop.py", line 536, in training_step_and_backward
    result = self._training_step(split_batch, batch_idx, opt_idx, hiddens)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\loops\batch\training_batch_loop.py", line 306, in _training_step
    training_step_output = self.trainer.accelerator.training_step(step_kwargs)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 193, in training_step
    return self.training_type_plugin.training_step(*step_kwargs.values())
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py", line 172, in training_step
    return self.model.training_step(*args, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\aggregator\src\main\python\com.holdmyspot.hvac.aggregator\testcase.py", line 114, in training_step
    actual = self(input)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\aggregator\src\main\python\com.holdmyspot.hvac.aggregator\testcase.py", line 103, in forward
    output, _ = self.gru(output)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Gili\Documents\daikin-one\python\lib\site-packages\torch\nn\modules\rnn.py", line 837, in forward
    result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
RuntimeError: CUDA out of memory. Tried to allocate 7.03 GiB (GPU 0; 10.00 GiB total capacity; 2.46 GiB already allocated; 5.26 GiB free; 2.49 GiB reserved in total by PyTorch)

Process finished with exit code 1

Expected behavior

I am expecting the batch_size selected by the tuner to fit in the GPU memory but it does not. This is 100% reproducible on my machine.

Environment

Collecting environment information...
PyTorch version: 1.9.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Pro
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.9.7 (tags/v3.9.7:1016ef3, Aug 30 2021, 20:19:38) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19043-SP0
Is CUDA available: True
CUDA runtime version: 11.4.120
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080
Nvidia driver version: 471.96
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] pytorch-lightning==1.4.7
[pip3] torch==1.9.0+cu111
[pip3] torch-tb-profiler==0.2.1
[pip3] torchaudio==0.9.0
[pip3] torchmetrics==0.5.1
[pip3] torchvision==0.10.0+cu111
[conda] Could not collect
@cowwoc cowwoc added bug Something isn't working help wanted Open to be worked on labels Sep 21, 2021
@cowwoc
Copy link
Contributor Author

cowwoc commented Sep 21, 2021

I read that this functionality is based on https://github.com/BlackHC/toma and noticed that they invoke https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py#L10 after an out of memory occurs. I don't know if you do the same but when I tried adding:

gc.collect()
torch.cuda.empty_cache()

immediately after trainer.tune(model) it did not make a difference for me. However, if I set batch_size=4096 manually at the bottom of the testcase (instead of running the batch size tuner) it works. just fine. So there seems to be some bug in the batch_size tuner... Maybe it's holding onto memory (preventing GC) which prevents subsequent runs from succeeding.

@tchaton
Copy link
Contributor

tchaton commented Sep 22, 2021

@SkafteNicki Mind having a look into this ?

@tchaton tchaton added tuner priority: 2 Low priority task labels Sep 22, 2021
@grudloff
Copy link

I believe I am experiencing the same issue. I manage to evade the issue by manually calling scale_batch_size and then deleting the associated trainer like this:

trainer = pl.Trainer(...)
trainer.tuner.scale_batch_size(model, ...)
del trainer
# then train the model

@grudloff
Copy link

Related issues #8028 #8257

@twsl
Copy link
Contributor

twsl commented Oct 3, 2021

I experienced a similar behaviour as @cowwoc, especially in combination with auto_lr_find=True

@wollschlager
Copy link

I am having the same issue. Also the batch size that is found for me with "binsearch" is smaller than what I can use by manually selecting:

  • found 32
  • rejected everything above 33, ..
  • manually I can use 52

@carmocca carmocca added this to the v1.6.x milestone Oct 27, 2021
@veronicamorfi
Copy link

I have a similar issue when using binsearch batch scaling. I found that in batch_size_scaling function _run_binsearch_scaling the reset_train_dataloader is not called when you run into an OOM error. So the dataloader batch_size was kept as the failed one.

@AAnoosheh
Copy link

AAnoosheh commented Oct 29, 2021

I have a similar issue when using binsearch batch scaling. I found that in batch_size_scaling function _run_binsearch_scaling the reset_train_dataloader is not called when you run into an OOM error. So the dataloader batch_size was kept as the failed one.

I tried "power" scaling too and the issue remains, so it appears it's not just that, but this is also a bug to note. (Binsearch can return a different value from Power mode, as it searches further after dropping the size down, so it may be that your final Power-mode size was lower than what Binsearch found and happens to work?)

When inspecting the GPU memory after the tuner returns, 10GB was still reserved somehow, convincing me this is in fact a leak of some sort. (GC and cache-clearing do nothing for this)

@awaelchli awaelchli modified the milestones: v1.6.x, 1.5.x Nov 3, 2021
@twsl
Copy link
Contributor

twsl commented Nov 29, 2021

Any news on this? I tried to debug this a bit, but I haven't found any additional hints yet. #10243 makes sense, but doesn't seem to solve the problem.

@caillonantoine
Copy link

When encountering an OOM CUDA error, train loaders are not reloaded, so basically it keeps using the highest batch size, no matter if it did yield an OOM error or not.

One workaround is to force the loader reloading at the beginning of each loop, i.e in pytorch_lightning/tuner/batch_size_scaling.py

def _run_power_scaling(trainer: "pl.Trainer", model: "pl.LightningModule",
                       new_size: int, batch_arg_name: str,
                       max_trials: int) -> int:
    """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
    for _ in range(max_trials):
        trainer.reset_train_dataloader(model) # FORCE LOADER RESET

and

def _run_binsearch_scaling(trainer: "pl.Trainer", model: "pl.LightningModule",
                           new_size: int, batch_arg_name: str,
                           max_trials: int) -> int:
    """Batch scaling mode where the size is initially is doubled at each iteration
    until an OOM error is encountered. Hereafter, the batch size is further
    refined using a binary search"""
    low = 1
    high = None
    count = 0
    while True:
        trainer.reset_train_dataloader(model) # FORCE LOADER RESET

@circlecrystal
Copy link
Contributor

I have a similar issue when using binsearch batch scaling. I found that in batch_size_scaling function _run_binsearch_scaling the reset_train_dataloader is not called when you run into an OOM error. So the dataloader batch_size was kept as the failed one.

I tried "power" scaling too and the issue remains, so it appears it's not just that, but this is also a bug to note. (Binsearch can return a different value from Power mode, as it searches further after dropping the size down, so it may be that your final Power-mode size was lower than what Binsearch found and happens to work?)

When inspecting the GPU memory after the tuner returns, 10GB was still reserved somehow, convincing me this is in fact a leak of some sort. (GC and cache-clearing do nothing for this)

I tested simply calling tuner.lr_find function without calling auto_scale_batch_size first... and it reports OOM error as well.

@hendrik-b
Copy link

hendrik-b commented May 12, 2022

I have a similar issue (using a data module) - as far as I can see the tuner only sends the data to GPU in the first iteration. Then the batch size is increased and during the next call of self.fit_loop.run() the skip property of the loop is True, which avoids the whole processing of the model (including sending to GPU) so that the higher batch size is considered ok and the iteration continues.

@thiyagu-lily
Copy link

do we have any updates on this? Im still running into the same error only when both lr_finder and auto_scale_batch_size are set to true!

@carmocca carmocca removed this from the pl:1.6.x milestone Jul 28, 2022
@noamsgl
Copy link

noamsgl commented Jan 23, 2023

Still running into this issue.

  auto_select_gpus=True,
  auto_scale_batch_size="binsearch",
  auto_lr_find=False

pytorch-lightning version 1.8.6.

CUDA out of memory error.

@twsl
Copy link
Contributor

twsl commented Jan 24, 2023

@noamsgl There won't be a solution in PL soon, as it is caused by the way pytorch reserves and allocates memory, see: pytorch/pytorch#72117

@jlehrer1
Copy link

Note: still happening as of this date.

@mrembalski
Copy link

I encounter the same issue, why is this closed?

@DeepUpSonnemann
Copy link

I am seeing this still, too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment