Skip to content

Commit e43552a

Browse files
authored
Merge pull request #182 from scorphus/scoped-mockers
Add mocker for class, module, package and session scopes
2 parents d0cceef + 762f907 commit e43552a

File tree

2 files changed

+102
-2
lines changed

2 files changed

+102
-2
lines changed

src/pytest_mock/plugin.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,7 @@ def __call__(self, *args, **kwargs):
199199
return self._start_patch(self.mock_module.patch, *args, **kwargs)
200200

201201

202-
@pytest.yield_fixture
203-
def mocker(pytestconfig):
202+
def _mocker(pytestconfig):
204203
"""
205204
return an object that has the same interface to the `mock` module, but
206205
takes care of automatically undoing all patches after each test method.
@@ -210,6 +209,13 @@ def mocker(pytestconfig):
210209
result.stopall()
211210

212211

212+
mocker = pytest.yield_fixture()(_mocker) # default scope is function
213+
class_mocker = pytest.yield_fixture(scope="class")(_mocker)
214+
module_mocker = pytest.yield_fixture(scope="module")(_mocker)
215+
package_mocker = pytest.yield_fixture(scope="package")(_mocker)
216+
session_mocker = pytest.yield_fixture(scope="session")(_mocker)
217+
218+
213219
_mock_module_patches = []
214220
_mock_module_originals = {}
215221

tests/test_pytest_mock.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,3 +820,97 @@ def test_foo(mocker):
820820
py_fn.remove()
821821
result = testdir.runpytest()
822822
result.stdout.fnmatch_lines("* 1 passed *")
823+
824+
825+
def test_used_with_class_scope(testdir):
826+
"""..."""
827+
testdir.makepyfile(
828+
"""
829+
import pytest
830+
import random
831+
import unittest
832+
833+
def get_random_number():
834+
return random.randint(0, 1)
835+
836+
@pytest.fixture(autouse=True, scope="class")
837+
def randint_mock(class_mocker):
838+
return class_mocker.patch("random.randint", lambda x, y: 5)
839+
840+
class TestGetRandomNumber(unittest.TestCase):
841+
def test_get_random_number(self):
842+
assert get_random_number() == 5
843+
"""
844+
)
845+
result = testdir.runpytest_subprocess()
846+
assert "AssertionError" not in result.stderr.str()
847+
result.stdout.fnmatch_lines("* 1 passed in *")
848+
849+
850+
def test_used_with_module_scope(testdir):
851+
"""..."""
852+
testdir.makepyfile(
853+
"""
854+
import pytest
855+
import random
856+
857+
def get_random_number():
858+
return random.randint(0, 1)
859+
860+
@pytest.fixture(autouse=True, scope="module")
861+
def randint_mock(module_mocker):
862+
return module_mocker.patch("random.randint", lambda x, y: 5)
863+
864+
def test_get_random_number():
865+
assert get_random_number() == 5
866+
"""
867+
)
868+
result = testdir.runpytest_subprocess()
869+
assert "AssertionError" not in result.stderr.str()
870+
result.stdout.fnmatch_lines("* 1 passed in *")
871+
872+
873+
def test_used_with_package_scope(testdir):
874+
"""..."""
875+
testdir.makepyfile(
876+
"""
877+
import pytest
878+
import random
879+
880+
def get_random_number():
881+
return random.randint(0, 1)
882+
883+
@pytest.fixture(autouse=True, scope="package")
884+
def randint_mock(package_mocker):
885+
return package_mocker.patch("random.randint", lambda x, y: 5)
886+
887+
def test_get_random_number():
888+
assert get_random_number() == 5
889+
"""
890+
)
891+
result = testdir.runpytest_subprocess()
892+
assert "AssertionError" not in result.stderr.str()
893+
result.stdout.fnmatch_lines("* 1 passed in *")
894+
895+
896+
def test_used_with_session_scope(testdir):
897+
"""..."""
898+
testdir.makepyfile(
899+
"""
900+
import pytest
901+
import random
902+
903+
def get_random_number():
904+
return random.randint(0, 1)
905+
906+
@pytest.fixture(autouse=True, scope="session")
907+
def randint_mock(session_mocker):
908+
return session_mocker.patch("random.randint", lambda x, y: 5)
909+
910+
def test_get_random_number():
911+
assert get_random_number() == 5
912+
"""
913+
)
914+
result = testdir.runpytest_subprocess()
915+
assert "AssertionError" not in result.stderr.str()
916+
result.stdout.fnmatch_lines("* 1 passed in *")

0 commit comments

Comments
 (0)