Skip to content

Commit bea7299

Browse files
authored
Fix retries in async mode (#2180)
* Avoid mutating a global retry_on_error list * Make retries config consistent in sync and async * Fix async retries * Add new TestConnectionConstructorWithRetry tests
1 parent 3370298 commit bea7299

File tree

6 files changed

+83
-9
lines changed

6 files changed

+83
-9
lines changed

redis/asyncio/client.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __init__(
158158
encoding_errors: str = "strict",
159159
decode_responses: bool = False,
160160
retry_on_timeout: bool = False,
161+
retry_on_error: Optional[list] = None,
161162
ssl: bool = False,
162163
ssl_keyfile: Optional[str] = None,
163164
ssl_certfile: Optional[str] = None,
@@ -176,8 +177,10 @@ def __init__(
176177
):
177178
"""
178179
Initialize a new Redis client.
179-
To specify a retry policy, first set `retry_on_timeout` to `True`
180-
then set `retry` to a valid `Retry` object
180+
To specify a retry policy for specific errors, first set
181+
`retry_on_error` to a list of the error/s to retry on, then set
182+
`retry` to a valid `Retry` object.
183+
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
181184
"""
182185
kwargs: Dict[str, Any]
183186
# auto_close_connection_pool only has an effect if connection_pool is
@@ -188,6 +191,10 @@ def __init__(
188191
auto_close_connection_pool if connection_pool is None else False
189192
)
190193
if not connection_pool:
194+
if not retry_on_error:
195+
retry_on_error = []
196+
if retry_on_timeout is True:
197+
retry_on_error.append(TimeoutError)
191198
kwargs = {
192199
"db": db,
193200
"username": username,
@@ -197,6 +204,7 @@ def __init__(
197204
"encoding_errors": encoding_errors,
198205
"decode_responses": decode_responses,
199206
"retry_on_timeout": retry_on_timeout,
207+
"retry_on_error": retry_on_error,
200208
"retry": copy.deepcopy(retry),
201209
"max_connections": max_connections,
202210
"health_check_interval": health_check_interval,
@@ -461,7 +469,10 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
461469
is not a TimeoutError
462470
"""
463471
await conn.disconnect()
464-
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
472+
if (
473+
conn.retry_on_error is None
474+
or isinstance(error, tuple(conn.retry_on_error)) is False
475+
):
465476
raise error
466477

467478
# COMMAND EXECUTION AND PROTOCOL PARSING

redis/asyncio/connection.py

+17
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,7 @@ class Connection:
578578
"socket_type",
579579
"redis_connect_func",
580580
"retry_on_timeout",
581+
"retry_on_error",
581582
"health_check_interval",
582583
"next_health_check",
583584
"last_active_at",
@@ -606,6 +607,7 @@ def __init__(
606607
socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
607608
socket_type: int = 0,
608609
retry_on_timeout: bool = False,
610+
retry_on_error: Union[list, _Sentinel] = SENTINEL,
609611
encoding: str = "utf-8",
610612
encoding_errors: str = "strict",
611613
decode_responses: bool = False,
@@ -631,12 +633,19 @@ def __init__(
631633
self.socket_keepalive_options = socket_keepalive_options or {}
632634
self.socket_type = socket_type
633635
self.retry_on_timeout = retry_on_timeout
636+
if retry_on_error is SENTINEL:
637+
retry_on_error = []
634638
if retry_on_timeout:
639+
retry_on_error.append(TimeoutError)
640+
self.retry_on_error = retry_on_error
641+
if retry_on_error:
635642
if not retry:
636643
self.retry = Retry(NoBackoff(), 1)
637644
else:
638645
# deep-copy the Retry object as it is mutable
639646
self.retry = copy.deepcopy(retry)
647+
# Update the retry's supported errors with the specified errors
648+
self.retry.update_supported_errors(retry_on_error)
640649
else:
641650
self.retry = Retry(NoBackoff(), 0)
642651
self.health_check_interval = health_check_interval
@@ -1169,6 +1178,7 @@ def __init__(
11691178
encoding_errors: str = "strict",
11701179
decode_responses: bool = False,
11711180
retry_on_timeout: bool = False,
1181+
retry_on_error: Union[list, _Sentinel] = SENTINEL,
11721182
parser_class: Type[BaseParser] = DefaultParser,
11731183
socket_read_size: int = 65536,
11741184
health_check_interval: float = 0.0,
@@ -1190,12 +1200,19 @@ def __init__(
11901200
self.socket_timeout = socket_timeout
11911201
self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None
11921202
self.retry_on_timeout = retry_on_timeout
1203+
if retry_on_error is SENTINEL:
1204+
retry_on_error = []
11931205
if retry_on_timeout:
1206+
retry_on_error.append(TimeoutError)
1207+
self.retry_on_error = retry_on_error
1208+
if retry_on_error:
11941209
if retry is None:
11951210
self.retry = Retry(NoBackoff(), 1)
11961211
else:
11971212
# deep-copy the Retry object as it is mutable
11981213
self.retry = copy.deepcopy(retry)
1214+
# Update the retry's supported errors with the specified errors
1215+
self.retry.update_supported_errors(retry_on_error)
11991216
else:
12001217
self.retry = Retry(NoBackoff(), 0)
12011218
self.health_check_interval = health_check_interval

redis/asyncio/retry.py

+8
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ def __init__(
3535
self._retries = retries
3636
self._supported_errors = supported_errors
3737

38+
def update_supported_errors(self, specified_errors: list):
39+
"""
40+
Updates the supported errors with the specified error types
41+
"""
42+
self._supported_errors = tuple(
43+
set(self._supported_errors + tuple(specified_errors))
44+
)
45+
3846
async def call_with_retry(
3947
self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
4048
) -> T:

redis/client.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ def __init__(
914914
errors=None,
915915
decode_responses=False,
916916
retry_on_timeout=False,
917-
retry_on_error=[],
917+
retry_on_error=None,
918918
ssl=False,
919919
ssl_keyfile=None,
920920
ssl_certfile=None,
@@ -958,6 +958,8 @@ def __init__(
958958
)
959959
)
960960
encoding_errors = errors
961+
if not retry_on_error:
962+
retry_on_error = []
961963
if retry_on_timeout is True:
962964
retry_on_error.append(TimeoutError)
963965
kwargs = {

redis/connection.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def __init__(
515515
socket_keepalive_options=None,
516516
socket_type=0,
517517
retry_on_timeout=False,
518-
retry_on_error=[],
518+
retry_on_error=SENTINEL,
519519
encoding="utf-8",
520520
encoding_errors="strict",
521521
decode_responses=False,
@@ -547,6 +547,8 @@ def __init__(
547547
self.socket_keepalive_options = socket_keepalive_options or {}
548548
self.socket_type = socket_type
549549
self.retry_on_timeout = retry_on_timeout
550+
if retry_on_error is SENTINEL:
551+
retry_on_error = []
550552
if retry_on_timeout:
551553
# Add TimeoutError to the errors list to retry on
552554
retry_on_error.append(TimeoutError)
@@ -1065,7 +1067,7 @@ def __init__(
10651067
encoding_errors="strict",
10661068
decode_responses=False,
10671069
retry_on_timeout=False,
1068-
retry_on_error=[],
1070+
retry_on_error=SENTINEL,
10691071
parser_class=DefaultParser,
10701072
socket_read_size=65536,
10711073
health_check_interval=0,
@@ -1088,6 +1090,8 @@ def __init__(
10881090
self.password = password
10891091
self.socket_timeout = socket_timeout
10901092
self.retry_on_timeout = retry_on_timeout
1093+
if retry_on_error is SENTINEL:
1094+
retry_on_error = []
10911095
if retry_on_timeout:
10921096
# Add TimeoutError to the errors list to retry on
10931097
retry_on_error.append(TimeoutError)

tests/test_asyncio/test_retry.py

+35-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from redis.asyncio.connection import Connection, UnixDomainSocketConnection
44
from redis.asyncio.retry import Retry
55
from redis.backoff import AbstractBackoff, NoBackoff
6-
from redis.exceptions import ConnectionError
6+
from redis.exceptions import ConnectionError, TimeoutError
77

88

99
class BackoffMock(AbstractBackoff):
@@ -22,23 +22,55 @@ def compute(self, failures):
2222
class TestConnectionConstructorWithRetry:
2323
"Test that the Connection constructors properly handles Retry objects"
2424

25+
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
26+
def test_retry_on_error_set(self, Class):
27+
class CustomError(Exception):
28+
pass
29+
30+
retry_on_error = [ConnectionError, TimeoutError, CustomError]
31+
c = Class(retry_on_error=retry_on_error)
32+
assert c.retry_on_error == retry_on_error
33+
assert isinstance(c.retry, Retry)
34+
assert c.retry._retries == 1
35+
assert set(c.retry._supported_errors) == set(retry_on_error)
36+
37+
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
38+
def test_retry_on_error_not_set(self, Class):
39+
c = Class()
40+
assert c.retry_on_error == []
41+
assert isinstance(c.retry, Retry)
42+
assert c.retry._retries == 0
43+
2544
@pytest.mark.parametrize("retry_on_timeout", [False, True])
2645
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
27-
def test_retry_on_timeout_boolean(self, Class, retry_on_timeout):
46+
def test_retry_on_timeout(self, Class, retry_on_timeout):
2847
c = Class(retry_on_timeout=retry_on_timeout)
2948
assert c.retry_on_timeout == retry_on_timeout
3049
assert isinstance(c.retry, Retry)
3150
assert c.retry._retries == (1 if retry_on_timeout else 0)
3251

3352
@pytest.mark.parametrize("retries", range(10))
3453
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
35-
def test_retry_on_timeout_retry(self, Class, retries: int):
54+
def test_retry_with_retry_on_timeout(self, Class, retries: int):
3655
retry_on_timeout = retries > 0
3756
c = Class(retry_on_timeout=retry_on_timeout, retry=Retry(NoBackoff(), retries))
3857
assert c.retry_on_timeout == retry_on_timeout
3958
assert isinstance(c.retry, Retry)
4059
assert c.retry._retries == retries
4160

61+
@pytest.mark.parametrize("retries", range(10))
62+
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
63+
def test_retry_with_retry_on_error(self, Class, retries: int):
64+
class CustomError(Exception):
65+
pass
66+
67+
retry_on_error = [ConnectionError, TimeoutError, CustomError]
68+
c = Class(retry_on_error=retry_on_error, retry=Retry(NoBackoff(), retries))
69+
assert c.retry_on_error == retry_on_error
70+
assert isinstance(c.retry, Retry)
71+
assert c.retry._retries == retries
72+
assert set(c.retry._supported_errors) == set(retry_on_error)
73+
4274

4375
class TestRetry:
4476
"Test that Retry calls backoff and retries the expected number of times"

0 commit comments

Comments
 (0)