From 5b8ef2cc59e334badbc1addb56df84a91458d394 Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Wed, 14 Feb 2018 08:48:54 -0500 Subject: [PATCH] Fix how thread states are created (#1276) Having pybind11 keep its own internal thread state can lead to an inconsistent situation where the Python interpreter has a thread state but pybind does not, and then when gil_scoped_acquire is called, pybind creates a new thread state instead of using the one created by the Python interpreter. This change gets rid of pybind's internal thread state and always uses the one created by the Python interpreter. --- include/pybind11/detail/internals.h | 3 --- include/pybind11/pybind11.h | 26 +++----------------------- tests/CMakeLists.txt | 1 + tests/conftest.py | 5 +++++ tests/test_threads.cpp | 27 +++++++++++++++++++++++++++ tests/test_threads.py | 19 +++++++++++++++++++ 6 files changed, 55 insertions(+), 26 deletions(-) create mode 100644 tests/test_threads.cpp create mode 100644 tests/test_threads.py diff --git a/include/pybind11/detail/internals.h b/include/pybind11/detail/internals.h index e39f38695f..e33922e35b 100644 --- a/include/pybind11/detail/internals.h +++ b/include/pybind11/detail/internals.h @@ -79,7 +79,6 @@ struct internals { PyTypeObject *default_metaclass; PyObject *instance_base; #if defined(WITH_THREAD) - decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x PyInterpreterState *istate = nullptr; #endif }; @@ -166,8 +165,6 @@ PYBIND11_NOINLINE inline internals &get_internals() { #if defined(WITH_THREAD) PyEval_InitThreads(); PyThreadState *tstate = PyThreadState_Get(); - internals_ptr->tstate = PyThread_create_key(); - PyThread_set_key_value(internals_ptr->tstate, tstate); internals_ptr->istate = tstate->interp; #endif builtins[id] = capsule(internals_pp); diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 977045d623..0d1d666ad5 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -1755,7 +1755,7 @@ class gil_scoped_acquire { public: PYBIND11_NOINLINE gil_scoped_acquire() { auto const &internals = detail::get_internals(); - tstate = (PyThreadState *) PyThread_get_key_value(internals.tstate); + tstate = PyGILState_GetThisThreadState(); if (!tstate) { tstate = PyThreadState_New(internals.istate); @@ -1764,10 +1764,6 @@ class gil_scoped_acquire { pybind11_fail("scoped_acquire: could not create thread state!"); #endif tstate->gilstate_counter = 0; - #if PY_MAJOR_VERSION < 3 - PyThread_delete_key_value(internals.tstate); - #endif - PyThread_set_key_value(internals.tstate, tstate); } else { release = detail::get_thread_state_unchecked() != tstate; } @@ -1806,7 +1802,6 @@ class gil_scoped_acquire { #endif PyThreadState_Clear(tstate); PyThreadState_DeleteCurrent(); - PyThread_delete_key_value(detail::get_internals().tstate); release = false; } } @@ -1825,30 +1820,15 @@ class gil_scoped_release { public: explicit gil_scoped_release(bool disassoc = false) : disassoc(disassoc) { // `get_internals()` must be called here unconditionally in order to initialize - // `internals.tstate` for subsequent `gil_scoped_acquire` calls. Otherwise, an + // `internals.istate` for subsequent `gil_scoped_acquire` calls. Otherwise, an // initialization race could occur as multiple threads try `gil_scoped_acquire`. - const auto &internals = detail::get_internals(); + detail::get_internals(); tstate = PyEval_SaveThread(); - if (disassoc) { - auto key = internals.tstate; - #if PY_MAJOR_VERSION < 3 - PyThread_delete_key_value(key); - #else - PyThread_set_key_value(key, nullptr); - #endif - } } ~gil_scoped_release() { if (!tstate) return; PyEval_RestoreThread(tstate); - if (disassoc) { - auto key = detail::get_internals().tstate; - #if PY_MAJOR_VERSION < 3 - PyThread_delete_key_value(key); - #endif - PyThread_set_key_value(key, tstate); - } } private: PyThreadState *tstate; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8f2f300ef7..410de1899a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -57,6 +57,7 @@ set(PYBIND11_TEST_FILES test_smart_ptr.cpp test_stl.cpp test_stl_binders.cpp + test_threads.cpp test_virtual_functions.cpp ) diff --git a/tests/conftest.py b/tests/conftest.py index f4c228260b..8630f829c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -199,6 +199,10 @@ def pytest_namespace(): from pybind11_tests.eigen import have_eigen except ImportError: have_eigen = False + try: + import threading + except ImportError: + threading = None pypy = platform.python_implementation() == "PyPy" skipif = pytest.mark.skipif @@ -210,6 +214,7 @@ def pytest_namespace(): reason="eigen and/or numpy are not installed"), 'requires_eigen_and_scipy': skipif(not have_eigen or not scipy, reason="eigen and/or scipy are not installed"), + 'requires_threading': skipif(not threading, reason="no threading"), 'unsupported_on_pypy': skipif(pypy, reason="unsupported on PyPy"), 'unsupported_on_py2': skipif(sys.version_info.major < 3, reason="unsupported on Python 2.x"), diff --git a/tests/test_threads.cpp b/tests/test_threads.cpp new file mode 100644 index 0000000000..3d1e88e000 --- /dev/null +++ b/tests/test_threads.cpp @@ -0,0 +1,27 @@ +#include "pybind11_tests.h" + +#if defined(WITH_THREAD) + +#include + +static bool check_threadstate() { + return PyGILState_GetThisThreadState() == PyThreadState_Get(); +} + +TEST_SUBMODULE(threads, m) { + m.def("check_pythread", []() -> bool { + py::gil_scoped_acquire acquire1; + return check_threadstate(); + }, py::call_guard()) + .def("check_cthread", []() -> bool { + bool result = false; + std::thread thread([&result]() { + py::gil_scoped_acquire acquire; + result = check_threadstate(); + }); + thread.join(); + return result; + }, py::call_guard()); +} + +#endif diff --git a/tests/test_threads.py b/tests/test_threads.py new file mode 100644 index 0000000000..b52a346232 --- /dev/null +++ b/tests/test_threads.py @@ -0,0 +1,19 @@ +import pytest + +pytestmark = pytest.requires_threading + +with pytest.suppress(ImportError): + import threading + from pybind11_tests import threads as t + + +def test_threads(): + def pythread_routine(): + threading.current_thread()._return = t.check_pythread() + + thread = threading.Thread(target=pythread_routine) + thread.start() + thread.join() + assert thread._return + + assert t.check_cthread()