Skip to content

Commit 52b4eb6

Browse files
committed
Added warns to assert warnings are thrown
Works in a similar manner to `raises`, but for warnings instead of exceptions. Also refactored `recwarn.py` so that all the warning recording and checking use the same core code.
1 parent aac371c commit 52b4eb6

File tree

7 files changed

+400
-145
lines changed

7 files changed

+400
-145
lines changed

CHANGELOG

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
2.8.0.dev (compared to 2.7.X)
22
-----------------------------
33

4+
- Add 'warns' to assert that warnings are thrown (like 'raises').
5+
Thanks to Eric Hunsberger for the PR.
6+
47
- Fix #683: Do not apply an already applied mark. Thanks ojake for the PR.
58

69
- Deal with capturing failures better so fewer exceptions get lost to

_pytest/python.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,8 +1052,8 @@ def getlocation(function, curdir):
10521052

10531053
# builtin pytest.raises helper
10541054

1055-
def raises(ExpectedException, *args, **kwargs):
1056-
""" assert that a code block/function call raises @ExpectedException
1055+
def raises(expected_exception, *args, **kwargs):
1056+
""" assert that a code block/function call raises @expected_exception
10571057
and raise a failure exception otherwise.
10581058
10591059
This helper produces a ``py.code.ExceptionInfo()`` object.
@@ -1101,23 +1101,23 @@ def raises(ExpectedException, *args, **kwargs):
11011101
11021102
"""
11031103
__tracebackhide__ = True
1104-
if ExpectedException is AssertionError:
1104+
if expected_exception is AssertionError:
11051105
# we want to catch a AssertionError
11061106
# replace our subclass with the builtin one
11071107
# see https://github.com/pytest-dev/pytest/issues/176
11081108
from _pytest.assertion.util import BuiltinAssertionError \
1109-
as ExpectedException
1109+
as expected_exception
11101110
msg = ("exceptions must be old-style classes or"
11111111
" derived from BaseException, not %s")
1112-
if isinstance(ExpectedException, tuple):
1113-
for exc in ExpectedException:
1112+
if isinstance(expected_exception, tuple):
1113+
for exc in expected_exception:
11141114
if not inspect.isclass(exc):
11151115
raise TypeError(msg % type(exc))
1116-
elif not inspect.isclass(ExpectedException):
1117-
raise TypeError(msg % type(ExpectedException))
1116+
elif not inspect.isclass(expected_exception):
1117+
raise TypeError(msg % type(expected_exception))
11181118

11191119
if not args:
1120-
return RaisesContext(ExpectedException)
1120+
return RaisesContext(expected_exception)
11211121
elif isinstance(args[0], str):
11221122
code, = args
11231123
assert isinstance(code, str)
@@ -1130,19 +1130,19 @@ def raises(ExpectedException, *args, **kwargs):
11301130
py.builtin.exec_(code, frame.f_globals, loc)
11311131
# XXX didn'T mean f_globals == f_locals something special?
11321132
# this is destroyed here ...
1133-
except ExpectedException:
1133+
except expected_exception:
11341134
return py.code.ExceptionInfo()
11351135
else:
11361136
func = args[0]
11371137
try:
11381138
func(*args[1:], **kwargs)
1139-
except ExpectedException:
1139+
except expected_exception:
11401140
return py.code.ExceptionInfo()
11411141
pytest.fail("DID NOT RAISE")
11421142

11431143
class RaisesContext(object):
1144-
def __init__(self, ExpectedException):
1145-
self.ExpectedException = ExpectedException
1144+
def __init__(self, expected_exception):
1145+
self.expected_exception = expected_exception
11461146
self.excinfo = None
11471147

11481148
def __enter__(self):
@@ -1161,7 +1161,7 @@ def __exit__(self, *tp):
11611161
exc_type, value, traceback = tp
11621162
tp = exc_type, exc_type(value), traceback
11631163
self.excinfo.__init__(tp)
1164-
return issubclass(self.excinfo.type, self.ExpectedException)
1164+
return issubclass(self.excinfo.type, self.expected_exception)
11651165

11661166
#
11671167
# the basic pytest Function item
@@ -2123,4 +2123,3 @@ def get_scope_node(node, scope):
21232123
return node.session
21242124
raise ValueError("unknown scope")
21252125
return node.getparent(cls)
2126-

_pytest/recwarn.py

Lines changed: 147 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
""" recording warnings during test function execution. """
22

3+
import inspect
4+
import py
35
import sys
46
import warnings
7+
import pytest
58

69

7-
def pytest_funcarg__recwarn(request):
10+
@pytest.yield_fixture
11+
def recwarn(request):
812
"""Return a WarningsRecorder instance that provides these methods:
913
1014
* ``pop(category=None)``: return last warning matching the category.
@@ -13,83 +17,169 @@ def pytest_funcarg__recwarn(request):
1317
See http://docs.python.org/library/warnings.html for information
1418
on warning categories.
1519
"""
16-
if sys.version_info >= (2,7):
17-
oldfilters = warnings.filters[:]
18-
warnings.simplefilter('default')
19-
def reset_filters():
20-
warnings.filters[:] = oldfilters
21-
request.addfinalizer(reset_filters)
2220
wrec = WarningsRecorder()
23-
request.addfinalizer(wrec.finalize)
24-
return wrec
21+
with wrec:
22+
warnings.simplefilter('default')
23+
yield wrec
24+
2525

2626
def pytest_namespace():
27-
return {'deprecated_call': deprecated_call}
27+
return {'deprecated_call': deprecated_call,
28+
'warns': warns}
29+
2830

2931
def deprecated_call(func, *args, **kwargs):
30-
""" assert that calling ``func(*args, **kwargs)``
31-
triggers a DeprecationWarning.
32+
"""Assert that ``func(*args, **kwargs)`` triggers a DeprecationWarning.
3233
"""
33-
l = []
34-
oldwarn_explicit = getattr(warnings, 'warn_explicit')
35-
def warn_explicit(*args, **kwargs):
36-
l.append(args)
37-
oldwarn_explicit(*args, **kwargs)
38-
oldwarn = getattr(warnings, 'warn')
39-
def warn(*args, **kwargs):
40-
l.append(args)
41-
oldwarn(*args, **kwargs)
42-
43-
warnings.warn_explicit = warn_explicit
44-
warnings.warn = warn
45-
try:
34+
wrec = WarningsRecorder()
35+
with wrec:
36+
warnings.simplefilter('always') # ensure all warnings are triggered
4637
ret = func(*args, **kwargs)
47-
finally:
48-
warnings.warn_explicit = oldwarn_explicit
49-
warnings.warn = oldwarn
50-
if not l:
38+
39+
if not any(r.category is DeprecationWarning for r in wrec):
5140
__tracebackhide__ = True
52-
raise AssertionError("%r did not produce DeprecationWarning" %(func,))
41+
raise AssertionError("%r did not produce DeprecationWarning" % (func,))
42+
5343
return ret
5444

5545

56-
class RecordedWarning:
57-
def __init__(self, message, category, filename, lineno, line):
46+
def warns(expected_warning, *args, **kwargs):
47+
"""Assert that code raises a particular class of warning.
48+
49+
Specifically, the input @expected_warning can be a warning class or
50+
tuple of warning classes, and the code must return that warning
51+
(if a single class) or one of those warnings (if a tuple).
52+
53+
This helper produces a list of ``warnings.WarningMessage`` objects,
54+
one for each warning raised.
55+
56+
This function can be used as a context manager, or any of the other ways
57+
``pytest.raises`` can be used::
58+
59+
>>> with warns(RuntimeWarning):
60+
... warnings.warn("my warning", RuntimeWarning)
61+
"""
62+
wcheck = WarningsChecker(expected_warning)
63+
if not args:
64+
return wcheck
65+
elif isinstance(args[0], str):
66+
code, = args
67+
assert isinstance(code, str)
68+
frame = sys._getframe(1)
69+
loc = frame.f_locals.copy()
70+
loc.update(kwargs)
71+
72+
with wcheck:
73+
code = py.code.Source(code).compile()
74+
py.builtin.exec_(code, frame.f_globals, loc)
75+
else:
76+
func = args[0]
77+
with wcheck:
78+
return func(*args[1:], **kwargs)
79+
80+
81+
class RecordedWarning(object):
82+
def __init__(self, message, category, filename, lineno, file, line):
5883
self.message = message
5984
self.category = category
6085
self.filename = filename
6186
self.lineno = lineno
87+
self.file = file
6288
self.line = line
6389

64-
class WarningsRecorder:
65-
def __init__(self):
66-
self.list = []
67-
def showwarning(message, category, filename, lineno, line=0):
68-
self.list.append(RecordedWarning(
69-
message, category, filename, lineno, line))
70-
try:
71-
self.old_showwarning(message, category,
72-
filename, lineno, line=line)
73-
except TypeError:
74-
# < python2.6
75-
self.old_showwarning(message, category, filename, lineno)
76-
self.old_showwarning = warnings.showwarning
77-
warnings.showwarning = showwarning
90+
91+
class WarningsRecorder(object):
92+
"""A context manager to record raised warnings.
93+
94+
Adapted from `warnings.catch_warnings`.
95+
"""
96+
97+
def __init__(self, module=None):
98+
self._module = sys.modules['warnings'] if module is None else module
99+
self._entered = False
100+
self._list = []
101+
102+
@property
103+
def list(self):
104+
"""The list of recorded warnings."""
105+
return self._list
106+
107+
def __getitem__(self, i):
108+
"""Get a recorded warning by index."""
109+
return self._list[i]
110+
111+
def __iter__(self):
112+
"""Iterate through the recorded warnings."""
113+
return iter(self._list)
114+
115+
def __len__(self):
116+
"""The number of recorded warnings."""
117+
return len(self._list)
78118

79119
def pop(self, cls=Warning):
80-
""" pop the first recorded warning, raise exception if not exists."""
81-
for i, w in enumerate(self.list):
120+
"""Pop the first recorded warning, raise exception if not exists."""
121+
for i, w in enumerate(self._list):
82122
if issubclass(w.category, cls):
83-
return self.list.pop(i)
123+
return self._list.pop(i)
84124
__tracebackhide__ = True
85-
assert 0, "%r not found in %r" %(cls, self.list)
86-
87-
#def resetregistry(self):
88-
# warnings.onceregistry.clear()
89-
# warnings.__warningregistry__.clear()
125+
raise AssertionError("%r not found in warning list" % cls)
90126

91127
def clear(self):
92-
self.list[:] = []
128+
"""Clear the list of recorded warnings."""
129+
self._list[:] = []
130+
131+
def __enter__(self):
132+
if self._entered:
133+
__tracebackhide__ = True
134+
raise RuntimeError("Cannot enter %r twice" % self)
135+
self._entered = True
136+
self._filters = self._module.filters
137+
self._module.filters = self._filters[:]
138+
self._showwarning = self._module.showwarning
139+
140+
def showwarning(message, category, filename, lineno,
141+
file=None, line=None):
142+
self._list.append(RecordedWarning(
143+
message, category, filename, lineno, file, line))
144+
145+
# still perform old showwarning functionality
146+
self._showwarning(message, category, filename, lineno,
147+
file=file, line=line)
148+
149+
self._module.showwarning = showwarning
150+
return self
151+
152+
def __exit__(self, *exc_info):
153+
if not self._entered:
154+
__tracebackhide__ = True
155+
raise RuntimeError("Cannot exit %r without entering first" % self)
156+
self._module.filters = self._filters
157+
self._module.showwarning = self._showwarning
158+
159+
160+
class WarningsChecker(WarningsRecorder):
161+
def __init__(self, expected_warning=None, module=None):
162+
super(WarningsChecker, self).__init__(module=module)
163+
164+
msg = ("exceptions must be old-style classes or "
165+
"derived from Warning, not %s")
166+
if isinstance(expected_warning, tuple):
167+
for exc in expected_warning:
168+
if not inspect.isclass(exc):
169+
raise TypeError(msg % type(exc))
170+
elif inspect.isclass(expected_warning):
171+
expected_warning = (expected_warning,)
172+
elif expected_warning is not None:
173+
raise TypeError(msg % type(expected_warning))
174+
175+
self.expected_warning = expected_warning
176+
177+
def __exit__(self, *exc_info):
178+
super(WarningsChecker, self).__exit__(*exc_info)
93179

94-
def finalize(self):
95-
warnings.showwarning = self.old_showwarning
180+
# only check if we're not currently handling an exception
181+
if all(a is None for a in exc_info):
182+
if self.expected_warning is not None:
183+
if not any(r.category in self.expected_warning for r in self):
184+
__tracebackhide__ = True
185+
pytest.fail("DID NOT WARN")

doc/en/assert.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,16 @@ like documenting unfixed bugs (where the test describes what "should" happen)
114114
or bugs in dependencies.
115115

116116

117+
.. _`assertwarns`:
118+
119+
Assertions about expected warnings
120+
-----------------------------------------
121+
122+
.. versionadded:: 2.8
123+
124+
You can check that code raises a particular warning using
125+
:ref:`pytest.warns <warns>`.
126+
117127

118128
.. _newreport:
119129

0 commit comments

Comments
 (0)