Skip to content

Commit cf2ae67

Browse files
committed
clean up DelayProxy, fix comments
1 parent 6f6b6f6 commit cf2ae67

File tree

1 file changed

+30
-32
lines changed

1 file changed

+30
-32
lines changed

tests/test_asyncio/test_cwe_404.py

+30-32
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,12 @@ def redis_addr(request):
1717
return host, int(port)
1818

1919

20-
async def pipe(
21-
reader: asyncio.StreamReader,
22-
writer: asyncio.StreamWriter,
23-
proxy: "DelayProxy",
24-
name="",
25-
event: asyncio.Event = None,
26-
):
27-
while True:
28-
data = await reader.read(1000)
29-
if not data:
30-
break
31-
if event:
32-
event.set()
33-
await asyncio.sleep(proxy.delay)
34-
writer.write(data)
35-
await writer.drain()
36-
37-
3820
class DelayProxy:
3921
def __init__(self, addr, redis_addr, delay: float):
4022
self.addr = addr
4123
self.redis_addr = redis_addr
4224
self.delay = delay
4325
self.send_event = asyncio.Event()
44-
self.redis_streams = None
4526

4627
async def start(self):
4728
# test that we can connect to redis
@@ -52,31 +33,48 @@ async def start(self):
5233
self.ROUTINE = asyncio.create_task(self.server.serve_forever())
5334

5435
@contextlib.contextmanager
55-
def override(self, delay: float = 0.0):
36+
def set_delay(self, delay: float = 0.0):
5637
"""
5738
Allow to override the delay for parts of tests which aren't time dependent,
5839
to speed up execution.
5940
"""
60-
old = self.delay
41+
old_delay = self.delay
6142
self.delay = delay
6243
try:
6344
yield
6445
finally:
65-
self.delay = old
46+
self.delay = old_delay
6647

6748
async def handle(self, reader, writer):
6849
# establish connection to redis
6950
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
7051
try:
7152
pipe1 = asyncio.create_task(
72-
pipe(reader, redis_writer, self, "to redis:", self.send_event)
53+
self.pipe(reader, redis_writer, "to redis:", self.send_event)
7354
)
74-
pipe2 = asyncio.create_task(pipe(redis_reader, writer, self, "from redis:"))
55+
pipe2 = asyncio.create_task(self.pipe(redis_reader, writer, "from redis:"))
7556
await asyncio.gather(pipe1, pipe2)
7657
finally:
7758
redis_writer.close()
7859
redis_reader.close()
7960

61+
async def pipe(
62+
self,
63+
reader: asyncio.StreamReader,
64+
writer: asyncio.StreamWriter,
65+
name="",
66+
event: asyncio.Event = None,
67+
):
68+
while True:
69+
data = await reader.read(1000)
70+
if not data:
71+
break
72+
if event:
73+
event.set()
74+
await asyncio.sleep(self.delay)
75+
writer.write(data)
76+
await writer.drain()
77+
8078
async def stop(self):
8179
# clean up enough so that we can reuse the looper
8280
self.ROUTINE.cancel()
@@ -101,7 +99,7 @@ async def test_standalone(delay, redis_addr):
10199
# note that we connect to proxy, rather than to Redis directly
102100
async with Redis(host="127.0.0.1", port=5380, single_connection_client=b) as r:
103101

104-
with dp.override():
102+
with dp.set_delay(0):
105103
await r.set("foo", "foo")
106104
await r.set("bar", "bar")
107105

@@ -117,7 +115,7 @@ async def test_standalone(delay, redis_addr):
117115

118116
# make sure that our previous request, cancelled while waiting for
119117
# a repsponse, didn't leave the connection open andin a bad state
120-
with dp.override():
118+
with dp.set_delay(0):
121119
assert await r.get("bar") == b"bar"
122120
assert await r.ping()
123121
assert await r.get("foo") == b"foo"
@@ -132,7 +130,7 @@ async def test_standalone_pipeline(delay, redis_addr):
132130
await dp.start()
133131
for b in [True, False]:
134132
async with Redis(host="127.0.0.1", port=5380, single_connection_client=b) as r:
135-
with dp.override():
133+
with dp.set_delay(0):
136134
await r.set("foo", "foo")
137135
await r.set("bar", "bar")
138136

@@ -154,7 +152,7 @@ async def test_standalone_pipeline(delay, redis_addr):
154152

155153
# we have now cancelled the pieline in the middle of a request, make sure
156154
# that the connection is still usable
157-
with dp.override():
155+
with dp.set_delay(0):
158156
pipe.get("bar")
159157
pipe.ping()
160158
pipe.get("foo")
@@ -205,10 +203,10 @@ async def any_wait():
205203
)
206204

207205
@contextlib.contextmanager
208-
def all_override(delay: int = 0):
206+
def set_delay(delay: int = 0):
209207
with contextlib.ExitStack() as stack:
210208
for p in proxies:
211-
stack.enter_context(p.override(delay=delay))
209+
stack.enter_context(p.delay_as(delay))
212210
yield
213211

214212
# start proxies
@@ -222,9 +220,9 @@ def all_override(delay: int = 0):
222220
await r.set("bar", "bar")
223221

224222
all_clear()
225-
with all_override(delay=delay):
223+
with set_delay(delay=delay):
226224
t = asyncio.create_task(r.get("foo"))
227-
# cannot wait on the send event, we don't know which node will be used
225+
# One of the proxies will handle our request, wait for it to send
228226
await any_wait()
229227
await asyncio.sleep(delay)
230228
t.cancel()

0 commit comments

Comments
 (0)