Skip to content

Commit a9ef0c5

Browse files
Make PythonParser resumable (#2510)
* PythonParser is now resumable if _stream IO is interrupted * Add test for parse resumability * Clear PythonParser state when connection or parsing errors occur. * disable test for cluster mode. * Perform "closed" check in a single place. * Update tests * Simplify code. * Remove reduntant test, EOF is detected inside _readline() * Make syncronous PythonParser restartable on error, same as HiredisParser Fix sync PythonParser * Add CHANGES * isort * Move MockStream and MockSocket into their own files
1 parent a947728 commit a9ef0c5

File tree

7 files changed

+269
-42
lines changed

7 files changed

+269
-42
lines changed

Diff for: CHANGES

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Make PythonParser resumable in case of error (#2510)
12
* Add `timeout=None` in `SentinelConnectionManager.read_response`
23
* Documentation fix: password protected socket connection (#2374)
34
* Allow `timeout=None` in `PubSub.get_message()` to wait forever

Diff for: redis/asyncio/connection.py

+51-18
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,18 @@ async def read_response(
208208
class PythonParser(BaseParser):
209209
"""Plain Python parsing class"""
210210

211-
__slots__ = BaseParser.__slots__ + ("encoder",)
211+
__slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks")
212212

213213
def __init__(self, socket_read_size: int):
214214
super().__init__(socket_read_size)
215215
self.encoder: Optional[Encoder] = None
216+
self._buffer = b""
217+
self._chunks = []
218+
self._pos = 0
219+
220+
def _clear(self):
221+
self._buffer = b""
222+
self._chunks.clear()
216223

217224
def on_connect(self, connection: "Connection"):
218225
"""Called when the stream connects"""
@@ -227,8 +234,11 @@ def on_disconnect(self):
227234
if self._stream is not None:
228235
self._stream = None
229236
self.encoder = None
237+
self._clear()
230238

231239
async def can_read_destructive(self) -> bool:
240+
if self._buffer:
241+
return True
232242
if self._stream is None:
233243
raise RedisError("Buffer is closed.")
234244
try:
@@ -237,14 +247,23 @@ async def can_read_destructive(self) -> bool:
237247
except asyncio.TimeoutError:
238248
return False
239249

240-
async def read_response(
250+
async def read_response(self, disable_decoding: bool = False):
251+
if self._chunks:
252+
# augment parsing buffer with previously read data
253+
self._buffer += b"".join(self._chunks)
254+
self._chunks.clear()
255+
self._pos = 0
256+
response = await self._read_response(disable_decoding=disable_decoding)
257+
# Successfully parsing a response allows us to clear our parsing buffer
258+
self._clear()
259+
return response
260+
261+
async def _read_response(
241262
self, disable_decoding: bool = False
242263
) -> Union[EncodableT, ResponseError, None]:
243264
if not self._stream or not self.encoder:
244265
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
245266
raw = await self._readline()
246-
if not raw:
247-
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
248267
response: Any
249268
byte, response = raw[:1], raw[1:]
250269

@@ -258,6 +277,7 @@ async def read_response(
258277
# if the error is a ConnectionError, raise immediately so the user
259278
# is notified
260279
if isinstance(error, ConnectionError):
280+
self._clear() # Successful parse
261281
raise error
262282
# otherwise, we're dealing with a ResponseError that might belong
263283
# inside a pipeline response. the connection's read_response()
@@ -282,7 +302,7 @@ async def read_response(
282302
if length == -1:
283303
return None
284304
response = [
285-
(await self.read_response(disable_decoding)) for _ in range(length)
305+
(await self._read_response(disable_decoding)) for _ in range(length)
286306
]
287307
if isinstance(response, bytes) and disable_decoding is False:
288308
response = self.encoder.decode(response)
@@ -293,25 +313,38 @@ async def _read(self, length: int) -> bytes:
293313
Read `length` bytes of data. These are assumed to be followed
294314
by a '\r\n' terminator which is subsequently discarded.
295315
"""
296-
if self._stream is None:
297-
raise RedisError("Buffer is closed.")
298-
try:
299-
data = await self._stream.readexactly(length + 2)
300-
except asyncio.IncompleteReadError as error:
301-
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
302-
return data[:-2]
316+
want = length + 2
317+
end = self._pos + want
318+
if len(self._buffer) >= end:
319+
result = self._buffer[self._pos : end - 2]
320+
else:
321+
tail = self._buffer[self._pos :]
322+
try:
323+
data = await self._stream.readexactly(want - len(tail))
324+
except asyncio.IncompleteReadError as error:
325+
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
326+
result = (tail + data)[:-2]
327+
self._chunks.append(data)
328+
self._pos += want
329+
return result
303330

304331
async def _readline(self) -> bytes:
305332
"""
306333
read an unknown number of bytes up to the next '\r\n'
307334
line separator, which is discarded.
308335
"""
309-
if self._stream is None:
310-
raise RedisError("Buffer is closed.")
311-
data = await self._stream.readline()
312-
if not data.endswith(b"\r\n"):
313-
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
314-
return data[:-2]
336+
found = self._buffer.find(b"\r\n", self._pos)
337+
if found >= 0:
338+
result = self._buffer[self._pos : found]
339+
else:
340+
tail = self._buffer[self._pos :]
341+
data = await self._stream.readline()
342+
if not data.endswith(b"\r\n"):
343+
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
344+
result = (tail + data)[:-2]
345+
self._chunks.append(data)
346+
self._pos += len(result) + 2
347+
return result
315348

316349

317350
class HiredisParser(BaseParser):

Diff for: redis/connection.py

+42-16
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,6 @@ def read(self, length):
232232
self._buffer.seek(self.bytes_read)
233233
data = self._buffer.read(length)
234234
self.bytes_read += len(data)
235-
236-
# purge the buffer when we've consumed it all so it doesn't
237-
# grow forever
238-
if self.bytes_read == self.bytes_written:
239-
self.purge()
240-
241235
return data[:-2]
242236

243237
def readline(self):
@@ -251,23 +245,44 @@ def readline(self):
251245
data = buf.readline()
252246

253247
self.bytes_read += len(data)
248+
return data[:-2]
254249

255-
# purge the buffer when we've consumed it all so it doesn't
256-
# grow forever
257-
if self.bytes_read == self.bytes_written:
258-
self.purge()
250+
def get_pos(self):
251+
"""
252+
Get current read position
253+
"""
254+
return self.bytes_read
259255

260-
return data[:-2]
256+
def rewind(self, pos):
257+
"""
258+
Rewind the buffer to a specific position, to re-start reading
259+
"""
260+
self.bytes_read = pos
261261

262262
def purge(self):
263-
self._buffer.seek(0)
264-
self._buffer.truncate()
265-
self.bytes_written = 0
263+
"""
264+
After a successful read, purge the read part of buffer
265+
"""
266+
unread = self.bytes_written - self.bytes_read
267+
268+
# Only if we have read all of the buffer do we truncate, to
269+
# reduce the amount of memory thrashing. This heuristic
270+
# can be changed or removed later.
271+
if unread > 0:
272+
return
273+
274+
if unread > 0:
275+
# move unread data to the front
276+
view = self._buffer.getbuffer()
277+
view[:unread] = view[-unread:]
278+
self._buffer.truncate(unread)
279+
self.bytes_written = unread
266280
self.bytes_read = 0
281+
self._buffer.seek(0)
267282

268283
def close(self):
269284
try:
270-
self.purge()
285+
self.bytes_written = self.bytes_read = 0
271286
self._buffer.close()
272287
except Exception:
273288
# issue #633 suggests the purge/close somehow raised a
@@ -315,6 +330,17 @@ def can_read(self, timeout):
315330
return self._buffer and self._buffer.can_read(timeout)
316331

317332
def read_response(self, disable_decoding=False):
333+
pos = self._buffer.get_pos()
334+
try:
335+
result = self._read_response(disable_decoding=disable_decoding)
336+
except BaseException:
337+
self._buffer.rewind(pos)
338+
raise
339+
else:
340+
self._buffer.purge()
341+
return result
342+
343+
def _read_response(self, disable_decoding=False):
318344
raw = self._buffer.readline()
319345
if not raw:
320346
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
@@ -355,7 +381,7 @@ def read_response(self, disable_decoding=False):
355381
if length == -1:
356382
return None
357383
response = [
358-
self.read_response(disable_decoding=disable_decoding)
384+
self._read_response(disable_decoding=disable_decoding)
359385
for i in range(length)
360386
]
361387
if isinstance(response, bytes) and disable_decoding is False:

Diff for: tests/mocks.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Various mocks for testing
2+
3+
4+
class MockSocket:
5+
"""
6+
A class simulating an readable socket, optionally raising a
7+
special exception every other read.
8+
"""
9+
10+
class TestError(BaseException):
11+
pass
12+
13+
def __init__(self, data, interrupt_every=0):
14+
self.data = data
15+
self.counter = 0
16+
self.pos = 0
17+
self.interrupt_every = interrupt_every
18+
19+
def tick(self):
20+
self.counter += 1
21+
if not self.interrupt_every:
22+
return
23+
if (self.counter % self.interrupt_every) == 0:
24+
raise self.TestError()
25+
26+
def recv(self, bufsize):
27+
self.tick()
28+
bufsize = min(5, bufsize) # truncate the read size
29+
result = self.data[self.pos : self.pos + bufsize]
30+
self.pos += len(result)
31+
return result
32+
33+
def recv_into(self, buffer, nbytes=0, flags=0):
34+
self.tick()
35+
if nbytes == 0:
36+
nbytes = len(buffer)
37+
nbytes = min(5, nbytes) # truncate the read size
38+
result = self.data[self.pos : self.pos + nbytes]
39+
self.pos += len(result)
40+
buffer[: len(result)] = result
41+
return len(result)

Diff for: tests/test_asyncio/mocks.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import asyncio
2+
3+
# Helper Mocking classes for the tests.
4+
5+
6+
class MockStream:
7+
"""
8+
A class simulating an asyncio input buffer, optionally raising a
9+
special exception every other read.
10+
"""
11+
12+
class TestError(BaseException):
13+
pass
14+
15+
def __init__(self, data, interrupt_every=0):
16+
self.data = data
17+
self.counter = 0
18+
self.pos = 0
19+
self.interrupt_every = interrupt_every
20+
21+
def tick(self):
22+
self.counter += 1
23+
if not self.interrupt_every:
24+
return
25+
if (self.counter % self.interrupt_every) == 0:
26+
raise self.TestError()
27+
28+
async def read(self, want):
29+
self.tick()
30+
want = 5
31+
result = self.data[self.pos : self.pos + want]
32+
self.pos += len(result)
33+
return result
34+
35+
async def readline(self):
36+
self.tick()
37+
find = self.data.find(b"\n", self.pos)
38+
if find >= 0:
39+
result = self.data[self.pos : find + 1]
40+
else:
41+
result = self.data[self.pos :]
42+
self.pos += len(result)
43+
return result
44+
45+
async def readexactly(self, length):
46+
self.tick()
47+
result = self.data[self.pos : self.pos + length]
48+
if len(result) < length:
49+
raise asyncio.IncompleteReadError(result, None)
50+
self.pos += len(result)
51+
return result

Diff for: tests/test_asyncio/test_connection.py

+41-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
import pytest
77

8+
import redis
89
from redis.asyncio.connection import (
10+
BaseParser,
911
Connection,
1012
PythonParser,
1113
UnixDomainSocketConnection,
@@ -16,23 +18,27 @@
1618
from tests.conftest import skip_if_server_version_lt
1719

1820
from .compat import mock
21+
from .mocks import MockStream
1922

2023

2124
@pytest.mark.onlynoncluster
2225
async def test_invalid_response(create_redis):
2326
r = await create_redis(single_connection_client=True)
2427

2528
raw = b"x"
29+
fake_stream = MockStream(raw + b"\r\n")
2630

27-
parser: "PythonParser" = r.connection._parser
28-
if not isinstance(parser, PythonParser):
29-
pytest.skip("PythonParser only")
30-
stream_mock = mock.Mock(parser._stream)
31-
stream_mock.readline.return_value = raw + b"\r\n"
32-
with mock.patch.object(parser, "_stream", stream_mock):
31+
parser: BaseParser = r.connection._parser
32+
with mock.patch.object(parser, "_stream", fake_stream):
3333
with pytest.raises(InvalidResponse) as cm:
3434
await parser.read_response()
35-
assert str(cm.value) == f"Protocol Error: {raw!r}"
35+
if isinstance(parser, PythonParser):
36+
assert str(cm.value) == f"Protocol Error: {raw!r}"
37+
else:
38+
assert (
39+
str(cm.value) == f'Protocol error, got "{raw.decode()}" as reply type byte'
40+
)
41+
await r.connection.disconnect()
3642

3743

3844
@skip_if_server_version_lt("4.0.0")
@@ -112,3 +118,31 @@ async def test_connect_timeout_error_without_retry():
112118
await conn.connect()
113119
assert conn._connect.call_count == 1
114120
assert str(e.value) == "Timeout connecting to server"
121+
122+
123+
@pytest.mark.onlynoncluster
124+
async def test_connection_parse_response_resume(r: redis.Redis):
125+
"""
126+
This test verifies that the Connection parser,
127+
be that PythonParser or HiredisParser,
128+
can be interrupted at IO time and then resume parsing.
129+
"""
130+
conn = Connection(**r.connection_pool.connection_kwargs)
131+
await conn.connect()
132+
message = (
133+
b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n"
134+
b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
135+
)
136+
137+
conn._parser._stream = MockStream(message, interrupt_every=2)
138+
for i in range(100):
139+
try:
140+
response = await conn.read_response()
141+
break
142+
except MockStream.TestError:
143+
pass
144+
145+
else:
146+
pytest.fail("didn't receive a response")
147+
assert response
148+
assert i > 0

0 commit comments

Comments
 (0)