Skip to content

gh-124397: Add threading.iter_locked #133908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions Doc/library/threading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,48 @@ This module defines the following functions:
of the result, even when terminated.


.. function:: iter_locked(iterable)

Convert an iterable into an iterator that performs iteration using locks.

The ``iter_locked`` makes non-atomic iterators atomic::

class non_atomic_iterator:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure non_atomic_iterator is a good example of something that exists in the wild, because it's not thread-safe on GIL-ful builds either. The thread can arbitrarily switch during execution of Python code, so this would need a lock anyway if concurrent iteration were a use case. For free-threading, __next__ itself should have its own locking to prevent races. How was this design reached?


def __init__(self, it):
self.it = iter(it)

def __iter__(self):
return self

def __next__(self):
a = next(self.it)
b = next(self.it)
return a, b

atomic_iterator = iter_locked(non_atomic_iterator())

The ``iter_locked`` allows concurrent iteration over generator objects. For example::

def count():
i = 0
while True:
i += 1
yield i
concurrent_iterator = iter_locked(count())

The implementation is roughly equivalent to::

class iter_locked(Iterator):
def __init__(self, it):
self._it = iter(it)
self._lock = Lock()
def __next__(self):
with self._lock:
return next(self._it)

.. versionadded:: next

.. function:: main_thread()

Return the main :class:`Thread` object. In normal conditions, the
Expand Down
81 changes: 81 additions & 0 deletions Lib/test/test_free_threading/test_threading_iter_locked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import unittest
from threading import Thread, Barrier, iter_locked
from test.support import threading_helper


threading_helper.requires_working_threading(module=True)

class non_atomic_iterator:

def __init__(self, it):
self.it = iter(it)

def __iter__(self):
return self

def __next__(self):
a = next(self.it)
b = next(self.it)
return a, b

def count():
i = 0
while True:
i += 1
yield i

class iter_lockedThreading(unittest.TestCase):

@threading_helper.reap_threads
def test_iter_locked(self):
number_of_threads = 10
number_of_iterations = 10
barrier = Barrier(number_of_threads)
def work(it):
while True:
try:
a, b = next(it)
assert a + 1 == b
except StopIteration:
break

data = tuple(range(400))
for it in range(number_of_iterations):
iter_locked_iterator = iter_locked(non_atomic_iterator(data,))
worker_threads = []
for ii in range(number_of_threads):
worker_threads.append(
Thread(target=work, args=[iter_locked_iterator]))

with threading_helper.start_threads(worker_threads):
pass

barrier.reset()

@threading_helper.reap_threads
def test_iter_locked_generator(self):
number_of_threads = 5
number_of_iterations = 4
barrier = Barrier(number_of_threads)
def work(it):
barrier.wait()
for _ in range(1_000):
try:
next(it)
except StopIteration:
break

for it in range(number_of_iterations):
generator = iter_locked(count())
worker_threads = []
for ii in range(number_of_threads):
worker_threads.append(
Thread(target=work, args=[generator]))

with threading_helper.start_threads(worker_threads):
pass

barrier.reset()

if __name__ == "__main__":
unittest.main()
9 changes: 9 additions & 0 deletions Lib/test/test_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2416,6 +2416,15 @@ def run_last():
self.assertIn("RuntimeError: can't register atexit after shutdown",
err.decode())

class IterLockedTests(unittest.TestCase):

def test_iter_locked(self):
for s in ("123", [], [1, 2, 3], tuple(), (1, 2, 3)):
expected = list(s)
actual = list(threading.iter_locked(s))
self.assertEqual(actual, expected)
for arg in [1, None, True, sys]:
self.assertRaises(TypeError, threading.iter_locked, arg)

if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion Lib/threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
'Barrier', 'BrokenBarrierError', 'Timer', 'ThreadError',
'setprofile', 'settrace', 'local', 'stack_size',
'excepthook', 'ExceptHookArgs', 'gettrace', 'getprofile',
'setprofile_all_threads','settrace_all_threads']
'setprofile_all_threads','settrace_all_threads', 'iter_locked']

# Rename some stuff so "from threading import *" is safe
_start_joinable_thread = _thread.start_joinable_thread
Expand All @@ -42,6 +42,7 @@
get_ident = _thread.get_ident
_get_main_thread_ident = _thread._get_main_thread_ident
_is_main_interpreter = _thread._is_main_interpreter
iter_locked = _thread.iter_locked
try:
get_native_id = _thread.get_native_id
_HAVE_THREAD_NATIVE_ID = True
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add :meth:`threading.iter_locked` to make concurrent iteration over an iterable execute using a lock.
119 changes: 117 additions & 2 deletions Modules/_threadmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# include <signal.h> // SIGINT
#endif

#include "clinic/_threadmodule.c.h"

// ThreadError is just an alias to PyExc_RuntimeError
#define ThreadError PyExc_RuntimeError
Expand All @@ -30,6 +29,7 @@ static struct PyModuleDef thread_module;
// Module state
typedef struct {
PyTypeObject *excepthook_type;
PyTypeObject *iter_locked_type;
PyTypeObject *lock_type;
PyTypeObject *local_type;
PyTypeObject *local_dummy_type;
Expand All @@ -48,6 +48,15 @@ get_thread_state(PyObject *module)
return (thread_module_state *)state;
}

static inline thread_module_state *
find_state_by_type(PyTypeObject *tp)
{
PyObject *mod = PyType_GetModuleByDef(tp, &thread_module);
assert(mod != NULL);
return get_thread_state(mod);
}

#include "clinic/_threadmodule.c.h"

#ifdef MS_WINDOWS
typedef HRESULT (WINAPI *PF_GET_THREAD_DESCRIPTION)(HANDLE, PCWSTR*);
Expand All @@ -59,8 +68,10 @@ static PF_SET_THREAD_DESCRIPTION pSetThreadDescription = NULL;

/*[clinic input]
module _thread
class _thread.iter_locked "iter_locked_object *" "find_state_by_type(type)->iter_locked_type"

[clinic start generated code]*/
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=be8dbe5cc4b16df7]*/
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=cc495aee1743488d]*/


// _ThreadHandle type
Expand Down Expand Up @@ -731,6 +742,99 @@ static PyType_Spec ThreadHandle_Type_spec = {
ThreadHandle_Type_slots,
};

/* iter_locked object **************************************************************/

typedef struct {
PyObject_HEAD
PyObject *it;
} iter_locked_object;

#define iter_locked_object_CAST(op) ((iter_locked_object *)(op))

/*[clinic input]
@classmethod
_thread.iter_locked.__new__
iterable: object
/
Make an iterator thread-safe.
[clinic start generated code]*/

static PyObject *
_thread_iter_locked_impl(PyTypeObject *type, PyObject *iterable)
/*[clinic end generated code: output=4a8ad5a25f7c09ba input=ae6124177726e809]*/
{
/* Get iterator. */
PyObject *it = PyObject_GetIter(iterable);
if (it == NULL)
return NULL;

iter_locked_object *il = (iter_locked_object *)type->tp_alloc(type, 0);
if (il == NULL) {
Py_DECREF(it);
return NULL;
}
il->it = it;

return (PyObject *)il;
}

static void
iter_locked_dealloc(PyObject *op)
{
iter_locked_object *il = iter_locked_object_CAST(op);
PyTypeObject *tp = Py_TYPE(il);
PyObject_GC_UnTrack(il);
Py_DECREF(il->it);
tp->tp_free(il);
Py_DECREF(tp);
}

static int
iter_locked_traverse(PyObject *op, visitproc visit, void *arg)
{
iter_locked_object *lz = iter_locked_object_CAST(op);
Py_VISIT(Py_TYPE(lz));
Py_VISIT(lz->it);
return 0;
}

static PyObject *
iter_locked_next(PyObject *op)
{
iter_locked_object *lz = iter_locked_object_CAST(op);
PyObject *result = NULL;

Py_BEGIN_CRITICAL_SECTION(op); // lock on op or lz->it?
PyObject *it = lz->it;
result = PyIter_Next(it);
if (result == NULL) {
/* Note: StopIteration is already cleared by PyIter_Next() */
/* If PyErr_Occurred() we will also return NULL*/
}
Py_END_CRITICAL_SECTION();
return result;
}

static PyType_Slot iter_locked_slots[] = {
{Py_tp_dealloc, iter_locked_dealloc},
{Py_tp_getattro, PyObject_GenericGetAttr},
{Py_tp_doc, (void *)_thread_iter_locked__doc__},
{Py_tp_traverse, iter_locked_traverse},
{Py_tp_iter, PyObject_SelfIter},
{Py_tp_iternext, iter_locked_next},
{Py_tp_new, _thread_iter_locked},
{Py_tp_free, PyObject_GC_Del},
{0, NULL},
};

static PyType_Spec iter_locked_spec = {
.name = "threading.iter_locked",
.basicsize = sizeof(iter_locked_object),
.flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE |
Py_TPFLAGS_IMMUTABLETYPE),
.slots = iter_locked_slots,
};

/* Lock objects */

typedef struct {
Expand Down Expand Up @@ -2631,6 +2735,15 @@ thread_module_exec(PyObject *module)
return -1;
}

// iter_locked
state->iter_locked_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &iter_locked_spec, NULL);
if (state->iter_locked_type == NULL) {
return -1;
}
if (PyModule_AddType(module, state->iter_locked_type) < 0) {
return -1;
}

// Lock
state->lock_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &lock_type_spec, NULL);
if (state->lock_type == NULL) {
Expand Down Expand Up @@ -2739,6 +2852,7 @@ thread_module_traverse(PyObject *module, visitproc visit, void *arg)
{
thread_module_state *state = get_thread_state(module);
Py_VISIT(state->excepthook_type);
Py_VISIT(state->iter_locked_type);
Py_VISIT(state->lock_type);
Py_VISIT(state->local_type);
Py_VISIT(state->local_dummy_type);
Expand All @@ -2751,6 +2865,7 @@ thread_module_clear(PyObject *module)
{
thread_module_state *state = get_thread_state(module);
Py_CLEAR(state->excepthook_type);
Py_CLEAR(state->iter_locked_type);
Py_CLEAR(state->lock_type);
Py_CLEAR(state->local_type);
Py_CLEAR(state->local_dummy_type);
Expand Down
34 changes: 32 additions & 2 deletions Modules/clinic/_threadmodule.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading