-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Lightning is very slow between epochs, compared to PyTorch. #10389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
After a lot of digging around, I managed to pin down the line causing the problem. It's the line 142 in loops/epoch/training_epoch_loop.py : class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
...
def on_run_start(self, data_fetcher: AbstractDataFetcher, **kwargs: Any) -> None:
# hook
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")
self.trainer.fit_loop.epoch_progress.increment_started()
self._reload_dataloader_state_dict(data_fetcher)
--> self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx + 1) Therefore, the culprit is: def _update_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
"""Attach the dataloader."""
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
# restore iteration
dataloader_iter = enumerate(data_fetcher, batch_idx)
else:
dataloader_iter = iter(data_fetcher)
return dataloader_iter The on_run_start hook is called from loops/base.py : class Loop(ABC, Generic[T]):
...
def run(self, *args: Any, **kwargs: Any) -> T:
if self.skip:
return self.on_skip()
self.reset()
--> self.on_run_start(*args, **kwargs)
... And this class FitLoop(Loop):
...
def advance(self) -> None:
"""Runs one whole epoch."""
dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)
with self.trainer.profiler.profile("run_training_epoch"):
--> self.epoch_loop.run(data_fetcher)
... The problem is the |
Just a guess, but maybe having a different number of workers between the training and validation step so it may be spinning up news workers and getting rid of them between epochs? Try making the number of workers equal for both train and val dataloaders (5-6 based on your CPU). |
I just tried this, it does not solve the problem. It would have been weird IMO, since it does not cause any problem with vanilla Pytorch. |
I see, that's interesting. Like you said, It does seem to be a dataloading issue. Maybe try removing
from the trainer call and explicitly set "reload_dataloaders_every_epoch=False" and see what happens. Other than that, i'd try a fresh install of pytorch-lightning in a new venv. |
I just tried that, but it has no effect. The Trainer's methods I also tried a fresh conda environment, but that didn't work either. I'm still trying to understand how the data loaders are getting reset, but I can't find anything really interesting yet. |
I had a similar observation where data_fetcher caused unusual long run times. For me it was indeed fixed by completely disabling multiprocess dataloading (num_workers=0). Although, I have not tried to set "reload_dataloaders_every_epoch=False“. Interesting to see that others have the same issue/observation. Funnily, setting num_workers=0 has led me to open #10182. Perhaps, there is something more to this? |
TL;DR: I just commented the After a bunch of fiddling around, I decided to create a custom DataLoader and overload the # Original DataLoader
class DataLoader(Generic[T_co]):
def __iter__(self):
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
# Custom DataLoader
class CustomDataLoader(DataLoader)
def __iter__(self) -> '_BaseDataLoaderIter':
print(f'\n> DataLoader __iter__ with {self._iterator=} starting.\n')
return super().__iter__() As you can see in the normal DataLoader, having I decided to override class CustomDataLoader(DataLoader)
...
@property
def _iterator(self):
return self.__iterator
@_iterator.setter
def _iterator(self, value):
if value is None:
print('\nSetting __iterator to None. Stack trace:')
import traceback
traceback.print_stack()
self.__iterator = value
return self.__iterator (I could also use the debugger for this) This leads to 2 different yet very similar stack traces (respectively, evaluation & training loaders):
Well... We're nearly there. It looks like advancing 1 epoch calls Indeed, when checking AbstractDataFetcher, we have this: class AbstractDataFetcher(...):
def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
if self.dataloader is None:
raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
--> self.reset()
self.dataloader_iter = iter(self.dataloader)
self._apply_patch()
self.prefetching(self.prefetch_batches)
return self
# And iter(AbstractDataFetcher) is called here, in utilities.py:
def _update_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
"""Attach the dataloader."""
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
# restore iteration
--> dataloader_iter = enumerate(data_fetcher, batch_idx)
else:
dataloader_iter = iter(data_fetcher)
return dataloader_iter So... I guess we found the problem? Each time a new epoch runs, it calls Notice that I just commented the What a ride. |
Nice find, I'm on an older version (1.3.7) with a project I'm working on and I can't find a "fetching.py" under utilities. It must be something fairly new, maybe it was meant to be inside a conditional for "reload_dataloaders_every_epoch". |
Should have been fixed by #10434 which landed with the 1.5.1 release |
Tested the new 1.5.1 release today, looks like performance is back on track. Thanks to everyone! |
Dear @TheMrZZ, Thanks for your investigation and happy we solved this ugly bug. Best, |
Hi everyone... This topic is very interesting as I'm striking the same issue.
Apparently, the performance issue has been fixed in 1.5.1, however, it seems that with 1.5.9 the reset line is still here. So, I'm curious, why do we need to reset the data fetcher, after each epoch? |
@isvogor-foi Are you saying you are again experiencing this problem with version 1.5.9 but not 1.5.1? If so, can you try the versions in-between and report back your findings? |
@carmocca Hi, well, I didn't try 1.5.1, I tried only 1.4.7, and 1.5.9. I'll see whether I can try 1.5.1 and get back at you with this! |
@carmocca Hi... So I ran a test with 1.5.1. The performance issues are there... A simple resnet18+imagenet with 15000 images, for 15 epochs on V100, 4 workers, batch size 256, prefetch factor 2. The plot below shows the img/s loading.
However, as already said by @TheMrZZ , removing the So, the question remains, why is |
@isvogor-foi Happy to look at the issue if you share that vanilla lightning example |
@carmocca Sure, vanilla - meaning, it's the one taken from official implementation: Lightning: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/imagenet.py Try running those on the same device, with the same image dataset, batch size, etc. I've tried on my machine, AWS instance, even Colab, and Torch is always much better. |
Sounds like this issue should be reopened |
@amin-nejad I agree. @TheMrZZ shall we reopen? |
Hi @isvogor-foi! I had a look but I'm not observing any speed differences after commenting reset. Can you describe exactly what changes are you making? Is the same behaviour reproducible in master? |
@carmocca This is very curious. There should be some difference. Aha, it's also important not to use LightningDataModule. |
I downloaded the last version 1.6.0dev and saw there are some changes. Also I retired with MNIST, and (@TheMrZZ) seems that commenting |
Here are some simple time results: This was using 1 PyTorch Lightning
I used https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/imagenet.py verbatim, called with time python imagenet.py fit --model.data_path /home/imagenet/data --trainer.limit_train_batches=100 --trainer.limit_val_batches=0 --trainer.max_epochs=2 PyTorch
I used https://github.com/pytorch/examples/blob/master/imagenet/main.py with the following changes applied to disable validation, and stop at 100 batches: 249c249
< acc1 = validate(val_loader, model, criterion, args)
---
> #acc1 = validate(val_loader, model, criterion, args)
252,253c252,253
< is_best = acc1 > best_acc1
< best_acc1 = max(acc1, best_acc1)
---
> is_best = False
> best_acc1 = best_acc1
281a282,283
> if i == 100:
> break time python torch_imagenet.py /home/imagenet/data --epochs=2 --gpu=0 So basically the same speed. |
Hm... very interesting. I did the same with So the training loop is this:
and before it, this part takes a lot of time:
and after it there is another long call:
Therefore excessive logging can slow it down. So, in my particular case, I was using lightning 1.5.9 with one V100, and the aforementioned hook calls seem to build up over time, so in my long experiment with 100 epochs, with 256 batch size, and 35k images Torch performs better. Using just 1 GPU. I didn't try multiple GPUs. I think we can leave it at that. @carmocca thanks! |
Hi @isvogor-foi, I the same issue as reported here and after updating pytorch-lightning I didn't see any improvement either. After reading OP blog: https://medium.com/@florian-ernst/finding-why-pytorch-lightning-made-my-training-4x-slower-ae64a4720bd1, I noticed that I was missing the persistent_workers=True flag on my DataLoader:
Hopefully this will help you! Performance was much improved for me. |
I am not using datamodules, I am passing dataloaders directly to the trainer. I am also not using save_hyperparameters functionality. But I will check whatever I am attaching to the module in case there is any serialization happening. I will profile the training as well. I think I will have some time next week to verify this and I will report back. |
@lminer You are correct, this is the right way to handle this when using @jzazo My recommendation is to check whether checkpointing is taking a long time (this happens between epochs). You could check by simply setting |
This still seems to be the case, especially with dataloader startup time |
Is this still fixed on 2.0.4? I'm still seeing this behavior. This code runs lightning fast (albeit with warnings about not having any workers): train_data_loader = DataLoader(
NameDataset("data/training.csv", char_to_int),
batch_size=TRAIN_BATCH_SIZE,
shuffle=True,
# num_workers=4,
)
eval_data_loader = DataLoader(
NameDataset("data/eval.csv", char_to_int, debug=False),
batch_size=TRAIN_BATCH_SIZE,
shuffle=False,
# num_workers=4,
)
lightning_model = PlContactEncoder(
model, criterion, SIMILARITY_METRIC(0.5, return_distance=True), LEARNING_RATE
)
lightning_model.to(device)
trainer = pl.Trainer(max_epochs=N_EPOCHS, max_steps=50)
trainer.fit(
model=lightning_model,
train_dataloaders=train_data_loader,
val_dataloaders=eval_data_loader,
) Adding the workers makes the warnings go away, but freezes ~15 seconds before validation or a new epoch: train_data_loader = DataLoader(
NameDataset("data/training.csv", char_to_int),
batch_size=TRAIN_BATCH_SIZE,
shuffle=True,
num_workers=4,
)
eval_data_loader = DataLoader(
NameDataset("data/eval.csv", char_to_int, debug=False),
batch_size=TRAIN_BATCH_SIZE,
shuffle=False,
num_workers=4,
)
lightning_model = PlContactEncoder(
model, criterion, SIMILARITY_METRIC(0.5, return_distance=True), LEARNING_RATE
)
lightning_model.to(device)
trainer = pl.Trainer(max_epochs=N_EPOCHS, max_steps=50)
trainer.fit(
model=lightning_model,
train_dataloaders=train_data_loader,
val_dataloaders=eval_data_loader,
) |
Hi, This issue still persists in 2.1.3. I'm directly passing the dataloaders to the Trainer Is there a fix/workaround for this? |
Is there any update on this? It seems like I'm not the only one facing this issue. I started using PTL since 2.x and would prefer not to downgrade to 1.5.1 to make this issue go away. Is there a 2.x version that doesn't have this issue? |
This is indeed an issue with the latest version still... I have checked a few things such as persistent workers and so on and made a small comparison against vanilla PyTorch with a nested for loop and I can confirm "lightning" may be an inaccurate way to describe this library right now. |
I think this issue should be reopened. I meet this problem with the pl version 2.0.3 |
Perhaps this issue persists? Still experiencing similar freeze for validations, version is 2.2.4 |
This issue still exists in version 2.3.3. An easy fix can be setting We let such an ugly bug that has been fixed before seriously affect the operation speed of the entire Lightning 2.0. I think this issue need to be reopened and fixed as soon as possible. @awaelchli |
persistent_workers=True
did not actually worked around the issue when I was testing
…On Wed, Jul 31, 2024 at 21:23 Jin Zehao ***@***.***> wrote:
This issue still exists in version 2.3.3.
With higher num_workers, the time between epochs is significantly longer.
I tested the influence of saving checkpointing or hyperparameters and finds
that these settings do not affect the runtime. The bug is the same as the
initial finding which is caused by dataloader.
An easy fix can be setting num_workers=0 or add persistent_workers=True
when instantiating dataloader.
We let such an ugly bug that has been fixed before seriously affect the
operation speed of the entire Lightning 2.0.
I think this issue need to be reopened and fixed as soon as possible.
@awaelchli <https://github.com/awaelchli>
—
Reply to this email directly, view it on GitHub
<#10389 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAP5M3FKYTFCQBF73YGK6JLZPGS2PAVCNFSM5HPZAV3KU5DIOJSWCZC7NNSXTN2JONZXKZKDN5WW2ZLOOQ5TEMRWGE4DQNRTGI3A>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
Hi, version v2.4 still has this issue, and the recommendations mentioned are not working. |
Folks, I might have a solution. TL;DR: use The problem stems from a combination of weird torch defaults, using Slurm or some comparable scheduler tool without containerization, and a large num_worker count. By default, torch uses as many threads as possible for interop and intraop operations. This "as many" is determined by the number of CPU cores in your system (see here). If you are using a scheduler such as Slurm, torch will think that you have access to all the CPUs in your machine (since the node resources are visible to the job) even if you have limited the number of cores allocated to the job. Therefore, e.g. in a 100-core node, torch will spawn hundreds of threads for each worker, suffocating your system. The solution is to reduce the number of threads that torch can spawn using the above-mentioned environment variables (1 is not required, I believe, but keeping it somewhat close to the actual number of CPUs would be smart). Alternatively, use containerization, people. Don't let Slurm pull you into its evil ways. In my experiments, this seems to resolve a couple deadlocks I have been hitting, and considerably improve the behavior for this particular issue. There is still some delay when switching between train and validation workers, which might be a bug on lightning side (verification needed), but at least the training is now manageable. This might be the same issue as #4450, or pretty much most other non-reproducible performance issues in torch/lightning repos. |
I tracked down my problem to evaluation_loop.py in PL. |
in my case persistent_workers=True solved the issue |
Sounds good, would you give a try? |
I noticed that decreasing number of model parameters resulted in a significant speedup in the time between epochs, so I tried disabling checkpointing. This made the time between epochs basically 0.
|
In my case, it was not the data loader. It was that the Trainer writing model checkpoints to the disk. Setting |
I am having the same issue, delay is almost 5min between epochs with a 50M param model (~200MB). I am using As I decrease the number of params of my model the hang decreases. Also if I trigger the same fit twice on a 12M param model on just one batch, the first one takes 80s and the next one takes 100ms..... I understand this is because the checkpoint is available in the second run, but even if I disable checkpointing the first one does not decrease time. |
Same here, the pause between epoch is much more obvious if batch size is huge. Setting persistent_workers=True reduces some slow down but is still very slow. |
Uh oh!
There was an error while loading. Please reload this page.
I converted some Pytorch code to Lightning. The dataset is loaded lazily by the train & eval dataloaders.
However, when moving the code to Lightning, I noticed a huge slowdown. After digging around, I noticed that there was a ~10 seconds delay between each epoch. For comparison, on my vanilla Pytorch, an epoch takes ~4s.
I first thought it was a data loading problem, but during the 10s delay, no data is loaded (at least that's what my
print
tell me).I think the issue is related to the number of workers, because setting
n_workers=0
solves the problem (but is slower in the end, since only one worker is not enough). I know starting workers is slow, however I havepersistent_workers=True
and this does not happen in normal Pytorch. My data loaders also havepin_memory=True
(removing pin_memory does not solve the problem).Since this is company code, I cannot disclose the before/after, but I'll try to "anonymize" some code if necessary. Here is the lightning module:
Here is the result of
profiler="simple"
:Here is the result of
profiler="advanced"
: https://pastebin.com/q3C5P826.Finally, here is a video demonstrating the problem. I'm printing each piece of data loading, to prove it's not the issue.
https://user-images.githubusercontent.com/30944236/140587623-ae184fa3-370a-42be-8593-200026d11ba4.mp4
Random informations:
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
cc @tchaton @rohitgr7 @Borda @akihironitta
The text was updated successfully, but these errors were encountered: