Skip to content

Process never ends when sending tensors through multiprocessing queues in Python 3.12+ on macOS #153050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
rafalh opened this issue May 7, 2025 · 7 comments
Labels
module: deadlock Problems related to deadlocks (hang without exiting) module: macos Mac OS related issues module: multiprocessing Related to torch.multiprocessing needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@rafalh
Copy link

rafalh commented May 7, 2025

🐛 Describe the bug

If a tensor is sent in multiprocessing queue, something blocks the process from ending after the end of script is reached (I have to press Ctrl+C to end the program).
It seems to be related to the resource tracker (multiprocessing.resource_tracker.ResourceTracker) process started by Python automatically, because when the process should end I can see resource tracker child process in the process tree and if I kill it the main process ends successfully.
The problem occurs in Python 3.12. It doesn't occur in Python 3.11. I am using macOS Sequoia. I tried running examples in Ubuntu container and couldn't reproduce the problem there, so it may be macOS specific. Multiple Torch versions are affected - I tested 2.2.0 (the oldest one installing successfully in Python 3.12) and 2.7.0 (the latest)
Calling multiprocessing.set_start_method("fork") fixes the issue (default start method is spawn), but it is not recommended according to Python docs. Start methods spawn and forkserver do not work.

Example using DataLoader:

from torch.utils.data import Dataset, DataLoader

class DummyDataset(Dataset):
    def __getitem__(self, index: int) -> int:
        return 1

    def __len__(self) -> int:
        return 10

def main() -> None:
    dataset = DummyDataset()
    data_loader = DataLoader(dataset, num_workers=1)
    for batch_idx, batch in enumerate(data_loader):
        print(batch_idx, batch)
    print("DONE?")

if __name__ == "__main__":
    main()

Example using just a tensor and a queue:

import torch.multiprocessing as multiprocessing
import threading
from torch import Tensor

def worker(q):
    q.put(Tensor(0))
    print("worker thread ended")

def main() -> None:
    q = multiprocessing.Queue()
    w = multiprocessing.Process(target=worker, args=(q,))
    w.start()
    w.join()
    print(q.get())
    print("DONE?")

if __name__ == "__main__":
    main()

In both cases program after printing "DONE?" does not end (unless interrupted with Ctrl+C) and the process tree looks like this:

~/tmp$ pstree 48529
-+= 48529 rafal.harabien /opt/homebrew/Cellar/[email protected]/3.12.10/Frameworks/Python.framework/Versions/3.12/Resources/Python.app/Contents/MacOS/Python /Users/rafal.harabien/minimal_mp_hang.py
 \--- 48530 rafal.harabien /opt/homebrew/Cellar/[email protected]/3.12.10/Frameworks/Python.framework/Versions/3.12/Resources/Python.app/Contents/MacOS/Python -c from multiprocessing.resource_tracker import main;main(6)

The second example works fine when sending non-tensor values, e.g. int.

Versions

((venv_py312) ) ~/tmp$ python collect_env.py
/Users/rafal.harabien/tmp/venv_py312/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py:276: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:81.)
cpu = _conversion_method_template(device=torch.device("cpu"))
Collecting environment information...
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.4.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.3)
CMake version: version 4.0.1
Libc version: N/A

Python version: 3.12.10 (main, Apr 8 2025, 11:35:47) [Clang 16.0.0 (clang-1600.0.26.6)] (64-bit runtime)
Python platform: macOS-15.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Pro

Versions of relevant libraries:
[pip3] torch==2.7.0
[conda] No relevant packages

cc @VitalyFedyunin @albanD @malfet

@malfet malfet added module: multiprocessing Related to torch.multiprocessing module: macos Mac OS related issues module: deadlock Problems related to deadlocks (hang without exiting) labels May 7, 2025
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 7, 2025
@albanD
Copy link
Collaborator

albanD commented May 7, 2025

Any chance you could share a stack trace of the hang?

From the two macos machines I could run it on, it either errors out or finishes just fine...

@albanD albanD added the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label May 7, 2025
@rafalh
Copy link
Author

rafalh commented May 7, 2025

Any chance you could share a stack trace of the hang?

When I press Ctrl+C process just ends without printing anything:

$ python3 tensor_queue_hang.py
worker thread ended
tensor([])
DONE?
^C

I tried attaching with GDB installed from Homebrew but I get "Don't know how to attach." error.

Actually I got something from Activity Monitor Sample Process function.

Main process:

Call graph:
    2497 Thread_234184   DispatchQueue_1: com.apple.main-thread  (serial)
      2497 start  (in dyld) + 6000  [0x19fc5eb4c]
        2497 Py_BytesMain  (in Python) + 40  [0x1018e9ca4]
          2497 pymain_main  (in Python) + 304  [0x1018e9c04]
            2497 Py_RunMain  (in Python) + 260  [0x1018e954c]
              2497 Py_FinalizeEx  (in Python) + 192  [0x1018bd374]
                2497 finalize_modules  (in Python) + 1624  [0x1018bdd7c]
                  2497 _PyModule_ClearDict  (in Python) + 644  [0x1017cd9e8]
                    2497 insertdict  (in Python) + 644  [0x1017b8b78]
                      2497 method_dealloc  (in Python) + 244  [0x10177d70c]
                        2497 subtype_dealloc  (in Python) + 1140  [0x1017eef34]
                          2497 slot_tp_finalize  (in Python) + 124  [0x1017f13b0]
                            2497 PyObject_CallOneArg  (in Python) + 112  [0x10177aefc]
                              2497 _PyEval_EvalFrameDefault  (in Python) + 42248  [0x10186f484]
                                2497 cfunction_vectorcall_FASTCALL  (in Python) + 96  [0x1017cb5f4]
                                  2497 os_waitpid  (in Python) + 92  [0x1018f4608]
                                    2497 __wait4  (in libsystem_kernel.dylib) + 8  [0x19ffc3204]

Resource tracker process:

Call graph:
    2504 Thread_234206   DispatchQueue_1: com.apple.main-thread  (serial)
      2504 start  (in dyld) + 6000  [0x19fc5eb4c]
        2504 Py_BytesMain  (in Python) + 40  [0x102db1ca4]
          2504 pymain_main  (in Python) + 304  [0x102db1c04]
            2504 Py_RunMain  (in Python) + 720  [0x102db1718]
              2504 PyRun_SimpleStringFlags  (in Python) + 64  [0x102d8d2e8]
                2504 PyRun_StringFlags  (in Python) + 124  [0x102d8d3bc]
                  2504 run_mod  (in Python) + 132  [0x102d8dd74]
                    2504 run_eval_code_obj  (in Python) + 88  [0x102d8fc94]
                      2504 PyEval_EvalCode  (in Python) + 184  [0x102d2cd0c]
                        2504 _PyEval_EvalFrameDefault  (in Python) + 38412  [0x102d36588]
                          2504 buffered_iternext  (in Python) + 160  [0x102dde9a8]
                            2504 PyObject_VectorcallMethod  (in Python) + 148  [0x102c440c4]
                              2504 method_vectorcall_FASTCALL  (in Python) + 112  [0x102c4df20]
                                2504 _buffered_readline  (in Python) + 504  [0x102ddeff8]
                                  2504 _bufferedreader_fill_buffer  (in Python) + 64  [0x102ddf5d4]
                                    2504 _bufferedreader_raw_read  (in Python) + 156  [0x102ddfd18]
                                      2504 PyObject_VectorcallMethod  (in Python) + 148  [0x102c440c4]
                                        2504 method_vectorcall_FASTCALL_KEYWORDS_METHOD  (in Python) + 136  [0x102c4e2ac]
                                          2504 _io_FileIO_readinto  (in Python) + 172  [0x102dda38c]
                                            2504 _Py_read  (in Python) + 76  [0x102dafa74]
                                              2504 read  (in libsystem_kernel.dylib) + 8  [0x19ffbc7dc]

Does it help?

Edit: after adding some prints I determined that this os.waitpid call is from ResourceTracker class _stop_locked method. So it seems closing the resource tracker pipe in the main process for some reason didn't end the tracker process as it should and the main process waits for it to exit indefinitely.

@malfet
Copy link
Contributor

malfet commented May 7, 2025

I can not reproduce the hang on my end, though to be fair I'm using local build rather than 2.7.0
[Edit] Just tried with 3.12 venv created from homebrew + 2.7.0, still no luck:

% /opt/homebrew/bin/python3.12 -mvenv test-hang
% source test-hang/bin/activate
% pip install torch
...
% python bug-153050.py 
/Users/nshulga/test-hang/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py:276: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:81.)
  cpu = _conversion_method_template(device=torch.device("cpu"))
/Users/nshulga/test-hang/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py:276: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:81.)
  cpu = _conversion_method_template(device=torch.device("cpu"))
worker thread ended
tensor([])
DONE?

@rafalh do you mind sharing output of pip freeze on your venv?

@rafalh
Copy link
Author

rafalh commented May 8, 2025

@rafalh do you mind sharing output of pip freeze on your venv?

((venv_py312) ) ~$ pip freeze              
filelock==3.18.0
fsspec==2025.3.2
Jinja2==3.1.6
MarkupSafe==3.0.2
mpmath==1.3.0
networkx==3.4.2
setuptools==80.3.1
sympy==1.14.0
torch==2.7.0
typing_extensions==4.13.2

@rafalh
Copy link
Author

rafalh commented May 8, 2025

This is basically how I test it:

brew install [email protected]
python3.12 -m venv venv_py312
. venv_py312/bin/activate
pip install torch==2.7.0
python script.py

A friend from work/team reproduced it as well, so it is not limited to my macBook.

@IsaevIlya
Copy link

We are experiencing the same issue on macOS, both on x86 and ARM architectures. The problem is specific to CPython 3.12.10, as everything works correctly with CPython 3.12.9. The issue doesn't occur on other Unix-based systems or with different Python versions. We haven't determined whether the root cause lies in PyTorch or Python itself.

@rafalh
Copy link
Author

rafalh commented May 8, 2025

This looks suspicious (from Python 3.12.10 changelog):

gh-88887: Fixing multiprocessing Resource Tracker process leaking, usually observed when running Python as PID 1.

Also this:

gh-118761: Reverts a change in the previous release attempting to make some stdlib imports used within the subprocess module lazy as this was causing errors during del finalizers calling methods such as terminate, or kill, or send_signal.

I didn't test yet with older Python version. I'll try tomorrow.

Edit:
I tested with Python installed by pyenv:
3.12.9 works fine
3.12.10 hangs

jet-tong added a commit to awslabs/s3-connector-for-pytorch that referenced this issue May 9, 2025
CPython 3.12.10 caused hanging issues in MacOS as it is unable to cleanup
multiprocessor resource tracker processes. See PyTorch issue #153050:
pytorch/pytorch#153050
IsaevIlya pushed a commit to awslabs/s3-connector-for-pytorch that referenced this issue May 9, 2025
CPython 3.12.10 caused hanging issues in MacOS as it is unable to cleanup
multiprocessor resource tracker processes. See PyTorch issue #153050:
pytorch/pytorch#153050
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: deadlock Problems related to deadlocks (hang without exiting) module: macos Mac OS related issues module: multiprocessing Related to torch.multiprocessing needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants