Skip to content

Dataloader hangs. Potential deadlock with set_num_threads in worker processes? #75147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
KimbingNg opened this issue Apr 2, 2022 · 3 comments
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler module: deadlock Problems related to deadlocks (hang without exiting) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@KimbingNg
Copy link

KimbingNg commented Apr 2, 2022

🐛 Bug ?

I have a main process running with pid 52422. Sometimes it get stucked when iterating over my dataloader with num_workers > 0 during training.

The threads of the main process:

(gdb) info threads
  Id   Target Id         Frame
* 1    Thread 0x7fc73ac5d700 (LWP 52422) "" select_poll_poll () at /usr/local/src/conda/python-3.8.12/Modules/clinic/selectmodule.c.h:219
  2    Thread 0x7fc6b59e6700 (LWP 52574) "" pthread_cond_wait@@GLIBC_2.3.2 () at ../sysdeps/unix/sysv/linux/x86_64/pthread_cond_wait.S:185
  3    Thread 0x7fc6b61e7700 (LWP 52575) "" pthread_cond_wait@@GLIBC_2.3.2 () at ../sysdeps/unix/sysv/linux/x86_64/pthread_cond_wait.S:185
  4    Thread 0x7fc6b69e8700 (LWP 52576) "" pthread_cond_wait@@GLIBC_2.3.2 () at ../sysdeps/unix/sysv/linux/x86_64/pthread_cond_wait.S:185
  5    Thread 0x7fc6b71e9700 (LWP 52742) "" 0x00007fc739c909c8 in accept4 (fd=7, addr=..., addr_len=0x7fc6b71e8e58, flags=524288) at ../sysdeps/unix/sysv/linux/accept4.c:40
  6    Thread 0x7fc6bae16700 (LWP 52748) "" 0x00007fc739c8384d in poll () at ../sysdeps/unix/syscall-template.S:84
  7    Thread 0x7fc6ba615700 (LWP 52776) "" pthread_cond_wait@@GLIBC_2.3.2 () at ../sysdeps/unix/sysv/linux/x86_64/pthread_cond_wait.S:185
  8    Thread 0x7fc6cab3d700 (LWP 53017) "" 0x00007fc73a879a15 in futex_abstimed_wait_cancelable (private=0, abstime=0x7fc6cab3c270, expected=0, futex_word=0x7fc6940008f0)
    at ../sysdeps/unix/sysv/linux/futex-internal.h:205
  9    Thread 0x7fc6ca2f0700 (LWP 56184) "" 0x00007fc73a879a15 in futex_abstimed_wait_cancelable (private=0, abstime=0x7fc6ca2ef270, expected=0, futex_word=0x7fc62c1f5510)
    at ../sysdeps/unix/sysv/linux/futex-internal.h:205
  10   Thread 0x7fc6c1fff700 (LWP 56250) "" pthread_cond_wait@@GLIBC_2.3.2 () at ../sysdeps/unix/sysv/linux/x86_64/pthread_cond_wait.S:185
  11   Thread 0x7fc6c832c700 (LWP 26359) "" 0x00007fc73a879827 in futex_abstimed_wait_cancelable (private=0, abstime=0x0, expected=0, futex_word=0x7fc630000ae0)
    at ../sysdeps/unix/sysv/linux/futex-internal.h:205
  12   Thread 0x7fc6c0ffd700 (LWP 26360) "" 0x00007fc73a879827 in futex_abstimed_wait_cancelable (private=0, abstime=0x0, expected=0, futex_word=0x7fc6240008c0)
    at ../sysdeps/unix/sysv/linux/futex-internal.h:205
  13   Thread 0x7fc6b9e14700 (LWP 26361) "" 0x00007fc73a879827 in futex_abstimed_wait_cancelable (private=0, abstime=0x0, expected=0, futex_word=0x7fc4a000fb80)
    at ../sysdeps/unix/sysv/linux/futex-internal.h:205
  14   Thread 0x7fc6c17fe700 (LWP 26362) "" 0x00007fc73a879827 in futex_abstimed_wait_cancelable (private=0, abstime=0x0, expected=0, futex_word=0x7fc4a4000d80)
    at ../sysdeps/unix/sysv/linux/futex-internal.h:205
52422
|__  26345
|__  26346
|__  26347
|__  26351

It has 4 subprocess (26345, 26346, 26347, 26351). One of them (26346) is getting blocked at the pthread_mutex_lock call, and the others are stucked at poll or accept4 calls.

The backtrace of 26346:

(gdb) bt
#0  __lll_lock_wait () at ../sysdeps/unix/sysv/linux/x86_64/lowlevellock.S:135
#1  0x00007fc73a873dbd in __GI___pthread_mutex_lock (mutex=0x5606dd44d080) at ../nptl/pthread_mutex_lock.c:80
#2  0x00007fc733b1b7b1 in __kmp_lock_suspend_mx () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so
#3  0x00007fc733acd435 in __kmp_free_thread () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so                   [390/7275]
#4  0x00007fc733acd5e0 in __kmp_set_num_threads () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so
#5  0x00007fc70f72ee95 in at::set_num_threads(int) () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#6  0x00007fc72667a1b6 in THPModule_setNumThreads(_object*, _object*) () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#7  0x00005606d3fa4f2b in cfunction_vectorcall_O () at /home/conda/feedstock_root/build_artifacts/python-split_1634073040851/work/Objects/methodobject.c:486
#8  0x00005606d404d133 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x5606e74b8628, callable=0x7fc727922590)
    at /home/conda/feedstock_root/build_artifacts/python-split_1634073040851/work/Include/cpython/abstract.h:115
#9  call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5606d4528e10)
    at /home/conda/feedstock_root/build_artifacts/python-split_1634073040851/work/Python/ceval.c:4963
#10 _PyEval_EvalFrameDefault () at /home/conda/feedstock_root/build_artifacts/python-split_1634073040851/work/Python/ceval.c:3469
#11 0x00005606d402bfc6 in PyEval_EvalFrameEx (throwflag=0, f=0x5606e74b83e0) at /home/conda/feedstock_root/build_artifacts/python-split_1634073040851/work/Python/ceval.c:738
#12 function_code_fastcall (globals=<optimized out>, nargs=<optimized out>, args=<optimized out>, co=<optimized out>)
    at /home/conda/feedstock_root/build_artifacts/python-split_1634073040851/work/Objects/call.c:284
#13 _PyFunction_Vectorcall () at /home/conda/feedstock_root/build_artifacts/python-split_1634073040851/work/Objects/call.c:411
#14 0x00005606d3fa116e in PyVectorcall_Call (kwargs=<optimized out>, tuple=<optimized out>, callable=0x7fc6d227db80)
    at /home/conda/feedstock_root/build_artifacts/python-split_1634073040851/work/Objects/call.c:200
#15 PyObject_Call () at /home/conda/feedstock_root/build_artifacts/python-split_1634073040851/work/Objects/call.c:228
#16 0x00005606d404a4ef in do_call_core (kwdict=0x7fc6c9714b80, callargs=0x7fc6cac0e700, func=0x7fc6d227db80, tstate=<optimized out>)
    at /home/conda/feedstock_root/build_artifacts/python-split_1634073040851/work/Python/ceval.c:5010
#17 _PyEval_EvalFrameDefault () at /home/conda/feedstock_root/build_artifacts/python-split_1634073040851/work/Python/ceval.c:3559
#18 0x00005606d402bfc6 in PyEval_EvalFrameEx (throwflag=0, f=0x7fc6c8e7c6c0) at /home/conda/feedstock_root/build_artifacts/python-split_1634073040851/work/Python/ceval.c:738
#19 function_code_fastcall (globals=<optimized out>, nargs=<optimized out>, args=<optimized out>, co=<optimized out>)

However, the mutex that 26346 waits on is already held by another thread, even though there is NO ANY OTHER thread in the process 26346 (The remaining 3 child processes all have at least 2 threads).

I found the owner of the mutex has the tid of 52574, which is one of the threads of the main process 52422):

(gdb) info threads
  Id   Target Id         Frame
* 1    Thread 0x7fc73ac5d700 (LWP 26346) "" 0x00007fc73a873dbd in __GI___pthread_mutex_lock (mutex=0x5606dd44d080) at ../nptl/pthread_mutex_lock.c:80
(gdb) p *mutex
$39 = {__data = {__lock = 2, __count = 0, __owner = 52574, __nusers = 1, __kind = 0, __spins = 0, __elision = 0, __list = {__prev = 0x0, __next = 0x0}},
  __size = "\002\000\000\000\000\000\000\000^\315\000\000\001", '\000' <repeats 26 times>, __align = 2}

The backtrace of owner thread 52574 in the main process is shown below:

(gdb) f 2
#2  0x00007fc733aef0b3 in bool __kmp_wait_template<kmp_flag_64<false, true>, true, false, true>(kmp_info*, kmp_flag_64<false, true>*, void*) [clone .constprop.0] ()
   from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so
(gdb) bt
#0  pthread_cond_wait@@GLIBC_2.3.2 () at ../sysdeps/unix/sysv/linux/x86_64/pthread_cond_wait.S:185
#1  0x00007fc733b1bb72 in void __kmp_suspend_64<false, true>(int, kmp_flag_64<false, true>*) ()
   from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so
#2  0x00007fc733aef0b3 in bool __kmp_wait_template<kmp_flag_64<false, true>, true, false, true>(kmp_info*, kmp_flag_64<false, true>*, void*) [clone .constprop.0] ()
   from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so
#3  0x00007fc733aef522 in __kmp_hyper_barrier_release(barrier_type, kmp_info*, int, int, int, void*) ()
   from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so
#4  0x00007fc733af7b26 in __kmp_fork_barrier(int, int) () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so
#5  0x00007fc733acf89e in __kmp_launch_thread () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so
#6  0x00007fc733b1c232 in __kmp_launch_worker(void*) () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so
#7  0x00007fc73a8716ba in start_thread (arg=0x7fc6b59e6700) at pthread_create.c:333
#8  0x00007fc739c8f51d in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:109

Why is the mutex of the child process (26346) being held by a thread (tid 52547) of the parent process (52422)?? I speculate that this might be the cause of the potential deadlock in Pytorch.

Any help? Thanks!

Versions

My environment:

PyTorch version: 1.11.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 16.04.7 LTS (x86_64)
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
Clang version: Could not collect
CMake version: version 3.14.5
Libc version: glibc-2.23

Python version: 3.8.12 | packaged by conda-forge | (default, Oct 12 2021, 21:59:51)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-4.15.0-142-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
GPU 2: GeForce GTX 1080 Ti
GPU 3: GeForce GTX 1080 Ti
GPU 4: GeForce GTX 1080 Ti
GPU 5: GeForce GTX 1080 Ti
GPU 6: GeForce GTX 1080 Ti
GPU 7: GeForce GTX 1080 Ti

Nvidia driver version: 440.36
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] pytorch-lightning==1.5.10
[pip3] torch==1.11.0
[pip3] torch-tb-profiler==0.3.1
[pip3] torchfile==0.1.0
[pip3] torchmetrics==0.6.0
[pip3] torchvision==0.10.0a0
[conda] cudatoolkit               10.2.89              h8f6ccaa_9    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
[conda] libblas                   3.9.0            12_linux64_mkl    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
[conda] libcblas                  3.9.0            12_linux64_mkl    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
[conda] liblapack                 3.9.0            12_linux64_mkl    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
[conda] liblapacke                3.9.0            12_linux64_mkl    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
[conda] magma                     2.5.4                h5da55e3_2    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
[conda] mkl                       2021.4.0           h8d4b97c_729    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
[conda] numpy                     1.21.2           py38he2449b9_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
[conda] pytorch-gpu               1.9.0           cuda102py38hf05f184_1    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
[conda] pytorch-lightning         1.5.10                   pypi_0    pypi
[conda] torch                     1.11.0                   pypi_0    pypi
[conda] torch-tb-profiler         0.3.1                    pypi_0    pypi
[conda] torchfile                 0.1.0                    pypi_0    pypi
[conda] torchmetrics              0.6.0              pyhd8ed1ab_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
[conda] torchvision               0.10.1          py38cuda102h1e64cea_0_cuda    https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge

cc @ssnl @VitalyFedyunin @ejguan @NivekT

@KimbingNg
Copy link
Author

KimbingNg commented Apr 2, 2022

Interestingly, when I interrupt my main process withkill -INT 52422, the process does not terminate, and it is able to run again. How could this happen?

@KimbingNg
Copy link
Author

I believe the cause of the issue is that one of the worker processes hangs at this line:

torch.set_num_threads(1)
forever, where it tries to obtain a locked mutex.

(gdb) bt
#0  __lll_lock_wait () at ../sysdeps/unix/sysv/linux/x86_64/lowlevellock.S:135
#1  0x00007fc73a873dbd in __GI___pthread_mutex_lock (mutex=0x5606dd44d080) at ../nptl/pthread_mutex_lock.c:80
#2  0x00007fc733b1b7b1 in __kmp_lock_suspend_mx () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so
#3  0x00007fc733acd435 in __kmp_free_thread () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so                   [390/7275]
#4  0x00007fc733acd5e0 in __kmp_set_num_threads () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/numpy/core/../../../../libiomp5.so
#5  0x00007fc70f72ee95 in at::set_num_threads(int) () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#6  0x00007fc72667a1b6 in THPModule_setNumThreads(_object*, _object*) () from /home/foo/anaconda3/envs/python/lib/python3.8/site-packages/torch/lib/libtorch_python.so

@KimbingNg KimbingNg changed the title Dataloader hangs. Potential deadlock when creating subprocess? Dataloader hangs. Potential deadlock with set_num_threads in worker processes? Apr 3, 2022
@mruberry mruberry added module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: deadlock Problems related to deadlocks (hang without exiting) labels Apr 4, 2022
@pconrad-insitro
Copy link

Hi @KimbingNg , thanks for reporting this! I stumbled on the same issue, with exactly the same stack traces. You provided some great pointers to help me keep digging.

This appears to be the same core issue as: #17199

In particular, #17199 (comment) and #17199 (comment) suggest that this is a known incompatibility of GNU openmp and fork+threads. That issue proposes some possible workarounds, but swapping to spawn multiprocessing for your dataloaders should be a straight forward fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler module: deadlock Problems related to deadlocks (hang without exiting) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants