Skip to content

Commit 7009cad

Browse files
committed
Check ImportMode exhaustively
1 parent 2a31c1a commit 7009cad

File tree

3 files changed

+101
-4
lines changed

3 files changed

+101
-4
lines changed

src/_pytest/compat.py

+36
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434

3535
if TYPE_CHECKING:
36+
from typing import NoReturn
3637
from typing import Type
3738
from typing_extensions import Final
3839

@@ -401,3 +402,38 @@ def __get__(self, instance, owner=None): # noqa: F811
401402
from collections import OrderedDict
402403

403404
order_preserving_dict = OrderedDict
405+
406+
407+
# Perform exhaustiveness checking.
408+
#
409+
# Consider this example:
410+
#
411+
# MyUnion = Union[int, str]
412+
#
413+
# def handle(x: MyUnion) -> int {
414+
# if isinstance(x, int):
415+
# return 1
416+
# elif isinstance(x, str):
417+
# return 2
418+
# else:
419+
# raise Exception('unreachable')
420+
#
421+
# Now suppose we add a new variant:
422+
#
423+
# MyUnion = Union[int, str, bytes]
424+
#
425+
# After doing this, we must remember ourselves to go and update the handle
426+
# function to handle the new variant.
427+
#
428+
# With `assert_never` we can do better:
429+
#
430+
# // throw new Error('unreachable');
431+
# return assert_never(x)
432+
#
433+
# Now, if we forget to handle the new variant, the type-checker will emit a
434+
# compile-time error, instead of the runtime error we would have gotten
435+
# previously.
436+
#
437+
# This also work for Enums (if you use `is` to compare) and Literals.
438+
def assert_never(value: "NoReturn") -> "NoReturn":
439+
assert False, "Unhandled value: {} ({})".format(value, type(value).__name__)

src/_pytest/pathlib.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import py
2626

27+
from _pytest.compat import assert_never
2728
from _pytest.outcomes import skip
2829
from _pytest.warning_types import PytestWarning
2930

@@ -463,7 +464,7 @@ def import_path(
463464
if not path.exists():
464465
raise ImportError(path)
465466

466-
if mode == ImportMode.importlib:
467+
if mode is ImportMode.importlib:
467468
module_name = path.stem
468469

469470
for meta_importer in sys.meta_path:
@@ -495,13 +496,14 @@ def import_path(
495496
# change sys.path permanently: restoring it at the end of this function would cause surprising
496497
# problems because of delayed imports: for example, a conftest.py file imported by this function
497498
# might have local imports, which would fail at runtime if we restored sys.path.
498-
if mode == ImportMode.append:
499+
if mode is ImportMode.append:
499500
if str(pkg_root) not in sys.path:
500501
sys.path.append(str(pkg_root))
501-
else:
502-
assert mode == ImportMode.prepend, "unexpected mode: {}".format(mode)
502+
elif mode is ImportMode.prepend:
503503
if str(pkg_root) != sys.path[0]:
504504
sys.path.insert(0, str(pkg_root))
505+
else:
506+
assert_never(mode)
505507

506508
importlib.import_module(module_name)
507509

testing/test_compat.py

+59
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
1+
import enum
12
import sys
23
from functools import partial
34
from functools import wraps
5+
from typing import Union
46

57
import pytest
68
from _pytest.compat import _PytestWrapper
9+
from _pytest.compat import assert_never
710
from _pytest.compat import cached_property
811
from _pytest.compat import get_real_func
912
from _pytest.compat import is_generator
1013
from _pytest.compat import safe_getattr
1114
from _pytest.compat import safe_isclass
15+
from _pytest.compat import TYPE_CHECKING
1216
from _pytest.outcomes import OutcomeException
1317

18+
if TYPE_CHECKING:
19+
from typing_extensions import Literal
20+
1421

1522
def test_is_generator():
1623
def zap():
@@ -205,3 +212,55 @@ def prop(self) -> int:
205212
assert ncalls == 1
206213
assert c2.prop == 2
207214
assert c1.prop == 1
215+
216+
217+
def test_assert_never_union() -> None:
218+
x = 10 # type: Union[int, str]
219+
220+
if isinstance(x, int):
221+
pass
222+
else:
223+
with pytest.raises(AssertionError):
224+
assert_never(x) # type: ignore[arg-type]
225+
226+
if isinstance(x, int):
227+
pass
228+
elif isinstance(x, str):
229+
pass
230+
else:
231+
assert_never(x)
232+
233+
234+
def test_assert_never_enum() -> None:
235+
E = enum.Enum("E", "a b")
236+
x = E.a # type: E
237+
238+
if x is E.a:
239+
pass
240+
else:
241+
with pytest.raises(AssertionError):
242+
assert_never(x) # type: ignore[arg-type]
243+
244+
if x is E.a:
245+
pass
246+
elif x is E.b:
247+
pass
248+
else:
249+
assert_never(x)
250+
251+
252+
def test_assert_never_literal() -> None:
253+
x = "a" # type: Literal["a", "b"]
254+
255+
if x == "a":
256+
pass
257+
else:
258+
with pytest.raises(AssertionError):
259+
assert_never(x) # type: ignore[arg-type]
260+
261+
if x == "a":
262+
pass
263+
elif x == "b":
264+
pass
265+
else:
266+
assert_never(x)

0 commit comments

Comments
 (0)