Skip to content

Commit 06023d9

Browse files
committed
bpo-17013: Add Mock.call_event to allow waiting for calls
New methods allow tests to wait for calls executing in other threads.
1 parent 11b2ae7 commit 06023d9

File tree

6 files changed

+198
-1
lines changed

6 files changed

+198
-1
lines changed

Doc/library/unittest.mock-examples.rst

+19
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,25 @@ the ``something`` method:
7070
>>> real.method()
7171
>>> real.something.assert_called_once_with(1, 2, 3)
7272

73+
When testing mutltithreaded code it may be important to ensure that certain
74+
method is eventually called, e.g. as a result of scheduling asynchronous
75+
operation. :attr:`~Mock.call_event` exposes methods that allow to assert that:
76+
77+
>>> from threading import Timer
78+
>>>
79+
>>> class ProductionClass:
80+
... def method(self):
81+
... self.t1 = Timer(0.1, self.something, args=(1, 2, 3))
82+
... self.t1.start()
83+
... self.t2 = Timer(0.1, self.something, args=(4, 5, 6))
84+
... self.t2.start()
85+
... def something(self, a, b, c):
86+
... pass
87+
...
88+
>>> real = ProductionClass()
89+
>>> real.something = MagicMock()
90+
>>> real.method()
91+
>>> real.something.call_event.wait_for_call(call(4, 5, 6), timeout=1.0)
7392

7493

7594
Mock for Method Calls on an Object

Doc/library/unittest.mock.rst

+14
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,20 @@ the *new_callable* argument to :func:`patch`.
515515
>>> mock.call_count
516516
2
517517

518+
.. attribute:: call_event
519+
520+
An object that can be used in multithreaded tests to assert that a call was made.
521+
522+
- :meth:`wait(/, skip=0, timeout=None)` asserts that mock is called
523+
*skip* + 1 times during the *timeout*
524+
525+
- :meth:`wait_for(predicate, /, timeout=None)` asserts that
526+
*predicate* was ``True`` at least once during the *timeout*;
527+
*predicate* receives exactly one positional argument: the mock itself
528+
529+
- :meth:`wait_for_call(call, /, skip=0, timeout=None)` asserts that
530+
*call* has happened at least *skip* + 1 times during the *timeout*
531+
518532
.. attribute:: return_value
519533

520534
Set this to configure the value returned by calling the mock:

Lib/unittest/mock.py

+74
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
'mock_open',
2121
'PropertyMock',
2222
'seal',
23+
'CallEvent'
2324
)
2425

2526

@@ -31,6 +32,7 @@
3132
import sys
3233
import builtins
3334
import pkgutil
35+
import threading
3436
from asyncio import iscoroutinefunction
3537
from types import CodeType, ModuleType, MethodType
3638
from unittest.util import safe_repr
@@ -224,6 +226,7 @@ def reset_mock():
224226
ret.reset_mock()
225227

226228
funcopy.called = False
229+
funcopy.call_event = CallEvent(mock)
227230
funcopy.call_count = 0
228231
funcopy.call_args = None
229232
funcopy.call_args_list = _CallList()
@@ -446,6 +449,7 @@ def __init__(
446449
__dict__['_mock_delegate'] = None
447450

448451
__dict__['_mock_called'] = False
452+
__dict__['_mock_call_event'] = CallEvent(self)
449453
__dict__['_mock_call_args'] = None
450454
__dict__['_mock_call_count'] = 0
451455
__dict__['_mock_call_args_list'] = _CallList()
@@ -547,6 +551,7 @@ def __class__(self):
547551
return self._spec_class
548552

549553
called = _delegating_property('called')
554+
call_event = _delegating_property('call_event')
550555
call_count = _delegating_property('call_count')
551556
call_args = _delegating_property('call_args')
552557
call_args_list = _delegating_property('call_args_list')
@@ -584,6 +589,7 @@ def reset_mock(self, visited=None,*, return_value=False, side_effect=False):
584589
visited.append(id(self))
585590

586591
self.called = False
592+
self.call_event = CallEvent(self)
587593
self.call_args = None
588594
self.call_count = 0
589595
self.mock_calls = _CallList()
@@ -1111,6 +1117,7 @@ def _mock_call(self, /, *args, **kwargs):
11111117
def _increment_mock_call(self, /, *args, **kwargs):
11121118
self.called = True
11131119
self.call_count += 1
1120+
self.call_event._notify()
11141121

11151122
# handle call_args
11161123
# needs to be set here so assertions on call arguments pass before
@@ -2411,6 +2418,73 @@ def _format_call_signature(name, args, kwargs):
24112418
return message % formatted_args
24122419

24132420

2421+
class CallEvent(object):
2422+
def __init__(self, mock):
2423+
self._mock = mock
2424+
self._condition = threading.Condition()
2425+
2426+
def wait(self, /, skip=0, timeout=None):
2427+
"""
2428+
Wait for any call.
2429+
2430+
:param skip: How many calls will be skipped.
2431+
As a result, the mock should be called at least
2432+
``skip + 1`` times.
2433+
2434+
:param timeout: See :meth:`threading.Condition.wait`.
2435+
"""
2436+
def predicate(mock):
2437+
return mock.call_count > skip
2438+
2439+
self.wait_for(predicate, timeout=timeout)
2440+
2441+
def wait_for_call(self, call, /, skip=0, timeout=None):
2442+
"""
2443+
Wait for a given call.
2444+
2445+
:param skip: How many calls will be skipped.
2446+
As a result, the call should happen at least
2447+
``skip + 1`` times.
2448+
2449+
:param timeout: See :meth:`threading.Condition.wait`.
2450+
"""
2451+
def predicate(mock):
2452+
return mock.call_args_list.count(call) > skip
2453+
2454+
self.wait_for(predicate, timeout=timeout)
2455+
2456+
def wait_for(self, predicate, /, timeout=None):
2457+
"""
2458+
Wait for a given predicate to become True.
2459+
2460+
:param predicate: A callable that receives mock which result
2461+
will be interpreted as a boolean value.
2462+
The final predicate value is the return value.
2463+
2464+
:param timeout: See :meth:`threading.Condition.wait`.
2465+
"""
2466+
try:
2467+
self._condition.acquire()
2468+
2469+
def _predicate():
2470+
return predicate(self._mock)
2471+
2472+
b = self._condition.wait_for(_predicate, timeout)
2473+
2474+
if not b:
2475+
msg = (f"{self._mock._mock_name or 'mock'} was not called before"
2476+
f" timeout({timeout}).")
2477+
raise AssertionError(msg)
2478+
finally:
2479+
self._condition.release()
2480+
2481+
def _notify(self):
2482+
try:
2483+
self._condition.acquire()
2484+
self._condition.notify_all()
2485+
finally:
2486+
self._condition.release()
2487+
24142488

24152489
class _Call(tuple):
24162490
"""

Lib/unittest/test/testmock/support.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import concurrent.futures
2+
import time
3+
4+
15
target = {'foo': 'FOO'}
26

37

@@ -14,3 +18,15 @@ def wibble(self): pass
1418

1519
class X(object):
1620
pass
21+
22+
23+
def call_after_delay(func, /, *args, **kwargs):
24+
time.sleep(kwargs.pop('delay'))
25+
func(*args, **kwargs)
26+
27+
28+
def run_async(func, /, *args, executor=None, delay=0, **kwargs):
29+
if executor is None:
30+
executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
31+
32+
executor.submit(call_after_delay, func, *args, **kwargs, delay=delay)

Lib/unittest/test/testmock/testmock.py

+73-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from test.support import ALWAYS_EQ
77
import unittest
8-
from unittest.test.testmock.support import is_instance
8+
from unittest.test.testmock.support import is_instance, run_async
99
from unittest import mock
1010
from unittest.mock import (
1111
call, DEFAULT, patch, sentinel,
@@ -2249,6 +2249,78 @@ class Foo():
22492249
f'{__name__}.Typos', autospect=True, set_spec=True, auto_spec=True):
22502250
pass
22512251

2252+
def test_wait_until_called_before(self):
2253+
mock = Mock(spec=Something)()
2254+
mock.method_1()
2255+
mock.method_1.call_event.wait()
2256+
mock.method_1.assert_called_once()
2257+
2258+
def test_wait_until_called(self):
2259+
mock = Mock(spec=Something)()
2260+
run_async(mock.method_1, delay=0.01)
2261+
mock.method_1.call_event.wait()
2262+
mock.method_1.assert_called_once()
2263+
2264+
def test_wait_until_called_magic_method(self):
2265+
mock = MagicMock(spec=Something)()
2266+
run_async(mock.method_1.__str__, delay=0.01)
2267+
mock.method_1.__str__.call_event.wait()
2268+
mock.method_1.__str__.assert_called_once()
2269+
2270+
def test_wait_until_called_timeout(self):
2271+
mock = Mock(spec=Something)()
2272+
run_async(mock.method_1, delay=0.2)
2273+
2274+
with self.assertRaises(AssertionError):
2275+
mock.method_1.call_event.wait(timeout=0.1)
2276+
2277+
mock.method_1.assert_not_called()
2278+
mock.method_1.call_event.wait()
2279+
mock.method_1.assert_called_once()
2280+
2281+
def test_wait_until_any_call_positional(self):
2282+
mock = Mock(spec=Something)()
2283+
run_async(mock.method_1, 1, delay=0.1)
2284+
run_async(mock.method_1, 2, delay=0.2)
2285+
run_async(mock.method_1, 3, delay=0.3)
2286+
2287+
for arg in (1, 2, 3):
2288+
self.assertNotIn(call(arg), mock.method_1.mock_calls)
2289+
mock.method_1.call_event.wait_for(lambda m: call(arg) in m.call_args_list)
2290+
mock.method_1.assert_called_with(arg)
2291+
2292+
def test_wait_until_any_call_keywords(self):
2293+
mock = Mock(spec=Something)()
2294+
run_async(mock.method_1, a=1, delay=0.1)
2295+
run_async(mock.method_1, a=2, delay=0.2)
2296+
run_async(mock.method_1, a=3, delay=0.3)
2297+
2298+
for arg in (1, 2, 3):
2299+
self.assertNotIn(call(arg), mock.method_1.mock_calls)
2300+
mock.method_1.call_event.wait_for(lambda m: call(a=arg) in m.call_args_list)
2301+
mock.method_1.assert_called_with(a=arg)
2302+
2303+
def test_wait_until_any_call_no_argument(self):
2304+
mock = Mock(spec=Something)()
2305+
mock.method_1(1)
2306+
mock.method_1.assert_called_once_with(1)
2307+
2308+
with self.assertRaises(AssertionError):
2309+
mock.method_1.call_event.wait_for(lambda m: call() in m.call_args_list, timeout=0.01)
2310+
2311+
mock.method_1()
2312+
mock.method_1.call_event.wait_for(lambda m: call() in m.call_args_list, timeout=0.01)
2313+
2314+
def test_wait_until_call(self):
2315+
mock = Mock(spec=Something)()
2316+
mock.method_1()
2317+
run_async(mock.method_1, 1, a=1, delay=0.1)
2318+
2319+
with self.assertRaises(AssertionError):
2320+
mock.method_1.call_event.wait_for_call(call(1, a=1), timeout=0.01)
2321+
2322+
mock.method_1.call_event.wait_for_call(call(1, a=1))
2323+
22522324

22532325
if __name__ == '__main__':
22542326
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Add :attr:`call_event` to :class:`Mock` that allows to wait for the calls in
2+
multithreaded tests. Patch by Ilya Kulakov.

0 commit comments

Comments
 (0)