Skip to content

Commit 681abde

Browse files
mariocj89tirkarthi
andcommitted
mock: Add EventMock class
Add a new class that allows to wait for a call to happen by using `Event` objects. This mock class can be used to test and validate expectations of multithreading code. Co-authored-by: Karthikeyan Singaravelan <[email protected]>
1 parent c7437e2 commit 681abde

File tree

4 files changed

+263
-0
lines changed

4 files changed

+263
-0
lines changed

Doc/library/unittest.mock.rst

+35
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,10 @@ The Mock Class
204204
import asyncio
205205
import inspect
206206
import unittest
207+
import threading
207208
from unittest.mock import sentinel, DEFAULT, ANY
208209
from unittest.mock import patch, call, Mock, MagicMock, PropertyMock, AsyncMock
210+
from unittest.mock import ThreadingMock
209211
from unittest.mock import mock_open
210212

211213
:class:`Mock` is a flexible mock object intended to replace the use of stubs and
@@ -1097,6 +1099,39 @@ object::
10971099
[call('foo'), call('bar')]
10981100

10991101

1102+
.. class:: ThreadingMock(spec=None, side_effect=None, return_value=DEFAULT, wraps=None, name=None, spec_set=None, unsafe=False, **kwargs)
1103+
1104+
A version of :class:`MagicMock` for multithreading tests. The
1105+
:class:`ThreadingMock` object provides extra methods to wait for a call to
1106+
happen on a different thread, rather than assert on it immediately.
1107+
1108+
.. method:: wait_until_called(mock_timeout=None)
1109+
1110+
Waits until the the mock is called.
1111+
If ``mock_timeout`` is set, after that number of seconds waiting,
1112+
it raises an :exc:`AssertionError`, waits forever otherwise.
1113+
1114+
>>> mock = ThreadingMock()
1115+
>>> thread = threading.Thread(target=mock)
1116+
>>> thread.start()
1117+
>>> mock.wait_until_called(mock_timeout=1)
1118+
>>> thread.join()
1119+
1120+
.. method:: wait_until_any_call(*args, mock_timeout=None, **kwargs)
1121+
1122+
Waits until the the mock is called with the specified arguments.
1123+
If ``mock_timeout`` is set, after that number of seconds waiting,
1124+
it raises an :exc:`AssertionError`, waits forever otherwise.
1125+
1126+
>>> mock = ThreadingMock()
1127+
>>> thread = threading.Thread(target=mock, args=(1,2,), kwargs={"arg": "thing"})
1128+
>>> thread.start()
1129+
>>> mock.wait_until_any_call(1, 2, arg="thing")
1130+
>>> thread.join()
1131+
1132+
.. versionadded:: 3.10
1133+
1134+
11001135
Calling
11011136
~~~~~~~
11021137

Lib/unittest/mock.py

+55
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
'call',
1515
'create_autospec',
1616
'AsyncMock',
17+
'ThreadingMock',
1718
'FILTER_DIR',
1819
'NonCallableMock',
1920
'NonCallableMagicMock',
@@ -31,6 +32,7 @@
3132
import sys
3233
import builtins
3334
from asyncio import iscoroutinefunction
35+
import threading
3436
from types import CodeType, ModuleType, MethodType
3537
from unittest.util import safe_repr
3638
from functools import wraps, partial
@@ -2851,6 +2853,59 @@ def __set__(self, obj, val):
28512853
self(val)
28522854

28532855

2856+
class ThreadingMock(MagicMock):
2857+
"""
2858+
A mock that can be used to wait until on calls happening
2859+
in a different thread.
2860+
"""
2861+
2862+
def __init__(self, *args, **kwargs):
2863+
_safe_super(ThreadingMock, self).__init__(*args, **kwargs)
2864+
self.__dict__["_event"] = threading.Event()
2865+
self.__dict__["_expected_calls"] = []
2866+
self.__dict__["_events_lock"] = threading.Lock()
2867+
2868+
def __get_event(self, expected_args, expected_kwargs):
2869+
with self._events_lock:
2870+
for args, kwargs, event in self._expected_calls:
2871+
if (args, kwargs) == (expected_args, expected_kwargs):
2872+
return event
2873+
new_event = threading.Event()
2874+
self._expected_calls.append((expected_args, expected_kwargs, new_event))
2875+
return new_event
2876+
2877+
2878+
def _mock_call(self, *args, **kwargs):
2879+
ret_value = _safe_super(ThreadingMock, self)._mock_call(*args, **kwargs)
2880+
2881+
call_event = self.__get_event(args, kwargs)
2882+
call_event.set()
2883+
2884+
self._event.set()
2885+
2886+
return ret_value
2887+
2888+
def wait_until_called(self, mock_timeout=None):
2889+
"""Wait until the mock object is called.
2890+
2891+
`mock_timeout` - time to wait for in seconds, waits forever otherwise.
2892+
"""
2893+
if not self._event.wait(timeout=mock_timeout):
2894+
msg = (f"{self._mock_name or 'mock'} was not called before"
2895+
f" timeout({mock_timeout}).")
2896+
raise AssertionError(msg)
2897+
2898+
def wait_until_any_call(self, *args, mock_timeout=None, **kwargs):
2899+
"""Wait until the mock object is called with given args.
2900+
2901+
`mock_timeout` - time to wait for in seconds, waits forever otherwise.
2902+
"""
2903+
event = self.__get_event(args, kwargs)
2904+
if not event.wait(timeout=mock_timeout):
2905+
expected_string = self._format_mock_call_signature(args, kwargs)
2906+
raise AssertionError(f'{expected_string} call not found')
2907+
2908+
28542909
def seal(mock):
28552910
"""Disable the automatic generation of child mocks.
28562911
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import time
2+
import unittest
3+
import concurrent.futures
4+
5+
from unittest.mock import patch, ThreadingMock, call
6+
7+
8+
class Something:
9+
10+
def method_1(self):
11+
pass
12+
13+
def method_2(self):
14+
pass
15+
16+
17+
class TestThreadingMock(unittest.TestCase):
18+
19+
def _call_after_delay(self, func, /, *args, **kwargs):
20+
time.sleep(kwargs.pop('delay'))
21+
func(*args, **kwargs)
22+
23+
24+
def setUp(self):
25+
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
26+
27+
def tearDown(self):
28+
self._executor.shutdown()
29+
30+
def run_async(self, func, /, *args, delay=0, **kwargs):
31+
self._executor.submit(self._call_after_delay, func, *args, **kwargs, delay=delay)
32+
33+
def _make_mock(self, *args, **kwargs):
34+
return ThreadingMock(*args, **kwargs)
35+
36+
def test_instance_check(self):
37+
waitable_mock = self._make_mock()
38+
39+
with patch(f'{__name__}.Something', waitable_mock):
40+
something = Something()
41+
42+
self.assertIsInstance(something.method_1, ThreadingMock)
43+
self.assertIsInstance(
44+
something.method_1().method_2(), ThreadingMock)
45+
46+
47+
def test_side_effect(self):
48+
waitable_mock = self._make_mock()
49+
50+
with patch(f'{__name__}.Something', waitable_mock):
51+
something = Something()
52+
something.method_1.side_effect = [1]
53+
54+
self.assertEqual(something.method_1(), 1)
55+
56+
57+
def test_spec(self):
58+
waitable_mock = self._make_mock(spec=Something)
59+
60+
with patch(f'{__name__}.Something', waitable_mock) as m:
61+
something = m()
62+
63+
self.assertIsInstance(something.method_1, ThreadingMock)
64+
self.assertIsInstance(
65+
something.method_1().method_2(), ThreadingMock)
66+
67+
with self.assertRaises(AttributeError):
68+
m.test
69+
70+
71+
def test_wait_until_called(self):
72+
waitable_mock = self._make_mock(spec=Something)
73+
74+
with patch(f'{__name__}.Something', waitable_mock):
75+
something = Something()
76+
self.run_async(something.method_1, delay=0.01)
77+
something.method_1.wait_until_called()
78+
something.method_1.assert_called_once()
79+
80+
81+
def test_wait_until_called_called_before(self):
82+
waitable_mock = self._make_mock(spec=Something)
83+
84+
with patch(f'{__name__}.Something', waitable_mock):
85+
something = Something()
86+
something.method_1()
87+
something.method_1.wait_until_called()
88+
something.method_1.assert_called_once()
89+
90+
91+
def test_wait_until_called_magic_method(self):
92+
waitable_mock = self._make_mock(spec=Something)
93+
94+
with patch(f'{__name__}.Something', waitable_mock):
95+
something = Something()
96+
self.run_async(something.method_1.__str__, delay=0.01)
97+
something.method_1.__str__.wait_until_called()
98+
something.method_1.__str__.assert_called_once()
99+
100+
101+
def test_wait_until_called_timeout(self):
102+
waitable_mock = self._make_mock(spec=Something)
103+
104+
with patch(f'{__name__}.Something', waitable_mock):
105+
something = Something()
106+
self.run_async(something.method_1, delay=0.2)
107+
with self.assertRaises(AssertionError):
108+
something.method_1.wait_until_called(mock_timeout=0.1)
109+
something.method_1.assert_not_called()
110+
111+
something.method_1.wait_until_called()
112+
something.method_1.assert_called_once()
113+
114+
115+
def test_wait_until_any_call_positional(self):
116+
waitable_mock = self._make_mock(spec=Something)
117+
118+
with patch(f'{__name__}.Something', waitable_mock):
119+
something = Something()
120+
self.run_async(something.method_1, 1, delay=0.1)
121+
self.run_async(something.method_1, 2, delay=0.2)
122+
self.run_async(something.method_1, 3, delay=0.3)
123+
self.assertNotIn(call(1), something.method_1.mock_calls)
124+
125+
something.method_1.wait_until_any_call(1)
126+
something.method_1.assert_called_once_with(1)
127+
self.assertNotIn(call(2), something.method_1.mock_calls)
128+
self.assertNotIn(call(3), something.method_1.mock_calls)
129+
130+
something.method_1.wait_until_any_call(3)
131+
self.assertIn(call(2), something.method_1.mock_calls)
132+
something.method_1.wait_until_any_call(2)
133+
134+
135+
def test_wait_until_any_call_keywords(self):
136+
waitable_mock = self._make_mock(spec=Something)
137+
138+
with patch(f'{__name__}.Something', waitable_mock):
139+
something = Something()
140+
self.run_async(something.method_1, a=1, delay=0.1)
141+
self.run_async(something.method_1, b=2, delay=0.2)
142+
self.run_async(something.method_1, c=3, delay=0.3)
143+
self.assertNotIn(call(a=1), something.method_1.mock_calls)
144+
145+
something.method_1.wait_until_any_call(a=1)
146+
something.method_1.assert_called_once_with(a=1)
147+
self.assertNotIn(call(b=2), something.method_1.mock_calls)
148+
self.assertNotIn(call(c=3), something.method_1.mock_calls)
149+
150+
something.method_1.wait_until_any_call(c=3)
151+
self.assertIn(call(b=2), something.method_1.mock_calls)
152+
something.method_1.wait_until_any_call(b=2)
153+
154+
def test_wait_until_any_call_no_argument(self):
155+
waitable_mock = self._make_mock(spec=Something)
156+
157+
with patch(f'{__name__}.Something', waitable_mock):
158+
something = Something()
159+
something.method_1(1)
160+
161+
something.method_1.assert_called_once_with(1)
162+
with self.assertRaises(AssertionError):
163+
something.method_1.wait_until_any_call(mock_timeout=0.1)
164+
165+
something.method_1()
166+
something.method_1.wait_until_any_call()
167+
168+
169+
if __name__ == "__main__":
170+
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Add `ThreadingMock` to :mod:`unittest.mock` that can be used to create
2+
Mock objects that can wait until they are called. Patch by Karthikeyan
3+
Singaravelan and Mario Corchero.

0 commit comments

Comments
 (0)