Skip to content

Lightning is very slow between epochs, compared to PyTorch. #10389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
TheMrZZ opened this issue Nov 6, 2021 · 62 comments
Closed

Lightning is very slow between epochs, compared to PyTorch. #10389

TheMrZZ opened this issue Nov 6, 2021 · 62 comments
Assignees
Labels
bug Something isn't working help wanted Open to be worked on performance priority: 1 Medium priority task

Comments

@TheMrZZ
Copy link

TheMrZZ commented Nov 6, 2021

I converted some Pytorch code to Lightning. The dataset is loaded lazily by the train & eval dataloaders.

However, when moving the code to Lightning, I noticed a huge slowdown. After digging around, I noticed that there was a ~10 seconds delay between each epoch. For comparison, on my vanilla Pytorch, an epoch takes ~4s.

I first thought it was a data loading problem, but during the 10s delay, no data is loaded (at least that's what my print tell me).

I think the issue is related to the number of workers, because setting n_workers=0 solves the problem (but is slower in the end, since only one worker is not enough). I know starting workers is slow, however I have persistent_workers=True and this does not happen in normal Pytorch. My data loaders also have pin_memory=True (removing pin_memory does not solve the problem).

Since this is company code, I cannot disclose the before/after, but I'll try to "anonymize" some code if necessary. Here is the lightning module:

class RawModule(pl.LightningModule):
    def __init__(self):
        super(RawModule, self).__init__()

        self.encoder1 = nn.Sequential(...)
        self.encoder2 = nn.Sequential(...)

    def forward(self, data1, data2):
        result1 = self.encoder1(data1)
        result2 = self.encoder2(data2)

        result1 = result1 .view(result1 .size(0), -1)
        result2 = result2 .view(result2 .size(0), -1)

        result1 = F.normalize(result1 , p=2, dim=1)
        result2 = F.normalize(result2 , p=2, dim=1)


        return result1, result2

    
    def calculate_loss(self, batch):
        x, r, y = batch
        a, v = self.forward(r, x)

        d = nn.functional.cosine_similarity(a, v)
        loss = logloss(d.unsqueeze(1), y)

        return loss


class Module(RawModule):
    def training_step(self, batch, batch_idx):
        loss = self.calculate_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.calculate_loss(batch)
        self.log("validation_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer


if __name__ == '__main__':
    # stuff...

    train_loader = data_utils.DataLoader(
        train_dataset, batch_size=256, shuffle=True,
        num_workers=5, persistent_workers=True,
        pin_memory=True,
    )

    val_loader = data_utils.DataLoader(
        test_dataset, batch_size=256,
        num_workers=2, persistent_workers=True,
        pin_memory=True,
    )

    # Model
    load_from_pytorch = True

    if checkpoint_path is None:
        model = Module()

        if load_from_pytorch:
            if not checkpoint_path:
                raise ValueError("Please provide a checkpoint path")
            model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
    else:
        model = Module.load_from_checkpoint(checkpoint_path)


    trainer = pl.Trainer(
        gpus=1,
        max_epochs=5,
        check_val_every_n_epoch=10,
        log_every_n_steps=5,
    )
    trainer.fit(model, train_loader, val_loader)

Here is the result of profiler="simple":

Action                                  |  Mean duration (s)    |Num calls              |  Total time (s)       |  Percentage %         |
----------------------------------------------------------------------------------------------------------------------------------------
Total                                   |  -                    |_                      |  48.813               |  100 %                |
----------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                      |  27.922               |1                      |  27.922               |  57.202               |
fetch_next_sanity_check_batch           |  4.4013               |3                      |  13.204               |  27.05                |
get_sanity_check_batch                  |  4.4013               |3                      |  13.204               |  27.05                |
fetch_next_train_batch                  |  1.2734               |10                     |  12.734               |  26.087               |
get_train_batch                         |  1.2734               |10                     |  12.734               |  26.087               |
run_training_batch                      |  0.47733              |9                      |  4.296                |  8.8009               |
optimizer_step_with_closure_0           |  0.40089              |9                      |  3.608                |  7.3915               |
validation_step                         |  0.664                |2                      |  1.328                |  2.7206               |
evaluation_step_and_end                 |  0.664                |2                      |  1.328                |  2.7206               |
training_step_and_backward              |  0.12644              |9                      |  1.138                |  2.3313               |
backward                                |  0.096889             |9                      |  0.872                |  1.7864               |
training_step                           |  0.029556             |9                      |  0.266                |  0.54494              |
model_forward                           |  0.029556             |9                      |  0.266                |  0.54494              |
on_train_start                          |  0.016                |1                      |  0.016                |  0.032778             |

Here is the result of profiler="advanced": https://pastebin.com/q3C5P826.

Finally, here is a video demonstrating the problem. I'm printing each piece of data loading, to prove it's not the issue.
https://user-images.githubusercontent.com/30944236/140587623-ae184fa3-370a-42be-8593-200026d11ba4.mp4

Random informations:

  • OS: Windows 10
  • CPU: AMD Ryzen 5 5600X 6 Core
  • GPU: Nvidia RTX 3070
  • Pytorch version: 1.10.0
  • Pytorch Lightning version: 1.5.0
  • Cuda version: 11.5
  • How did I install Pytorch: conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
  • Python 3.8

cc @tchaton @rohitgr7 @Borda @akihironitta

@TheMrZZ TheMrZZ added bug Something isn't working help wanted Open to be worked on labels Nov 6, 2021
@TheMrZZ
Copy link
Author

TheMrZZ commented Nov 7, 2021

After a lot of digging around, I managed to pin down the line causing the problem.

It's the line 142 in loops/epoch/training_epoch_loop.py :

class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
    ...
    def on_run_start(self, data_fetcher: AbstractDataFetcher, **kwargs: Any) -> None:
        # hook
        self.trainer.logger_connector.on_epoch_start()
        self.trainer.call_hook("on_epoch_start")
        self.trainer.call_hook("on_train_epoch_start")
        self.trainer.fit_loop.epoch_progress.increment_started()

        self._reload_dataloader_state_dict(data_fetcher)
-->     self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx + 1)

Therefore, the culprit is:

def _update_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
    """Attach the dataloader."""
    if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
        # restore iteration
        dataloader_iter = enumerate(data_fetcher, batch_idx)
    else:
        dataloader_iter = iter(data_fetcher)
    return dataloader_iter

The on_run_start hook is called from loops/base.py :

class Loop(ABC, Generic[T]):
    ...
    def run(self, *args: Any, **kwargs: Any) -> T:
        if self.skip:
            return self.on_skip()

        self.reset()

-->     self.on_run_start(*args, **kwargs)
        ...

And this run method is called from loops/fit_loop.py :

class FitLoop(Loop):
    ...
    def advance(self) -> None:
        """Runs one whole epoch."""
        dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
        data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)

        with self.trainer.profiler.profile("run_training_epoch"):
-->         self.epoch_loop.run(data_fetcher)
        ...

The problem is the data_fetcher. It's indeed related to dataloaders, as expected. I still don't know what is the root reason, but I'll try to find it.

@a-kore
Copy link

a-kore commented Nov 7, 2021

Just a guess, but maybe having a different number of workers between the training and validation step so it may be spinning up news workers and getting rid of them between epochs?

Try making the number of workers equal for both train and val dataloaders (5-6 based on your CPU).

@TheMrZZ
Copy link
Author

TheMrZZ commented Nov 7, 2021

Just a guess, but maybe having a different number of workers between the training and validation step so it may be spinning up news workers and getting rid of them between epochs?

Try making the number of workers equal for both train and val dataloaders (5-6 based on your CPU).

I just tried this, it does not solve the problem. It would have been weird IMO, since it does not cause any problem with vanilla Pytorch.

@a-kore
Copy link

a-kore commented Nov 7, 2021

I see, that's interesting. Like you said, It does seem to be a dataloading issue. Maybe try removing

        check_val_every_n_epoch=10,
        log_every_n_steps=5

from the trainer call and explicitly set "reload_dataloaders_every_epoch=False" and see what happens.

Other than that, i'd try a fresh install of pytorch-lightning in a new venv.

@TheMrZZ
Copy link
Author

TheMrZZ commented Nov 7, 2021

I see, that's interesting. Like you said, It does seem to be a dataloading issue. Maybe try removing

        check_val_every_n_epoch=10,
        log_every_n_steps=5

from the trainer call and explicitly set "reload_dataloaders_every_epoch=False" and see what happens.

Other than that, i'd try a fresh install of pytorch-lightning in a new venv.

I just tried that, but it has no effect. The Trainer's methods _reset_eval_dataloader and reset_train_dataloader are never called (except once at the beginning), so it doesn't look like it's Lightning manually resetting the data loaders.

I also tried a fresh conda environment, but that didn't work either. I'm still trying to understand how the data loaders are getting reset, but I can't find anything really interesting yet.

@kaushikb11 kaushikb11 added the priority: 0 High priority task label Nov 7, 2021
@marcm-ml
Copy link
Contributor

marcm-ml commented Nov 8, 2021

I had a similar observation where data_fetcher caused unusual long run times. For me it was indeed fixed by completely disabling multiprocess dataloading (num_workers=0). Although, I have not tried to set "reload_dataloaders_every_epoch=False“. Interesting to see that others have the same issue/observation. Funnily, setting num_workers=0 has led me to open #10182. Perhaps, there is something more to this?

@TheMrZZ
Copy link
Author

TheMrZZ commented Nov 8, 2021

TL;DR: I just commented the self.reset() line of AbstractDataLoader, located at line 198 of pytorch_lightning/utilities/fetching.py. My code runs ~20x faster. It probably isn't a correct way to fix things, but my trainings work as well as they did before.

After a bunch of fiddling around, I decided to create a custom DataLoader and overload the __iter__ method. I discovered the problem was that the _iterator property of the DataLoader was always set to None somewhere between epochs. When _iterator is None, the DataLoader is reseted and needs to start everything from scratch.

# Original DataLoader
class DataLoader(Generic[T_co]):
    def __iter__(self):
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

# Custom DataLoader
class CustomDataLoader(DataLoader)
    def __iter__(self) -> '_BaseDataLoaderIter':
        print(f'\n> DataLoader __iter__ with {self._iterator=} starting.\n')
        return super().__iter__()

As you can see in the normal DataLoader, having self._iterator set to None causes a call to self._get_iterator(), which relaoads everything.

I decided to override _iterator with a custom property (getter & setter), to print the stack trace when ._iterator is set to None:

class CustomDataLoader(DataLoader)
    ...
   @property
    def _iterator(self):
        return self.__iterator

    @_iterator.setter
    def _iterator(self, value):
        if value is None:
            print('\nSetting __iterator to None. Stack trace:')
            import traceback
            traceback.print_stack()
        self.__iterator = value
        return self.__iterator

(I could also use the debugger for this)

This leads to 2 different yet very similar stack traces (respectively, evaluation & training loaders):

  File "env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1370, in _run_sanity_check
    self._evaluation_loop.run()
  File "env\lib\site-packages\pytorch_lightning\loops\base.py", line 144, in run
    self.advance(*args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\loops\dataloader\evaluation_loop.py", line 109, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "env\lib\site-packages\pytorch_lightning\loops\base.py", line 139, in run
    self.on_run_start(*args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\loops\epoch\evaluation_epoch_loop.py", line 87, in on_run_start
    self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_progress.current.ready)
  File "env\lib\site-packages\pytorch_lightning\loops\utilities.py", line 121, in _update_dataloader_iter
    dataloader_iter = enumerate(data_fetcher, batch_idx)
  File "env\lib\site-packages\pytorch_lightning\utilities\fetching.py", line 198, in __iter__
    self.reset()
  File "env\lib\site-packages\pytorch_lightning\utilities\fetching.py", line 214, in reset
    CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
  File "env\lib\site-packages\pytorch_lightning\trainer\supporters.py", line 498, in _shutdown_workers_and_reset_iterator
    dataloader._iterator = None
  File "env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1314, in _run_train
    self.fit_loop.run()
  File "env\lib\site-packages\pytorch_lightning\loops\base.py", line 144, in run
    self.advance(*args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 234, in advance
    self.epoch_loop.run(data_fetcher)
  File "env\lib\site-packages\pytorch_lightning\loops\base.py", line 139, in run
    self.on_run_start(*args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\loops\epoch\training_epoch_loop.py", line 142, in on_run_start
    self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx + 1)
  File "env\lib\site-packages\pytorch_lightning\loops\utilities.py", line 121, in _update_dataloader_iter
    dataloader_iter = enumerate(data_fetcher, batch_idx)
  File "env\lib\site-packages\pytorch_lightning\utilities\fetching.py", line 198, in __iter__
    self.reset()
  File "env\lib\site-packages\pytorch_lightning\utilities\fetching.py", line 212, in reset
    self.dataloader.reset()
  File "env\lib\site-packages\pytorch_lightning\trainer\supporters.py", line 504, in reset
    apply_to_collection(self.loaders, DataLoader, self._shutdown_workers_and_reset_iterator)
  File "env\lib\site-packages\pytorch_lightning\utilities\apply_func.py", line 92, in apply_to_collection
    return function(data, *args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\trainer\supporters.py", line 498, in _shutdown_workers_and_reset_iterator
    dataloader._iterator = None

Well... We're nearly there. It looks like advancing 1 epoch calls self.reset() on the DataFetcher itself, which then resets the DataLoader and leads to our problem.

Indeed, when checking AbstractDataFetcher, we have this:

class AbstractDataFetcher(...):
    def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
        if self.dataloader is None:
            raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
-->     self.reset()
        self.dataloader_iter = iter(self.dataloader)
        self._apply_patch()
        self.prefetching(self.prefetch_batches)
        return self

# And iter(AbstractDataFetcher) is called here, in utilities.py:
def _update_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
    """Attach the dataloader."""
    if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
        # restore iteration
-->     dataloader_iter = enumerate(data_fetcher, batch_idx)
    else:
        dataloader_iter = iter(data_fetcher)
    return dataloader_iter

So... I guess we found the problem? Each time a new epoch runs, it calls self.advance on the FitLoop/EvalLoop, which then calls self.on_run_start. on_run_start will create the _dataloader_iter (which is a normal behavior), which itself calls _update_dataloader_iter. This function, through enumerate, calls AbstractDataFetcher.__iter__ which calls self.reset(), entirely reloading the DataLoader.

Notice that self.reset() is also called in the __init__ method of AbstractDataFetcher, for setup purposes.

I just commented the self.reset() line of AbstractDataLoader, located at line 198 of pytorch_lightning/utilities/fetching.py.
While it speeds up the code a lot, and the entire training seems to work, it probably breaks a bunch of things ? I'd wait until the Lightning team fixes the problem before trying anything serious.

What a ride.

@a-kore
Copy link

a-kore commented Nov 9, 2021

Nice find, I'm on an older version (1.3.7) with a project I'm working on and I can't find a "fetching.py" under utilities. It must be something fairly new, maybe it was meant to be inside a conditional for "reload_dataloaders_every_epoch".

@carmocca
Copy link
Contributor

Should have been fixed by #10434 which landed with the 1.5.1 release

@TheMrZZ
Copy link
Author

TheMrZZ commented Nov 10, 2021

Tested the new 1.5.1 release today, looks like performance is back on track. Thanks to everyone!

@TheMrZZ TheMrZZ closed this as completed Nov 10, 2021
@tchaton
Copy link
Contributor

tchaton commented Nov 11, 2021

Dear @TheMrZZ,

Thanks for your investigation and happy we solved this ugly bug.

Best,
T.C

@isvogor-foi
Copy link
Contributor

Hi everyone... This topic is very interesting as I'm striking the same issue.
I'm comparing the same implementation, in Torch and Lightning. I came across this post, so I noticed I was using the ancient Lightning 1.4.7, so I updated to 1.5.9. I repeated my test and nothing changed... Torch was still significantly faster than lightning.
So, as @TheMrZZ suggested, I commented reset in the __iter__ function, and repeated the test. Sure enough, Lightning was... lightning fast now! I'm loading images so I got the following:

  • lightning (vanilla 1.5.9): 11 images/s, total runtime 407s
  • lightning (no mod): 108 images/s, total runtime 56s
  • torch (no mod): 40 images/s, total runtime 373s

Apparently, the performance issue has been fixed in 1.5.1, however, it seems that with 1.5.9 the reset line is still here. So, I'm curious, why do we need to reset the data fetcher, after each epoch?

@carmocca
Copy link
Contributor

@isvogor-foi Are you saying you are again experiencing this problem with version 1.5.9 but not 1.5.1?

If so, can you try the versions in-between and report back your findings?

@isvogor-foi
Copy link
Contributor

@carmocca Hi, well, I didn't try 1.5.1, I tried only 1.4.7, and 1.5.9. I'll see whether I can try 1.5.1 and get back at you with this!

@isvogor-foi
Copy link
Contributor

isvogor-foi commented Feb 1, 2022

@carmocca Hi... So I ran a test with 1.5.1. The performance issues are there... A simple resnet18+imagenet with 15000 images, for 15 epochs on V100, 4 workers, batch size 256, prefetch factor 2. The plot below shows the img/s loading.
As for, the runtime:

  • Torch: 381.46s
  • Lightning: 1354.31s
    The data is on a local scratch drive, and for process creation, I made sure that both approaches use the fork instead of spawn.

image

However, as already said by @TheMrZZ , removing the self.reset in __iter__ of fetching.py changes everything. Lightning performance increases multifold and outperforms Torch.

So, the question remains, why is self.reset necessary if it deleterious the performance so much?

@carmocca
Copy link
Contributor

carmocca commented Feb 1, 2022

@isvogor-foi Happy to look at the issue if you share that vanilla lightning example

@isvogor-foi
Copy link
Contributor

@carmocca Sure, vanilla - meaning, it's the one taken from official implementation:

Lightning: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/imagenet.py
Torch: https://github.com/pytorch/examples/blob/master/imagenet/main.py

Try running those on the same device, with the same image dataset, batch size, etc. I've tried on my machine, AWS instance, even Colab, and Torch is always much better.

@amin-nejad
Copy link
Contributor

Sounds like this issue should be reopened

@isvogor-foi
Copy link
Contributor

@amin-nejad I agree. @TheMrZZ shall we reopen?
@carmocca did you test?

@carmocca
Copy link
Contributor

carmocca commented Feb 8, 2022

Hi @isvogor-foi!

I had a look but I'm not observing any speed differences after commenting reset.

Can you describe exactly what changes are you making? Is the same behaviour reproducible in master?
You can also reach me in our Slack if you find that easier.

@isvogor-foi
Copy link
Contributor

isvogor-foi commented Feb 8, 2022

@carmocca This is very curious. There should be some difference.
So it should take effect if you use multiple epochs, e.g. 50. It should be faster since the DataLoader will not be reset after the epoch, and recreated. This recreation is usually expensive. Do you know whether you're using the "fork" or "spawn" setting?

Aha, it's also important not to use LightningDataModule.

@isvogor-foi
Copy link
Contributor

I downloaded the last version 1.6.0dev and saw there are some changes. Also I retired with MNIST, and (@TheMrZZ) seems that commenting reset() is dangerous, it only runs for the first epoch, and then kills other epochs.
That said, in that case Torch in my case runs much faster. I use a custom and same dataloader for both, on this example:
Lightning: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/imagenet.py
Torch: https://github.com/pytorch/examples/blob/master/imagenet/main.py
If I find anything else, I'll report back.

@carmocca
Copy link
Contributor

carmocca commented Feb 9, 2022

Here are some simple time results:

This was using 1 NVIDIA GeForce RTX 3090 and PyTorch Lightning 1.6.0dev, commit 8394770, torch==1.10.1, torchmetrics==0.7.0, torchvision==0.11.2, Python 3.8.12.

PyTorch Lightning

real    1m28.276s
user    5m29.551s
sys     0m34.544s

I used https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/imagenet.py verbatim, called with

time python imagenet.py fit --model.data_path /home/imagenet/data --trainer.limit_train_batches=100 --trainer.limit_val_batches=0 --trainer.max_epochs=2

PyTorch

real    1m29.662s
user    5m15.496s
sys     0m31.530s

I used https://github.com/pytorch/examples/blob/master/imagenet/main.py with the following changes applied to disable validation, and stop at 100 batches:

249c249
<         acc1 = validate(val_loader, model, criterion, args)
---
>         #acc1 = validate(val_loader, model, criterion, args)
252,253c252,253
<         is_best = acc1 > best_acc1
<         best_acc1 = max(acc1, best_acc1)
---
>         is_best = False
>         best_acc1 = best_acc1
281a282,283
>         if i == 100:
>             break
time python torch_imagenet.py /home/imagenet/data --epochs=2 --gpu=0

So basically the same speed.

@isvogor-foi
Copy link
Contributor

isvogor-foi commented Feb 9, 2022

Hm... very interesting. I did the same with mnist example that comes in the lightning examples. And noticed that Torch was only slightly better in my case, a second or so. However, I've just dismantled the advance call in training_epoch_loop.py. Not to explain details, I've added some arrows to indicate the execution timeline.

lightning

So the training loop is this:

with self.trainer.profiler.profile("run_training_batch"):
    batch_output = self.batch_loop.run(batch, batch_idx)

and before it, this part takes a lot of time:

  response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
  if response == -1:
      self.batch_progress.increment_processed()
      raise StopIteration

and after it there is another long call:

  self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
  self.trainer.call_hook("on_batch_end")
  self.trainer.logger_connector.on_batch_end()

Therefore excessive logging can slow it down.

So, in my particular case, I was using lightning 1.5.9 with one V100, and the aforementioned hook calls seem to build up over time, so in my long experiment with 100 epochs, with 256 batch size, and 35k images Torch performs better. Using just 1 GPU. I didn't try multiple GPUs. I think we can leave it at that.

@carmocca thanks!

@carapas
Copy link

carapas commented Mar 24, 2022

Hi @isvogor-foi, I the same issue as reported here and after updating pytorch-lightning I didn't see any improvement either. After reading OP blog: https://medium.com/@florian-ernst/finding-why-pytorch-lightning-made-my-training-4x-slower-ae64a4720bd1, I noticed that I was missing the persistent_workers=True flag on my DataLoader:

# My data Loader parameters
DataLoader(
  train_dataset, batch_size=64, shuffle=True, num_workers=n_workers,
  persistent_workers=True, pin_memory=True,
)

Hopefully this will help you! Performance was much improved for me.

@jzazo
Copy link

jzazo commented Oct 28, 2022

I am not using datamodules, I am passing dataloaders directly to the trainer. I am also not using save_hyperparameters functionality. But I will check whatever I am attaching to the module in case there is any serialization happening. I will profile the training as well. I think I will have some time next week to verify this and I will report back.

@Borda Borda added this to the v1.8.x milestone Oct 31, 2022
@Borda Borda added priority: 1 Medium priority task and removed priority: 0 High priority task labels Oct 31, 2022
@awaelchli
Copy link
Contributor

awaelchli commented Oct 31, 2022

@lminer You are correct, this is the right way to handle this when using save_hyperparameters().

@jzazo My recommendation is to check whether checkpointing is taking a long time (this happens between epochs). You could check by simply setting enable_checkpointing=False in the Trainer. Please feel free to report back in a new issue with your findings. Or ping us on slack if you need more guidance on debugging this.

@rohitgr7 rohitgr7 changed the title Lightning is very slow between epochs, compared to Pytorch. Lightning is very slow between epochs, compared to PyTorch. Oct 31, 2022
@carmocca carmocca removed this from the v1.8.x milestone Nov 1, 2022
@is-jlehrer
Copy link

This still seems to be the case, especially with dataloader startup time

@Jason94
Copy link

Jason94 commented Jul 5, 2023

Is this still fixed on 2.0.4? I'm still seeing this behavior. This code runs lightning fast (albeit with warnings about not having any workers):

    train_data_loader = DataLoader(
        NameDataset("data/training.csv", char_to_int),
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=True,
        # num_workers=4,
    )
    eval_data_loader = DataLoader(
        NameDataset("data/eval.csv", char_to_int, debug=False),
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=False,
        # num_workers=4,
    )

    lightning_model = PlContactEncoder(
        model, criterion, SIMILARITY_METRIC(0.5, return_distance=True), LEARNING_RATE
    )
    lightning_model.to(device)

    trainer = pl.Trainer(max_epochs=N_EPOCHS, max_steps=50)
    trainer.fit(
        model=lightning_model,
        train_dataloaders=train_data_loader,
        val_dataloaders=eval_data_loader,
    )

Adding the workers makes the warnings go away, but freezes ~15 seconds before validation or a new epoch:

    train_data_loader = DataLoader(
        NameDataset("data/training.csv", char_to_int),
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=True,
        num_workers=4,
    )
    eval_data_loader = DataLoader(
        NameDataset("data/eval.csv", char_to_int, debug=False),
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=False,
        num_workers=4,
    )

    lightning_model = PlContactEncoder(
        model, criterion, SIMILARITY_METRIC(0.5, return_distance=True), LEARNING_RATE
    )
    lightning_model.to(device)

    trainer = pl.Trainer(max_epochs=N_EPOCHS, max_steps=50)
    trainer.fit(
        model=lightning_model,
        train_dataloaders=train_data_loader,
        val_dataloaders=eval_data_loader,
    )

@aktgpt
Copy link

aktgpt commented Jan 7, 2024

Hi,

This issue still persists in 2.1.3. I'm directly passing the dataloaders to the Trainer
train_dataloader=DataLoader(train_dataset,batch_size=64,collate_fn=collate_fn_train,drop_last=True,shuffle=True,pin_memory=True,num_workers=8,prefetch_factor=8,persistent_workers=True,)
but the training freezes for a few seconds after each epoch. When I ran the advanced profiler in run_training_epoch, I saw that reset is called every epoch (dataloader.py:1086(_reset) and training_epoch_loop.py:143(reset)). The training is extremely slow if I set num_workers=0.

Is there a fix/workaround for this?

@aktgpt
Copy link

aktgpt commented Jan 16, 2024

Is there any update on this? It seems like I'm not the only one facing this issue. I started using PTL since 2.x and would prefer not to downgrade to 1.5.1 to make this issue go away. Is there a 2.x version that doesn't have this issue?
@awaelchli @Borda @carmocca

@leventt
Copy link

leventt commented Mar 21, 2024

This is indeed an issue with the latest version still... I have checked a few things such as persistent workers and so on and made a small comparison against vanilla PyTorch with a nested for loop and I can confirm "lightning" may be an inaccurate way to describe this library right now.

@otakudj
Copy link

otakudj commented Apr 25, 2024

I think this issue should be reopened. I meet this problem with the pl version 2.0.3

@jordan7186
Copy link

Perhaps this issue persists? Still experiencing similar freeze for validations, version is 2.2.4

@Lunamos
Copy link
Contributor

Lunamos commented Aug 1, 2024

This issue still exists in version 2.3.3.
With higher num_workers, the time between epochs is significantly longer. I tested the influence of saving checkpointing or hyperparameters and finds that these settings do not affect the runtime. The bug is the same as the initial finding which is caused by dataloader.

An easy fix can be setting num_workers=0 or add persistent_workers=True when instantiating dataloader.

We let such an ugly bug that has been fixed before seriously affect the operation speed of the entire Lightning 2.0.

I think this issue need to be reopened and fixed as soon as possible. @awaelchli

@leventt
Copy link

leventt commented Aug 1, 2024 via email

@KalinNonchev
Copy link

Hi, version v2.4 still has this issue, and the recommendations mentioned are not working.

@meakbiyik
Copy link

meakbiyik commented Aug 26, 2024

Folks, I might have a solution.

TL;DR: use OMP_NUM_THREADS=1 MKL_NUM_THREADS=1

The problem stems from a combination of weird torch defaults, using Slurm or some comparable scheduler tool without containerization, and a large num_worker count.

By default, torch uses as many threads as possible for interop and intraop operations. This "as many" is determined by the number of CPU cores in your system (see here).

If you are using a scheduler such as Slurm, torch will think that you have access to all the CPUs in your machine (since the node resources are visible to the job) even if you have limited the number of cores allocated to the job. Therefore, e.g. in a 100-core node, torch will spawn hundreds of threads for each worker, suffocating your system.

The solution is to reduce the number of threads that torch can spawn using the above-mentioned environment variables (1 is not required, I believe, but keeping it somewhat close to the actual number of CPUs would be smart). Alternatively, use containerization, people. Don't let Slurm pull you into its evil ways.

In my experiments, this seems to resolve a couple deadlocks I have been hitting, and considerably improve the behavior for this particular issue. There is still some delay when switching between train and validation workers, which might be a bug on lightning side (verification needed), but at least the training is now manageable.

This might be the same issue as #4450, or pretty much most other non-reproducible performance issues in torch/lightning repos.

@amirshamaeisynex
Copy link

I tracked down my problem to evaluation_loop.py in PL.
This line of code iter(data_fetcher) # creates the iterator inside the fetcher takes too much to run. I guess data_fetcher is culprit here.

@amirshamaeisynex
Copy link

in my case persistent_workers=True solved the issue

@Borda
Copy link
Member

Borda commented Oct 16, 2024

in my case persistent_workers=True solved the issue

Sounds good, would you give a try?

@JacobHelwig
Copy link

I noticed that decreasing number of model parameters resulted in a significant speedup in the time between epochs, so I tried disabling checkpointing. This made the time between epochs basically 0.

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, Subset

from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.utilities.types import STEP_OUTPUT

from argparse import ArgumentParser

class RandomDataset(Dataset):
    """
    .. warning::  This is meant for testing/debugging and is experimental.
    """

    def __init__(self, size: int, length: int):
        self.len = length
        self.data = torch.randn(length, 1, size, size)

    def __getitem__(self, index: int) -> Tensor:
        return self.data[index]

    def __len__(self) -> int:
        return self.len


class BoringDataModule(LightningDataModule):
    """
    .. warning::  This is meant for testing/debugging and is experimental.
    """

    def __init__(self, size: int) -> None:
        super().__init__()
        self.length = 100
        self.random_full = RandomDataset(size=size, length=self.length)

    def setup(self, stage: str) -> None:
        n = self.length // 2
        self.random_train = Subset(self.random_full, indices=range(n))
        self.random_val = Subset(self.random_full, indices=range(n, 2 * n))

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.random_train,num_workers=16,persistent_workers=True)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.random_val,num_workers=16,persistent_workers=True)


class DemoModel(LightningModule):
    """
    .. warning::  This is meant for testing/debugging and is experimental.
    """

    def __init__(self, dim: int = 10, size: int=32, learning_rate: float = 0.02):
        super().__init__()
        self.model = Net(dim=dim, size=size)
        self.learning_rate = learning_rate

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def step(self, batch: Any, batch_nb: int) -> STEP_OUTPUT:
        x = batch
        x = self(x)
        return x.sum()
    
    def training_step(self, batch: Any, batch_nb: int) -> STEP_OUTPUT:
        return self.step(batch, batch_nb)
    
    def validation_step(self, batch: Any, batch_nb: int) -> STEP_OUTPUT:
        return self.step(batch, batch_nb)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)


class Net(nn.Module):
    """
    .. warning::  This is meant for testing/debugging and is experimental.
    """

    def __init__(self, dim, size) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, dim, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim, 2 * dim, 3, 1, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(dim // 2 * size * size, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--dim", type=int, required=True)
    args = parser.parse_args()

    size = 32
    model = DemoModel(dim=args.dim, size=32)

    datamodule = BoringDataModule(size=size)
    trainer = Trainer(devices=1, max_epochs=2,enable_checkpointing=False)
    trainer.fit(model, datamodule=datamodule)

@meliksahturker
Copy link

In my case, it was not the data loader. It was that the Trainer writing model checkpoints to the disk. Setting lightning.Trainer(enable_checkpointing=False) solved thie issue for me.

@martinigoyanes
Copy link

martinigoyanes commented Dec 1, 2024

I am having the same issue, delay is almost 5min between epochs with a 50M param model (~200MB). I am using 2.2.2

As I decrease the number of params of my model the hang decreases.

Also if I trigger the same fit twice on a 12M param model on just one batch, the first one takes 80s and the next one takes 100ms..... I understand this is because the checkpoint is available in the second run, but even if I disable checkpointing the first one does not decrease time.

@fgdfgfthgr-fox
Copy link

Same here, the pause between epoch is much more obvious if batch size is huge. Setting persistent_workers=True reduces some slow down but is still very slow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on performance priority: 1 Medium priority task
Projects
No open projects
Status: Done
Development

No branches or pull requests