Skip to content

Commit 075fba7

Browse files
dvora-hchayim
authored andcommitted
Fix bug: client side caching causes unexpected disconnections (#3160)
* fix disconnects * skip test in cluster --------- Co-authored-by: Chayim <[email protected]>
1 parent c32ac68 commit 075fba7

File tree

5 files changed

+61
-19
lines changed

5 files changed

+61
-19
lines changed

redis/_parsers/resp3.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def _read_response(self, disable_decoding=False, push_request=False):
117117
)
118118
for _ in range(int(response))
119119
]
120-
self.handle_push_response(response, disable_decoding, push_request)
120+
response = self.handle_push_response(
121+
response, disable_decoding, push_request
122+
)
121123
else:
122124
raise InvalidResponse(f"Protocol Error: {raw!r}")
123125

redis/client.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -560,10 +560,10 @@ def execute_command(self, *args, **options):
560560
pool = self.connection_pool
561561
conn = self.connection or pool.get_connection(command_name, **options)
562562
response_from_cache = conn._get_from_local_cache(args)
563-
if response_from_cache is not None:
564-
return response_from_cache
565-
else:
566-
try:
563+
try:
564+
if response_from_cache is not None:
565+
return response_from_cache
566+
else:
567567
response = conn.retry.call_with_retry(
568568
lambda: self._send_command_parse_response(
569569
conn, command_name, *args, **options
@@ -572,9 +572,9 @@ def execute_command(self, *args, **options):
572572
)
573573
conn._add_to_local_cache(args, response, keys)
574574
return response
575-
finally:
576-
if not self.connection:
577-
pool.release(conn)
575+
finally:
576+
if not self.connection:
577+
pool.release(conn)
578578

579579
def parse_response(self, connection, command_name, **options):
580580
"""Parses a response from the Redis server"""

redis/commands/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2011,7 +2011,7 @@ def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT:
20112011
options = {}
20122012
if not args:
20132013
options[EMPTY_RESPONSE] = []
2014-
options["keys"] = keys
2014+
options["keys"] = args
20152015
return self.execute_command("MGET", *args, **options)
20162016

20172017
def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT:

redis/connection.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import os
3-
import select
43
import socket
54
import ssl
65
import sys
@@ -609,11 +608,6 @@ def pack_commands(self, commands):
609608
output.append(SYM_EMPTY.join(pieces))
610609
return output
611610

612-
def _socket_is_empty(self):
613-
"""Check if the socket is empty"""
614-
r, _, _ = select.select([self._sock], [], [], 0)
615-
return not bool(r)
616-
617611
def _cache_invalidation_process(
618612
self, data: List[Union[str, Optional[List[str]]]]
619613
) -> None:
@@ -639,7 +633,7 @@ def _get_from_local_cache(self, command: str):
639633
or command[0] not in self.cache_whitelist
640634
):
641635
return None
642-
while not self._socket_is_empty():
636+
while self.can_read():
643637
self.read_response(push_request=True)
644638
return self.client_cache.get(command)
645639

@@ -1187,12 +1181,15 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection":
11871181
try:
11881182
# ensure this connection is connected to Redis
11891183
connection.connect()
1190-
# connections that the pool provides should be ready to send
1191-
# a command. if not, the connection was either returned to the
1184+
# if client caching is not enabled connections that the pool
1185+
# provides should be ready to send a command.
1186+
# if not, the connection was either returned to the
11921187
# pool before all data has been read or the socket has been
11931188
# closed. either way, reconnect and verify everything is good.
1189+
# (if caching enabled the connection will not always be ready
1190+
# to send a command because it may contain invalidation messages)
11941191
try:
1195-
if connection.can_read():
1192+
if connection.can_read() and connection.client_cache is None:
11961193
raise ConnectionError("Connection has data")
11971194
except (ConnectionError, OSError):
11981195
connection.disconnect()

tests/test_cache.py

+43
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,49 @@ def test_cache_return_copy(self, r):
146146
check = cache.get(("LRANGE", "mylist", 0, -1))
147147
assert check == [b"baz", b"bar", b"foo"]
148148

149+
@pytest.mark.onlynoncluster
150+
@pytest.mark.parametrize(
151+
"r",
152+
[{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
153+
indirect=True,
154+
)
155+
def test_csc_not_cause_disconnects(self, r):
156+
r, cache = r
157+
id1 = r.client_id()
158+
r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1, "f": 1})
159+
assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"]
160+
id2 = r.client_id()
161+
162+
# client should get value from client cache
163+
assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"]
164+
assert cache.get(("MGET", "a", "b", "c", "d", "e", "f")) == [
165+
"1",
166+
"1",
167+
"1",
168+
"1",
169+
"1",
170+
"1",
171+
]
172+
173+
r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2, "f": 2})
174+
id3 = r.client_id()
175+
# client should get value from redis server post invalidate messages
176+
assert r.mget("a", "b", "c", "d", "e", "f") == ["2", "2", "2", "2", "2", "2"]
177+
178+
r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3, "f": 3})
179+
# need to check that we get correct value 3 and not 2
180+
assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"]
181+
# client should get value from client cache
182+
assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"]
183+
184+
r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4, "f": 4})
185+
# need to check that we get correct value 4 and not 3
186+
assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"]
187+
# client should get value from client cache
188+
assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"]
189+
id4 = r.client_id()
190+
assert id1 == id2 == id3 == id4
191+
149192

150193
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
151194
@pytest.mark.onlycluster

0 commit comments

Comments
 (0)