diff --git a/redis/backoff.py b/redis/backoff.py index e236764d71..22a3ed0abb 100644 --- a/redis/backoff.py +++ b/redis/backoff.py @@ -31,6 +31,15 @@ def __init__(self, backoff: float) -> None: """`backoff`: backoff time in seconds""" self._backoff = backoff + def __hash__(self) -> int: + return hash((self._backoff,)) + + def __eq__(self, other) -> bool: + if not isinstance(other, ConstantBackoff): + return NotImplemented + + return self._backoff == other._backoff + def compute(self, failures: int) -> float: return self._backoff @@ -53,6 +62,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE): self._cap = cap self._base = base + def __hash__(self) -> int: + return hash((self._base, self._cap)) + + def __eq__(self, other) -> bool: + if not isinstance(other, ExponentialBackoff): + return NotImplemented + + return self._base == other._base and self._cap == other._cap + def compute(self, failures: int) -> float: return min(self._cap, self._base * 2**failures) @@ -68,6 +86,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None self._cap = cap self._base = base + def __hash__(self) -> int: + return hash((self._base, self._cap)) + + def __eq__(self, other) -> bool: + if not isinstance(other, FullJitterBackoff): + return NotImplemented + + return self._base == other._base and self._cap == other._cap + def compute(self, failures: int) -> float: return random.uniform(0, min(self._cap, self._base * 2**failures)) @@ -83,6 +110,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None self._cap = cap self._base = base + def __hash__(self) -> int: + return hash((self._base, self._cap)) + + def __eq__(self, other) -> bool: + if not isinstance(other, EqualJitterBackoff): + return NotImplemented + + return self._base == other._base and self._cap == other._cap + def compute(self, failures: int) -> float: temp = min(self._cap, self._base * 2**failures) / 2 return temp + random.uniform(0, temp) @@ -100,6 +136,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None self._base = base self._previous_backoff = 0 + def __hash__(self) -> int: + return hash((self._base, self._cap)) + + def __eq__(self, other) -> bool: + if not isinstance(other, DecorrelatedJitterBackoff): + return NotImplemented + + return self._base == other._base and self._cap == other._cap + def reset(self) -> None: self._previous_backoff = 0 @@ -121,6 +166,15 @@ def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None self._cap = cap self._base = base + def __hash__(self) -> int: + return hash((self._base, self._cap)) + + def __eq__(self, other) -> bool: + if not isinstance(other, EqualJitterBackoff): + return NotImplemented + + return self._base == other._base and self._cap == other._cap + def compute(self, failures: int) -> float: return min(self._cap, random.random() * self._base * 2**failures) diff --git a/redis/retry.py b/redis/retry.py index ca9ea76f24..c93f34e65f 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -34,6 +34,19 @@ def __init__( self._retries = retries self._supported_errors = supported_errors + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Retry): + return NotImplemented + + return ( + self._backoff == other._backoff + and self._retries == other._retries + and set(self._supported_errors) == set(other._supported_errors) + ) + + def __hash__(self) -> int: + return hash((self._backoff, self._retries, frozenset(self._supported_errors))) + def update_supported_errors( self, specified_errors: Iterable[Type[Exception]] ) -> None: diff --git a/tests/test_retry.py b/tests/test_retry.py index 926fe28313..4f4f04caca 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,7 +1,16 @@ from unittest.mock import patch import pytest -from redis.backoff import AbstractBackoff, ExponentialBackoff, NoBackoff +from redis.backoff import ( + AbstractBackoff, + ConstantBackoff, + DecorrelatedJitterBackoff, + EqualJitterBackoff, + ExponentialBackoff, + ExponentialWithJitterBackoff, + FullJitterBackoff, + NoBackoff, +) from redis.client import Redis from redis.connection import Connection, UnixDomainSocketConnection from redis.exceptions import ( @@ -80,6 +89,40 @@ def test_retry_on_error_retry(self, Class, retries): assert c.retry._retries == retries +@pytest.mark.parametrize( + "args", + [ + (ConstantBackoff(0), 0), + (ConstantBackoff(10), 5), + (NoBackoff(), 0), + ] + + [ + backoff + for Backoff in ( + DecorrelatedJitterBackoff, + EqualJitterBackoff, + ExponentialBackoff, + ExponentialWithJitterBackoff, + FullJitterBackoff, + ) + for backoff in ((Backoff(), 2), (Backoff(25), 5), (Backoff(25, 5), 5)) + ], +) +def test_retry_eq_and_hashable(args): + assert Retry(*args) == Retry(*args) + + # create another retry object with different parameters + copy = list(args) + if isinstance(copy[0], ConstantBackoff): + copy[1] = 9000 + else: + copy[0] = ConstantBackoff(9000) + + assert Retry(*args) != Retry(*copy) + assert Retry(*copy) != Retry(*args) + assert len({Retry(*args), Retry(*args), Retry(*copy), Retry(*copy)}) == 2 + + class TestRetry: "Test that Retry calls backoff and retries the expected number of times"