Skip to content

Commit eb27a5d

Browse files
committed
asyncio: #8 consider contextvars
Make it so everything gets executed in the same asyncio context
1 parent 4ad767d commit eb27a5d

File tree

10 files changed

+294
-2
lines changed

10 files changed

+294
-2
lines changed

alt_pytest_asyncio/async_converters.py

+60-1
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,75 @@
1414
from functools import wraps
1515

1616

17-
def convert_fixtures(fixturedef, request, node):
17+
def convert_fixtures(ctx, fixturedef, request, node):
1818
"""Used to replace async fixtures"""
1919
if not hasattr(fixturedef, "func"):
2020
return
2121

22+
if hasattr(fixturedef.func, "__alt_asyncio_pytest_converted__"):
23+
return
24+
2225
if inspect.iscoroutinefunction(fixturedef.func):
2326
convert_async_coroutine_fixture(fixturedef, request, node)
27+
fixturedef.func.__alt_asyncio_pytest_converted__ = True
2428

2529
elif inspect.isasyncgenfunction(fixturedef.func):
2630
convert_async_gen_fixture(fixturedef, request, node)
31+
fixturedef.func.__alt_asyncio_pytest_converted__ = True
32+
33+
elif inspect.isgeneratorfunction(fixturedef.func):
34+
convert_sync_gen_fixture(ctx, fixturedef)
35+
fixturedef.func.__alt_asyncio_pytest_converted__ = True
36+
37+
else:
38+
convert_sync_fixture(ctx, fixturedef)
39+
fixturedef.func.__alt_asyncio_pytest_converted__ = True
40+
41+
42+
def convert_sync_fixture(ctx, fixturedef):
43+
"""
44+
Used to make sure a non-async fixture is run in our
45+
asyncio contextvars
46+
"""
47+
original = fixturedef.func
48+
49+
@wraps(original)
50+
def run_fixture(*args, **kwargs):
51+
try:
52+
ctx.run(lambda: None)
53+
run = lambda func, *a, **kw: ctx.run(func, *a, **kw)
54+
except RuntimeError:
55+
run = lambda func, *a, **kw: func(*a, **kw)
56+
57+
return run(original, *args, **kwargs)
58+
59+
fixturedef.func = run_fixture
60+
61+
62+
def convert_sync_gen_fixture(ctx, fixturedef):
63+
"""
64+
Used to make sure a non-async generator fixture is run in our
65+
asyncio contextvars
66+
"""
67+
original = fixturedef.func
68+
69+
@wraps(original)
70+
def run_fixture(*args, **kwargs):
71+
try:
72+
ctx.run(lambda: None)
73+
run = lambda func, *a, **kw: ctx.run(func, *a, **kw)
74+
except RuntimeError:
75+
run = lambda func, *a, **kw: func(*a, **kw)
76+
77+
cm = original(*args, **kwargs)
78+
value = run(cm.__next__)
79+
try:
80+
yield value
81+
run(cm.__next__)
82+
except StopIteration:
83+
pass
84+
85+
fixturedef.func = run_fixture
2786

2887

2988
def converted_async_test(test_tasks, func, timeout, *args, **kwargs):

alt_pytest_asyncio/plugin.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
from alt_pytest_asyncio.async_converters import convert_fixtures, converted_async_test
2+
from alt_pytest_asyncio import async_converters
3+
4+
from _pytest._code.code import ExceptionInfo
5+
from functools import partial, wraps
6+
from collections import defaultdict
7+
from unittest import mock
8+
import contextvars
9+
import inspect
110
import asyncio
211
import inspect
312
import sys
@@ -19,14 +28,48 @@ def __init__(self, loop=None):
1928
self.own_loop = True
2029
loop = asyncio.new_event_loop()
2130
asyncio.set_event_loop(loop)
31+
2232
self.loop = loop
33+
self.ctx = contextvars.copy_context()
2334

2435
def pytest_configure(self, config):
2536
"""Register our timeout marker which is used to signify async timeouts"""
2637
config.addinivalue_line(
2738
"markers", "async_timeout(length): mark async test to have a timeout"
2839
)
2940

41+
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
42+
def pytest_sessionstart(self, session):
43+
"""
44+
Python won't let me create the task with it's own context And
45+
run_until_complete doesn't let me propagate a context.
46+
47+
So I monkey patch call_soon to force the context So I don't need my own
48+
Task class
49+
50+
This obviously won't work if Task doesn't something other than
51+
``call_soon(self._step)``.
52+
"""
53+
original = self.loop.call_soon
54+
55+
@wraps(original)
56+
def wrapped_call_soon(callback, *args, context=None):
57+
is_alt_pytest = False
58+
for frm in inspect.stack():
59+
mod = inspect.getmodule(frm[0])
60+
if mod is async_converters:
61+
is_alt_pytest = True
62+
break
63+
64+
if is_alt_pytest:
65+
return original(callback, *args, context=self.ctx)
66+
else:
67+
return original(callback, *args, context=context)
68+
69+
self.loop_patch = mock.patch.object(self.loop, "call_soon", wrapped_call_soon)
70+
self.loop_patch.start()
71+
yield
72+
3073
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
3174
def pytest_sessionfinish(self, session, exitstatus):
3275
"""
@@ -54,10 +97,13 @@ def pytest_sessionfinish(self, session, exitstatus):
5497
finally:
5598
self.loop.close()
5699

100+
if hasattr(self, "loop_patch"):
101+
self.loop_patch.stop()
102+
57103
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
58104
def pytest_fixture_setup(self, fixturedef, request):
59105
"""Convert async fixtures to sync fixtures"""
60-
convert_fixtures(fixturedef, request, request.node)
106+
convert_fixtures(self.ctx, fixturedef, request, request.node)
61107
yield
62108

63109
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
@@ -76,6 +122,22 @@ def pytest_pyfunc_call(self, pyfuncitem):
76122

77123
o = pyfuncitem.obj
78124
pyfuncitem.obj = wraps(o)(partial(converted_async_test, self.test_tasks, o, timeout))
125+
else:
126+
127+
original = pyfuncitem.obj
128+
129+
@wraps(original)
130+
def run_obj(*args, **kwargs):
131+
try:
132+
self.ctx.run(lambda: None)
133+
run = lambda func, *a, **kw: self.ctx.run(func, *a, **kw)
134+
except RuntimeError:
135+
run = lambda func, *a, **kw: func(*a, **kw)
136+
137+
run(original, *args, **kwargs)
138+
139+
pyfuncitem.obj = run_obj
140+
79141
yield
80142

81143

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ tests = [
2525
"nest-asyncio==1.0.0",
2626
"noseOfYeti[black]==2.4.1",
2727
"pytest==7.3.0",
28+
"pytest-order==1.2.1"
2829
]
2930

3031
[project.entry-points.pytest11]

tests/code_contextvars/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from contextvars import ContextVar
2+
import string
3+
4+
allvars = {}
5+
6+
7+
class Empty:
8+
pass
9+
10+
11+
def assertVarsEmpty(excluding=None):
12+
for letter, var in allvars.items():
13+
if not excluding or letter not in excluding:
14+
assert var.get(Empty) is Empty
15+
16+
17+
for letter in string.ascii_letters:
18+
var = ContextVar(letter)
19+
locals()[letter] = var
20+
allvars[letter] = var
21+
22+
__all__ = ["allvars", "Empty", "assertVarsEmpty"] + sorted(allvars)

tests/conftest.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11
import pytest
2+
from pathlib import Path
3+
import asyncio
4+
import pytest
5+
import struct
6+
import socket
7+
import sys
8+
9+
sys.path.append(str(Path(__file__).parent / "code_contextvars"))
210

311
pytest_plugins = ["pytester"]
412

tests/test_contextvars/__init__.py

Whitespace-only changes.

tests/test_contextvars/conftest.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import contextvars_for_test as ctxvars
2+
3+
import pytest
4+
5+
6+
@pytest.fixture(scope="session", autouse=True)
7+
async def a_set_conftest_fixture_session_autouse():
8+
ctxvars.a.set("a_set_conftest_fixture_session_autouse")
9+
10+
11+
@pytest.fixture(scope="session", autouse=True)
12+
async def b_set_conftest_cm_session_autouse():
13+
ctxvars.b.set("b_set_conftest_cm_session_autouse")
14+
yield
15+
16+
17+
@pytest.fixture()
18+
async def c_set_conftest_fixture_test():
19+
ctxvars.c.set("c_set_conftest_fixture_test")
20+
21+
22+
@pytest.fixture()
23+
async def c_set_conftest_cm_test():
24+
token = ctxvars.c.set("c_set_conftest_cm_test")
25+
try:
26+
yield
27+
finally:
28+
ctxvars.c.reset(token)
+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# coding: spec
2+
3+
import contextvars_for_test as ctxvars
4+
import pytest
5+
6+
7+
@pytest.fixture(scope="module", autouse=True)
8+
def d_set_conftest_cm_module_autouse():
9+
token = ctxvars.d.set("d_set_conftest_fixture_test")
10+
try:
11+
yield
12+
finally:
13+
ctxvars.d.reset(token)
14+
15+
16+
@pytest.fixture(scope="module", autouse=True)
17+
async def e_set_conftest_cm_test():
18+
assert ctxvars.f.get(ctxvars.Empty) is ctxvars.Empty
19+
ctxvars.e.set("e_set_conftest_cm_module")
20+
yield
21+
assert ctxvars.f.get() == "f_set_conftest_cm_module"
22+
23+
24+
@pytest.fixture(scope="module", autouse=True)
25+
async def f_set_conftest_cm_module(e_set_conftest_cm_test):
26+
assert ctxvars.e.get() == "e_set_conftest_cm_module"
27+
ctxvars.f.set("f_set_conftest_cm_module")
28+
yield
29+
30+
31+
@pytest.mark.order(1)
32+
async it "gets session modified vars":
33+
assert ctxvars.a.get() == "a_set_conftest_fixture_session_autouse"
34+
assert ctxvars.b.get() == "b_set_conftest_cm_session_autouse"
35+
assert ctxvars.d.get() == "d_set_conftest_fixture_test"
36+
assert ctxvars.e.get() == "e_set_conftest_cm_module"
37+
assert ctxvars.f.get() == "f_set_conftest_cm_module"
38+
ctxvars.assertVarsEmpty(excluding=("a", "b", "d", "e", "f"))
39+
40+
41+
@pytest.mark.order(2)
42+
async it "can use a fixture to change the var", c_set_conftest_fixture_test:
43+
assert ctxvars.a.get() == "a_set_conftest_fixture_session_autouse"
44+
assert ctxvars.b.get() == "b_set_conftest_cm_session_autouse"
45+
assert ctxvars.d.get() == "d_set_conftest_fixture_test"
46+
assert ctxvars.e.get() == "e_set_conftest_cm_module"
47+
assert ctxvars.f.get() == "f_set_conftest_cm_module"
48+
49+
assert ctxvars.c.get() == "c_set_conftest_fixture_test"
50+
ctxvars.assertVarsEmpty(excluding=("a", "b", "c", "d", "e", "f"))
51+
52+
53+
@pytest.mark.order(3)
54+
async it "does not reset contextvars for you":
55+
"""
56+
It's too hard to know when a contextvar should be reset. It should
57+
be up to whatever sets the contextvar to know when it should be unset
58+
"""
59+
assert ctxvars.a.get() == "a_set_conftest_fixture_session_autouse"
60+
assert ctxvars.b.get() == "b_set_conftest_cm_session_autouse"
61+
assert ctxvars.d.get() == "d_set_conftest_fixture_test"
62+
assert ctxvars.e.get() == "e_set_conftest_cm_module"
63+
assert ctxvars.f.get() == "f_set_conftest_cm_module"
64+
65+
assert ctxvars.c.get() == "c_set_conftest_fixture_test"
66+
ctxvars.assertVarsEmpty(excluding=("a", "b", "c", "d", "e", "f"))
67+
68+
69+
@pytest.mark.order(4)
70+
async it "works in context manager fixtures", c_set_conftest_cm_test:
71+
"""
72+
It's too hard to know when a contextvar should be reset. It should
73+
be up to whatever sets the contextvar to know when it should be unset
74+
"""
75+
assert ctxvars.a.get() == "a_set_conftest_fixture_session_autouse"
76+
assert ctxvars.b.get() == "b_set_conftest_cm_session_autouse"
77+
assert ctxvars.d.get() == "d_set_conftest_fixture_test"
78+
assert ctxvars.e.get() == "e_set_conftest_cm_module"
79+
assert ctxvars.f.get() == "f_set_conftest_cm_module"
80+
81+
assert ctxvars.c.get() == "c_set_conftest_cm_test"
82+
ctxvars.assertVarsEmpty(excluding=("a", "b", "c", "d", "e", "f"))
83+
84+
85+
@pytest.mark.order(5)
86+
it "resets the contextvar successfully when cm attempts that":
87+
"""
88+
It's too hard to know when a contextvar should be reset. It should
89+
be up to whatever sets the contextvar to know when it should be unset
90+
"""
91+
assert ctxvars.a.get() == "a_set_conftest_fixture_session_autouse"
92+
assert ctxvars.b.get() == "b_set_conftest_cm_session_autouse"
93+
assert ctxvars.d.get() == "d_set_conftest_fixture_test"
94+
assert ctxvars.e.get() == "e_set_conftest_cm_module"
95+
assert ctxvars.f.get() == "f_set_conftest_cm_module"
96+
97+
assert ctxvars.c.get() == "c_set_conftest_fixture_test"
98+
ctxvars.assertVarsEmpty(excluding=("a", "b", "c", "d", "e", "f"))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# coding: spec
2+
3+
import contextvars_for_test as ctxvars
4+
import pytest
5+
6+
7+
@pytest.mark.order(100)
8+
it "only keeps values from other module fixtures if they haven't been cleaned up":
9+
assert ctxvars.a.get() == "a_set_conftest_fixture_session_autouse"
10+
assert ctxvars.b.get() == "b_set_conftest_cm_session_autouse"
11+
assert ctxvars.c.get() == "c_set_conftest_fixture_test"
12+
assert ctxvars.e.get() == "e_set_conftest_cm_module"
13+
assert ctxvars.f.get() == "f_set_conftest_cm_module"
14+
ctxvars.assertVarsEmpty(excluding=("a", "b", "c", "e", "f"))

0 commit comments

Comments
 (0)