Skip to content

gh-61215: Add Mock.call_event to allow waiting for calls #20759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions Doc/library/unittest.mock-examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ the ``something`` method:
>>> real.method()
>>> real.something.assert_called_once_with(1, 2, 3)

When testing mutltithreaded code it may be important to ensure that certain
method is eventually called, e.g. as a result of scheduling asynchronous
operation. :attr:`~Mock.call_event` exposes methods that allow to assert that:

>>> from threading import Timer
>>>
>>> class ProductionClass:
... def method(self):
... self.t1 = Timer(0.1, self.something, args=(1, 2, 3))
... self.t1.start()
... self.t2 = Timer(0.1, self.something, args=(4, 5, 6))
... self.t2.start()
... def something(self, a, b, c):
... pass
...
>>> real = ProductionClass()
>>> real.something = MagicMock()
>>> real.method()
>>> real.something.call_event.wait_for_call(call(4, 5, 6), timeout=1.0)


Mock for Method Calls on an Object
Expand Down
14 changes: 14 additions & 0 deletions Doc/library/unittest.mock.rst
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,20 @@ the *new_callable* argument to :func:`patch`.
>>> mock.call_count
2

.. attribute:: call_event

An object that can be used in multithreaded tests to assert that a call was made.

- :meth:`wait(/, skip=0, timeout=None)` asserts that mock is called
*skip* + 1 times during the *timeout*

- :meth:`wait_for(predicate, /, timeout=None)` asserts that
*predicate* was ``True`` at least once during the *timeout*;
*predicate* receives exactly one positional argument: the mock itself

- :meth:`wait_for_call(call, /, skip=0, timeout=None)` asserts that
*call* has happened at least *skip* + 1 times during the *timeout*

.. attribute:: return_value

Set this to configure the value returned by calling the mock:
Expand Down
74 changes: 74 additions & 0 deletions Lib/unittest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
'mock_open',
'PropertyMock',
'seal',
'CallEvent'
)


Expand All @@ -31,6 +32,7 @@
import sys
import builtins
import pkgutil
import threading
from asyncio import iscoroutinefunction
from types import CodeType, ModuleType, MethodType
from unittest.util import safe_repr
Expand Down Expand Up @@ -224,6 +226,7 @@ def reset_mock():
ret.reset_mock()

funcopy.called = False
funcopy.call_event = CallEvent(mock)
funcopy.call_count = 0
funcopy.call_args = None
funcopy.call_args_list = _CallList()
Expand Down Expand Up @@ -446,6 +449,7 @@ def __init__(
__dict__['_mock_delegate'] = None

__dict__['_mock_called'] = False
__dict__['_mock_call_event'] = CallEvent(self)
__dict__['_mock_call_args'] = None
__dict__['_mock_call_count'] = 0
__dict__['_mock_call_args_list'] = _CallList()
Expand Down Expand Up @@ -547,6 +551,7 @@ def __class__(self):
return self._spec_class

called = _delegating_property('called')
call_event = _delegating_property('call_event')
call_count = _delegating_property('call_count')
call_args = _delegating_property('call_args')
call_args_list = _delegating_property('call_args_list')
Expand Down Expand Up @@ -584,6 +589,7 @@ def reset_mock(self, visited=None,*, return_value=False, side_effect=False):
visited.append(id(self))

self.called = False
self.call_event = CallEvent(self)
self.call_args = None
self.call_count = 0
self.mock_calls = _CallList()
Expand Down Expand Up @@ -1111,6 +1117,7 @@ def _mock_call(self, /, *args, **kwargs):
def _increment_mock_call(self, /, *args, **kwargs):
self.called = True
self.call_count += 1
self.call_event._notify()

# handle call_args
# needs to be set here so assertions on call arguments pass before
Expand Down Expand Up @@ -2411,6 +2418,73 @@ def _format_call_signature(name, args, kwargs):
return message % formatted_args


class CallEvent(object):
def __init__(self, mock):
self._mock = mock
self._condition = threading.Condition()

def wait(self, /, skip=0, timeout=None):
"""
Wait for any call.

:param skip: How many calls will be skipped.
As a result, the mock should be called at least
``skip + 1`` times.

:param timeout: See :meth:`threading.Condition.wait`.
"""
def predicate(mock):
return mock.call_count > skip

self.wait_for(predicate, timeout=timeout)

def wait_for_call(self, call, /, skip=0, timeout=None):
"""
Wait for a given call.

:param skip: How many calls will be skipped.
As a result, the call should happen at least
``skip + 1`` times.

:param timeout: See :meth:`threading.Condition.wait`.
"""
def predicate(mock):
return mock.call_args_list.count(call) > skip

self.wait_for(predicate, timeout=timeout)

def wait_for(self, predicate, /, timeout=None):
"""
Wait for a given predicate to become True.

:param predicate: A callable that receives mock which result
will be interpreted as a boolean value.
The final predicate value is the return value.

:param timeout: See :meth:`threading.Condition.wait`.
"""
try:
self._condition.acquire()

def _predicate():
return predicate(self._mock)

b = self._condition.wait_for(_predicate, timeout)

if not b:
msg = (f"{self._mock._mock_name or 'mock'} was not called before"
f" timeout({timeout}).")
raise AssertionError(msg)
finally:
self._condition.release()

def _notify(self):
try:
self._condition.acquire()
self._condition.notify_all()
finally:
self._condition.release()


class _Call(tuple):
"""
Expand Down
16 changes: 16 additions & 0 deletions Lib/unittest/test/testmock/support.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import concurrent.futures
import time


target = {'foo': 'FOO'}


Expand All @@ -14,3 +18,15 @@ def wibble(self): pass

class X(object):
pass


def call_after_delay(func, /, *args, **kwargs):
time.sleep(kwargs.pop('delay'))
func(*args, **kwargs)


def run_async(func, /, *args, executor=None, delay=0, **kwargs):
if executor is None:
executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)

executor.submit(call_after_delay, func, *args, **kwargs, delay=delay)
74 changes: 73 additions & 1 deletion Lib/unittest/test/testmock/testmock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from test.support import ALWAYS_EQ
import unittest
from unittest.test.testmock.support import is_instance
from unittest.test.testmock.support import is_instance, run_async
from unittest import mock
from unittest.mock import (
call, DEFAULT, patch, sentinel,
Expand Down Expand Up @@ -2249,6 +2249,78 @@ class Foo():
f'{__name__}.Typos', autospect=True, set_spec=True, auto_spec=True):
pass

def test_wait_until_called_before(self):
mock = Mock(spec=Something)()
mock.method_1()
mock.method_1.call_event.wait()
mock.method_1.assert_called_once()

def test_wait_until_called(self):
mock = Mock(spec=Something)()
run_async(mock.method_1, delay=0.01)
mock.method_1.call_event.wait()
mock.method_1.assert_called_once()

def test_wait_until_called_magic_method(self):
mock = MagicMock(spec=Something)()
run_async(mock.method_1.__str__, delay=0.01)
mock.method_1.__str__.call_event.wait()
mock.method_1.__str__.assert_called_once()

def test_wait_until_called_timeout(self):
mock = Mock(spec=Something)()
run_async(mock.method_1, delay=0.2)

with self.assertRaises(AssertionError):
mock.method_1.call_event.wait(timeout=0.1)

mock.method_1.assert_not_called()
mock.method_1.call_event.wait()
mock.method_1.assert_called_once()

def test_wait_until_any_call_positional(self):
mock = Mock(spec=Something)()
run_async(mock.method_1, 1, delay=0.1)
run_async(mock.method_1, 2, delay=0.2)
run_async(mock.method_1, 3, delay=0.3)

for arg in (1, 2, 3):
self.assertNotIn(call(arg), mock.method_1.mock_calls)
mock.method_1.call_event.wait_for(lambda m: call(arg) in m.call_args_list)
mock.method_1.assert_called_with(arg)

def test_wait_until_any_call_keywords(self):
mock = Mock(spec=Something)()
run_async(mock.method_1, a=1, delay=0.1)
run_async(mock.method_1, a=2, delay=0.2)
run_async(mock.method_1, a=3, delay=0.3)

for arg in (1, 2, 3):
self.assertNotIn(call(arg), mock.method_1.mock_calls)
mock.method_1.call_event.wait_for(lambda m: call(a=arg) in m.call_args_list)
mock.method_1.assert_called_with(a=arg)

def test_wait_until_any_call_no_argument(self):
mock = Mock(spec=Something)()
mock.method_1(1)
mock.method_1.assert_called_once_with(1)

with self.assertRaises(AssertionError):
mock.method_1.call_event.wait_for(lambda m: call() in m.call_args_list, timeout=0.01)

mock.method_1()
mock.method_1.call_event.wait_for(lambda m: call() in m.call_args_list, timeout=0.01)

def test_wait_until_call(self):
mock = Mock(spec=Something)()
mock.method_1()
run_async(mock.method_1, 1, a=1, delay=0.1)

with self.assertRaises(AssertionError):
mock.method_1.call_event.wait_for_call(call(1, a=1), timeout=0.01)

mock.method_1.call_event.wait_for_call(call(1, a=1))


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add :attr:`call_event` to :class:`Mock` that allows to wait for the calls in
multithreaded tests. Patch by Ilya Kulakov.