Skip to content

Commit 5b9fe2e

Browse files
committed
add equality and hashability to Retry and backoff classes
1 parent 120517f commit 5b9fe2e

File tree

3 files changed

+115
-1
lines changed

3 files changed

+115
-1
lines changed

redis/backoff.py

+58
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ def __init__(self, backoff: float) -> None:
3131
"""`backoff`: backoff time in seconds"""
3232
self._backoff = backoff
3333

34+
def __hash__(self) -> int:
35+
return hash((self._backoff,))
36+
37+
def __eq__(self, other) -> bool:
38+
if not isinstance(other, ConstantBackoff):
39+
return NotImplemented
40+
41+
return self._backoff == other._backoff
42+
3443
def compute(self, failures: int) -> float:
3544
return self._backoff
3645

@@ -53,6 +62,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE):
5362
self._cap = cap
5463
self._base = base
5564

65+
def __hash__(self) -> int:
66+
return hash((self._base, self._cap))
67+
68+
def __eq__(self, other) -> bool:
69+
if not isinstance(other, ExponentialBackoff):
70+
return NotImplemented
71+
72+
return self._base == other._base and self._cap == other._cap
73+
5674
def compute(self, failures: int) -> float:
5775
return min(self._cap, self._base * 2**failures)
5876

@@ -68,6 +86,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None
6886
self._cap = cap
6987
self._base = base
7088

89+
def __hash__(self) -> int:
90+
return hash((self._base, self._cap))
91+
92+
def __eq__(self, other) -> bool:
93+
if not isinstance(other, FullJitterBackoff):
94+
return NotImplemented
95+
96+
return self._base == other._base and self._cap == other._cap
97+
7198
def compute(self, failures: int) -> float:
7299
return random.uniform(0, min(self._cap, self._base * 2**failures))
73100

@@ -83,6 +110,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None
83110
self._cap = cap
84111
self._base = base
85112

113+
def __hash__(self) -> int:
114+
return hash((self._base, self._cap))
115+
116+
def __eq__(self, other) -> bool:
117+
if not isinstance(other, EqualJitterBackoff):
118+
return NotImplemented
119+
120+
return self._base == other._base and self._cap == other._cap
121+
86122
def compute(self, failures: int) -> float:
87123
temp = min(self._cap, self._base * 2**failures) / 2
88124
return temp + random.uniform(0, temp)
@@ -100,6 +136,19 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None
100136
self._base = base
101137
self._previous_backoff = 0
102138

139+
def __hash__(self) -> int:
140+
return hash((self._base, self._cap))
141+
142+
def __eq__(self, other) -> bool:
143+
if not isinstance(other, DecorrelatedJitterBackoff):
144+
return NotImplemented
145+
146+
return (
147+
self._base == other._base
148+
and self._cap == other._cap
149+
and self._previous_backoff == other._previous_backoff
150+
)
151+
103152
def reset(self) -> None:
104153
self._previous_backoff = 0
105154

@@ -121,6 +170,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None
121170
self._cap = cap
122171
self._base = base
123172

173+
def __hash__(self) -> int:
174+
return hash((self._base, self._cap))
175+
176+
def __eq__(self, other) -> bool:
177+
if not isinstance(other, EqualJitterBackoff):
178+
return NotImplemented
179+
180+
return self._base == other._base and self._cap == other._cap
181+
124182
def compute(self, failures: int) -> float:
125183
return min(self._cap, random.random() * self._base * 2**failures)
126184

redis/retry.py

+13
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ def __init__(
3434
self._retries = retries
3535
self._supported_errors = supported_errors
3636

37+
def __eq__(self, other: Any) -> bool:
38+
if not isinstance(other, Retry):
39+
return NotImplemented
40+
41+
return (
42+
self._backoff == other._backoff
43+
and self._retries == other._retries
44+
and set(self._supported_errors) == set(other._supported_errors)
45+
)
46+
47+
def __hash__(self) -> int:
48+
return hash((self._backoff, self._retries, frozenset(self._supported_errors)))
49+
3750
def update_supported_errors(
3851
self, specified_errors: Iterable[Type[Exception]]
3952
) -> None:

tests/test_retry.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
from unittest.mock import patch
22

33
import pytest
4-
from redis.backoff import AbstractBackoff, ExponentialBackoff, NoBackoff
4+
from redis.backoff import (
5+
AbstractBackoff,
6+
ConstantBackoff,
7+
DecorrelatedJitterBackoff,
8+
EqualJitterBackoff,
9+
ExponentialBackoff,
10+
ExponentialWithJitterBackoff,
11+
FullJitterBackoff,
12+
NoBackoff,
13+
)
514
from redis.client import Redis
615
from redis.connection import Connection, UnixDomainSocketConnection
716
from redis.exceptions import (
@@ -80,6 +89,40 @@ def test_retry_on_error_retry(self, Class, retries):
8089
assert c.retry._retries == retries
8190

8291

92+
@pytest.mark.parametrize(
93+
"args",
94+
[
95+
(ConstantBackoff(0), 0),
96+
(ConstantBackoff(10), 5),
97+
(NoBackoff(), 0),
98+
]
99+
+ [
100+
backoff
101+
for Backoff in (
102+
DecorrelatedJitterBackoff,
103+
EqualJitterBackoff,
104+
ExponentialBackoff,
105+
ExponentialWithJitterBackoff,
106+
FullJitterBackoff,
107+
)
108+
for backoff in ((Backoff(), 2), (Backoff(25), 5), (Backoff(25, 5), 5))
109+
],
110+
)
111+
def test_retry_eq_and_hashable(args):
112+
assert Retry(*args) == Retry(*args)
113+
114+
# create another retry object with different parameters
115+
copy = list(args)
116+
if isinstance(copy[0], ConstantBackoff):
117+
copy[1] = 9000
118+
else:
119+
copy[0] = ConstantBackoff(9000)
120+
121+
assert Retry(*args) != Retry(*copy)
122+
assert Retry(*copy) != Retry(*args)
123+
assert len({Retry(*args), Retry(*args), Retry(*copy), Retry(*copy)}) == 2
124+
125+
83126
class TestRetry:
84127
"Test that Retry calls backoff and retries the expected number of times"
85128

0 commit comments

Comments
 (0)