Skip to content

add equality and hashability to Retry and backoff classes #3628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions redis/backoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ def __init__(self, backoff: float) -> None:
"""`backoff`: backoff time in seconds"""
self._backoff = backoff

def __hash__(self) -> int:
return hash((self._backoff,))

def __eq__(self, other) -> bool:
if not isinstance(other, ConstantBackoff):
return NotImplemented

return self._backoff == other._backoff

def compute(self, failures: int) -> float:
return self._backoff

Expand All @@ -53,6 +62,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE):
self._cap = cap
self._base = base

def __hash__(self) -> int:
return hash((self._base, self._cap))

def __eq__(self, other) -> bool:
if not isinstance(other, ExponentialBackoff):
return NotImplemented

return self._base == other._base and self._cap == other._cap

def compute(self, failures: int) -> float:
return min(self._cap, self._base * 2**failures)

Expand All @@ -68,6 +86,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None
self._cap = cap
self._base = base

def __hash__(self) -> int:
return hash((self._base, self._cap))

def __eq__(self, other) -> bool:
if not isinstance(other, FullJitterBackoff):
return NotImplemented

return self._base == other._base and self._cap == other._cap

def compute(self, failures: int) -> float:
return random.uniform(0, min(self._cap, self._base * 2**failures))

Expand All @@ -83,6 +110,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None
self._cap = cap
self._base = base

def __hash__(self) -> int:
return hash((self._base, self._cap))

def __eq__(self, other) -> bool:
if not isinstance(other, EqualJitterBackoff):
return NotImplemented

return self._base == other._base and self._cap == other._cap

def compute(self, failures: int) -> float:
temp = min(self._cap, self._base * 2**failures) / 2
return temp + random.uniform(0, temp)
Expand All @@ -100,6 +136,19 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None
self._base = base
self._previous_backoff = 0

def __hash__(self) -> int:
return hash((self._base, self._cap))

def __eq__(self, other) -> bool:
if not isinstance(other, DecorrelatedJitterBackoff):
return NotImplemented

return (
self._base == other._base
and self._cap == other._cap
and self._previous_backoff == other._previous_backoff
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you should not compare the _previous_backoff fields- they represent the progress/state of the objects, not the configuration. For example, it is also not included in the hash of the object.

)

def reset(self) -> None:
self._previous_backoff = 0

Expand All @@ -121,6 +170,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None
self._cap = cap
self._base = base

def __hash__(self) -> int:
return hash((self._base, self._cap))

def __eq__(self, other) -> bool:
if not isinstance(other, EqualJitterBackoff):
return NotImplemented

return self._base == other._base and self._cap == other._cap

def compute(self, failures: int) -> float:
return min(self._cap, random.random() * self._base * 2**failures)

Expand Down
13 changes: 13 additions & 0 deletions redis/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ def __init__(
self._retries = retries
self._supported_errors = supported_errors

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Retry):
return NotImplemented

return (
self._backoff == other._backoff
and self._retries == other._retries
and set(self._supported_errors) == set(other._supported_errors)
)

def __hash__(self) -> int:
return hash((self._backoff, self._retries, frozenset(self._supported_errors)))

def update_supported_errors(
self, specified_errors: Iterable[Type[Exception]]
) -> None:
Expand Down
45 changes: 44 additions & 1 deletion tests/test_retry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from unittest.mock import patch

import pytest
from redis.backoff import AbstractBackoff, ExponentialBackoff, NoBackoff
from redis.backoff import (
AbstractBackoff,
ConstantBackoff,
DecorrelatedJitterBackoff,
EqualJitterBackoff,
ExponentialBackoff,
ExponentialWithJitterBackoff,
FullJitterBackoff,
NoBackoff,
)
from redis.client import Redis
from redis.connection import Connection, UnixDomainSocketConnection
from redis.exceptions import (
Expand Down Expand Up @@ -80,6 +89,40 @@ def test_retry_on_error_retry(self, Class, retries):
assert c.retry._retries == retries


@pytest.mark.parametrize(
"args",
[
(ConstantBackoff(0), 0),
(ConstantBackoff(10), 5),
(NoBackoff(), 0),
]
+ [
backoff
for Backoff in (
DecorrelatedJitterBackoff,
EqualJitterBackoff,
ExponentialBackoff,
ExponentialWithJitterBackoff,
FullJitterBackoff,
)
for backoff in ((Backoff(), 2), (Backoff(25), 5), (Backoff(25, 5), 5))
],
)
def test_retry_eq_and_hashable(args):
assert Retry(*args) == Retry(*args)

# create another retry object with different parameters
copy = list(args)
if isinstance(copy[0], ConstantBackoff):
copy[1] = 9000
else:
copy[0] = ConstantBackoff(9000)

assert Retry(*args) != Retry(*copy)
assert Retry(*copy) != Retry(*args)
assert len({Retry(*args), Retry(*args), Retry(*copy), Retry(*copy)}) == 2


class TestRetry:
"Test that Retry calls backoff and retries the expected number of times"

Expand Down
Loading