Skip to content

Commit 9beb1cf

Browse files
committed
Use context manager to manage DelayProxy
1 parent 7c6ed4a commit 9beb1cf

File tree

1 file changed

+121
-113
lines changed

1 file changed

+121
-113
lines changed

tests/test_asyncio/test_cwe_404.py

+121-113
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,21 @@ def __init__(self, addr, redis_addr, delay: float = 0.0):
2727
self.delay = delay
2828
self.send_event = asyncio.Event()
2929

30+
async def __aenter__(self):
31+
await self.start()
32+
return self
33+
34+
async def __aexit__(self, *args):
35+
await self.stop()
36+
3037
async def start(self):
3138
# test that we can connect to redis
3239
async with async_timeout(2):
3340
_, redis_writer = await asyncio.open_connection(*self.redis_addr)
3441
redis_writer.close()
35-
self.server = await asyncio.start_server(self.handle, *self.addr)
42+
self.server = await asyncio.start_server(
43+
self.handle, *self.addr, reuse_address=True
44+
)
3645
self.ROUTINE = asyncio.create_task(self.server.serve_forever())
3746

3847
@contextlib.contextmanager
@@ -95,91 +104,89 @@ async def test_standalone(delay, redis_addr):
95104

96105
# create a tcp socket proxy that relays data to Redis and back,
97106
# inserting 0.1 seconds of delay
98-
dp = DelayProxy(addr=("127.0.0.1", 5380), redis_addr=redis_addr)
99-
await dp.start()
100-
101-
for b in [True, False]:
102-
# note that we connect to proxy, rather than to Redis directly
103-
async with Redis(host="127.0.0.1", port=5380, single_connection_client=b) as r:
104-
105-
await r.set("foo", "foo")
106-
await r.set("bar", "bar")
107-
108-
async def op(r):
109-
with dp.set_delay(delay * 2):
110-
return await r.get(
111-
"foo"
112-
) # <-- this is the operation we want to cancel
113-
114-
dp.send_event.clear()
115-
t = asyncio.create_task(op(r))
116-
# Wait until the task has sent, and then some, to make sure it has
117-
# settled on the read.
118-
await dp.send_event.wait()
119-
await asyncio.sleep(0.01) # a little extra time for prudence
120-
t.cancel()
121-
with pytest.raises(asyncio.CancelledError):
122-
await t
123-
124-
# make sure that our previous request, cancelled while waiting for
125-
# a repsponse, didn't leave the connection open andin a bad state
126-
assert await r.get("bar") == b"bar"
127-
assert await r.ping()
128-
assert await r.get("foo") == b"foo"
129-
130-
await dp.stop()
107+
async with DelayProxy(addr=("127.0.0.1", 5380), redis_addr=redis_addr) as dp:
108+
109+
for b in [True, False]:
110+
# note that we connect to proxy, rather than to Redis directly
111+
async with Redis(
112+
host="127.0.0.1", port=5380, single_connection_client=b
113+
) as r:
114+
115+
await r.set("foo", "foo")
116+
await r.set("bar", "bar")
117+
118+
async def op(r):
119+
with dp.set_delay(delay * 2):
120+
return await r.get(
121+
"foo"
122+
) # <-- this is the operation we want to cancel
123+
124+
dp.send_event.clear()
125+
t = asyncio.create_task(op(r))
126+
# Wait until the task has sent, and then some, to make sure it has
127+
# settled on the read.
128+
await dp.send_event.wait()
129+
await asyncio.sleep(0.01) # a little extra time for prudence
130+
t.cancel()
131+
with pytest.raises(asyncio.CancelledError):
132+
await t
133+
134+
# make sure that our previous request, cancelled while waiting for
135+
# a repsponse, didn't leave the connection open andin a bad state
136+
assert await r.get("bar") == b"bar"
137+
assert await r.ping()
138+
assert await r.get("foo") == b"foo"
131139

132140

133141
@pytest.mark.onlynoncluster
134142
@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2])
135143
async def test_standalone_pipeline(delay, redis_addr):
136-
dp = DelayProxy(addr=("127.0.0.1", 5380), redis_addr=redis_addr)
137-
await dp.start()
138-
for b in [True, False]:
139-
async with Redis(host="127.0.0.1", port=5380, single_connection_client=b) as r:
140-
await r.set("foo", "foo")
141-
await r.set("bar", "bar")
142-
143-
pipe = r.pipeline()
144-
145-
pipe2 = r.pipeline()
146-
pipe2.get("bar")
147-
pipe2.ping()
148-
pipe2.get("foo")
149-
150-
async def op(pipe):
151-
with dp.set_delay(delay * 2):
152-
return await pipe.get(
153-
"foo"
154-
).execute() # <-- this is the operation we want to cancel
155-
156-
dp.send_event.clear()
157-
t = asyncio.create_task(op(pipe))
158-
# wait until task has settled on the read
159-
await dp.send_event.wait()
160-
await asyncio.sleep(0.01)
161-
t.cancel()
162-
with pytest.raises(asyncio.CancelledError):
163-
await t
164-
165-
# we have now cancelled the pieline in the middle of a request, make sure
166-
# that the connection is still usable
167-
pipe.get("bar")
168-
pipe.ping()
169-
pipe.get("foo")
170-
await pipe.reset()
171-
172-
# check that the pipeline is empty after reset
173-
assert await pipe.execute() == []
174-
175-
# validating that the pipeline can be used as it could previously
176-
pipe.get("bar")
177-
pipe.ping()
178-
pipe.get("foo")
179-
assert await pipe.execute() == [b"bar", True, b"foo"]
180-
assert await pipe2.execute() == [b"bar", True, b"foo"]
181-
182-
await dp.stop()
144+
async with DelayProxy(addr=("127.0.0.1", 5380), redis_addr=redis_addr) as dp:
145+
for b in [True, False]:
146+
async with Redis(
147+
host="127.0.0.1", port=5380, single_connection_client=b
148+
) as r:
149+
await r.set("foo", "foo")
150+
await r.set("bar", "bar")
151+
152+
pipe = r.pipeline()
153+
154+
pipe2 = r.pipeline()
155+
pipe2.get("bar")
156+
pipe2.ping()
157+
pipe2.get("foo")
158+
159+
async def op(pipe):
160+
with dp.set_delay(delay * 2):
161+
return await pipe.get(
162+
"foo"
163+
).execute() # <-- this is the operation we want to cancel
164+
165+
dp.send_event.clear()
166+
t = asyncio.create_task(op(pipe))
167+
# wait until task has settled on the read
168+
await dp.send_event.wait()
169+
await asyncio.sleep(0.01)
170+
t.cancel()
171+
with pytest.raises(asyncio.CancelledError):
172+
await t
173+
174+
# we have now cancelled the pieline in the middle of a request, make sure
175+
# that the connection is still usable
176+
pipe.get("bar")
177+
pipe.ping()
178+
pipe.get("foo")
179+
await pipe.reset()
180+
181+
# check that the pipeline is empty after reset
182+
assert await pipe.execute() == []
183+
184+
# validating that the pipeline can be used as it could previously
185+
pipe.get("bar")
186+
pipe.ping()
187+
pipe.get("foo")
188+
assert await pipe.execute() == [b"bar", True, b"foo"]
189+
assert await pipe2.execute() == [b"bar", True, b"foo"]
183190

184191

185192
@pytest.mark.onlycluster
@@ -202,9 +209,6 @@ def remap(address):
202209
proxy = DelayProxy(addr=("127.0.0.1", remapped), redis_addr=forward_addr)
203210
proxies.append(proxy)
204211

205-
# start proxies
206-
await asyncio.gather(*[p.start() for p in proxies])
207-
208212
def all_clear():
209213
for p in proxies:
210214
p.send_event.clear()
@@ -221,32 +225,36 @@ def set_delay(delay: float):
221225
stack.enter_context(p.set_delay(delay))
222226
yield
223227

224-
with contextlib.closing(
225-
RedisCluster.from_url(f"redis://127.0.0.1:{remap_base}", address_remap=remap)
226-
) as r:
227-
await r.initialize()
228-
await r.set("foo", "foo")
229-
await r.set("bar", "bar")
230-
231-
async def op(r):
232-
with set_delay(delay):
233-
return await r.get("foo")
234-
235-
all_clear()
236-
t = asyncio.create_task(op(r))
237-
# Wait for whichever DelayProxy gets the request first
238-
await wait_for_send()
239-
await asyncio.sleep(0.01)
240-
t.cancel()
241-
with pytest.raises(asyncio.CancelledError):
242-
await t
243-
244-
# try a number of requests to excercise all the connections
245-
async def doit():
246-
assert await r.get("bar") == b"bar"
247-
assert await r.ping()
248-
assert await r.get("foo") == b"foo"
249-
250-
await asyncio.gather(*[doit() for _ in range(10)])
251-
252-
await asyncio.gather(*(p.stop() for p in proxies))
228+
async with contextlib.AsyncExitStack() as stack:
229+
for p in proxies:
230+
await stack.enter_async_context(p)
231+
232+
with contextlib.closing(
233+
RedisCluster.from_url(
234+
f"redis://127.0.0.1:{remap_base}", address_remap=remap
235+
)
236+
) as r:
237+
await r.initialize()
238+
await r.set("foo", "foo")
239+
await r.set("bar", "bar")
240+
241+
async def op(r):
242+
with set_delay(delay):
243+
return await r.get("foo")
244+
245+
all_clear()
246+
t = asyncio.create_task(op(r))
247+
# Wait for whichever DelayProxy gets the request first
248+
await wait_for_send()
249+
await asyncio.sleep(0.01)
250+
t.cancel()
251+
with pytest.raises(asyncio.CancelledError):
252+
await t
253+
254+
# try a number of requests to excercise all the connections
255+
async def doit():
256+
assert await r.get("bar") == b"bar"
257+
assert await r.ping()
258+
assert await r.get("foo") == b"foo"
259+
260+
await asyncio.gather(*[doit() for _ in range(10)])

0 commit comments

Comments
 (0)