Skip to content

Commit 0d530a6

Browse files
authored
feat(utils): Add helpers for circuit breaker and circuit breaker tests (#74559)
This is a follow-up to #74557, which added the beginnings of a rate-limit-based circuit breaker, in the form of a new `CircuitBreaker` class. In this PR, various helpers, for checking the state of the breaker and the underlying rate limiters and for communicating with redis, have been added to the class. It also adds a `MockCircuitBreaker` subclass for use in tests, which contains a number of methods for mocking both circuit breaker and rate limiter state. Note that though these helpers don't have accompanying tests, they are tested (indirectly) in the final PR in the series[1], as part of testing the methods which use them. [1] #74560
1 parent 0066487 commit 0d530a6

File tree

2 files changed

+328
-7
lines changed

2 files changed

+328
-7
lines changed

Diff for: src/sentry/utils/circuit_breaker2.py

+124-2
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
"""
77

88
import logging
9+
import time
910
from enum import Enum
10-
from typing import NotRequired, TypedDict
11+
from typing import Any, Literal, NotRequired, TypedDict, overload
1112

1213
from django.conf import settings
1314

14-
from sentry.ratelimits.sliding_windows import Quota, RedisSlidingWindowRateLimiter
15+
from sentry.ratelimits.sliding_windows import Quota, RedisSlidingWindowRateLimiter, RequestedQuota
1516

1617
logger = logging.getLogger(__name__)
1718

@@ -180,3 +181,124 @@ def __init__(self, key: str, config: CircuitBreakerConfig):
180181
default_recovery_duration,
181182
)
182183
self.recovery_duration = default_recovery_duration
184+
185+
def _get_from_redis(self, keys: list[str]) -> Any:
186+
for key in keys:
187+
self.redis_pipeline.get(key)
188+
return self.redis_pipeline.execute()
189+
190+
def _set_in_redis(self, keys_values_and_timeouts: list[tuple[str, Any, int]]) -> None:
191+
for key, value, timeout in keys_values_and_timeouts:
192+
self.redis_pipeline.set(key, value, timeout)
193+
self.redis_pipeline.execute()
194+
195+
def _get_state_and_remaining_time(
196+
self,
197+
) -> tuple[CircuitBreakerState, int | None]:
198+
"""
199+
Return the current state of the breaker (OK, BROKEN, or in RECOVERY), along with the
200+
number of seconds until that state expires (or `None` when in OK state, as it has no
201+
expiry).
202+
"""
203+
now = int(time.time())
204+
205+
try:
206+
broken_state_expiry, recovery_state_expiry = self._get_from_redis(
207+
[self.broken_state_key, self.recovery_state_key]
208+
)
209+
except Exception:
210+
logger.exception("Couldn't get state from redis for circuit breaker '%s'", self.key)
211+
212+
# Default to letting traffic through so the breaker doesn't become a single point of failure
213+
return (CircuitBreakerState.OK, None)
214+
215+
# The BROKEN state key should always expire before the RECOVERY state one, so check it first
216+
if broken_state_expiry is not None:
217+
broken_state_seconds_left = int(broken_state_expiry) - now
218+
219+
# In theory there should always be time left (the key should have expired otherwise),
220+
# but race conditions/caching/etc means we should check, just to be sure
221+
if broken_state_seconds_left > 0:
222+
return (CircuitBreakerState.BROKEN, broken_state_seconds_left)
223+
224+
if recovery_state_expiry is not None:
225+
recovery_state_seconds_left = int(recovery_state_expiry) - now
226+
if recovery_state_seconds_left > 0:
227+
return (CircuitBreakerState.RECOVERY, recovery_state_seconds_left)
228+
229+
return (CircuitBreakerState.OK, None)
230+
231+
@overload
232+
def _get_controlling_quota(
233+
self, state: Literal[CircuitBreakerState.OK, CircuitBreakerState.RECOVERY]
234+
) -> Quota:
235+
...
236+
237+
@overload
238+
def _get_controlling_quota(self, state: Literal[CircuitBreakerState.BROKEN]) -> None:
239+
...
240+
241+
@overload
242+
def _get_controlling_quota(self) -> Quota | None:
243+
...
244+
245+
def _get_controlling_quota(self, state: CircuitBreakerState | None = None) -> Quota | None:
246+
"""
247+
Return the Quota corresponding to the given breaker state (or the current breaker state, if
248+
no state is provided). If the state is question is the BROKEN state, return None.
249+
"""
250+
controlling_quota_by_state = {
251+
CircuitBreakerState.OK: self.primary_quota,
252+
CircuitBreakerState.BROKEN: None,
253+
CircuitBreakerState.RECOVERY: self.recovery_quota,
254+
}
255+
256+
_state = state or self._get_state_and_remaining_time()[0]
257+
258+
return controlling_quota_by_state[_state]
259+
260+
@overload
261+
def _get_remaining_error_quota(self, quota: None, window_end: int | None) -> None:
262+
...
263+
264+
@overload
265+
def _get_remaining_error_quota(self, quota: Quota, window_end: int | None) -> int:
266+
...
267+
268+
@overload
269+
def _get_remaining_error_quota(self, quota: None) -> None:
270+
...
271+
272+
@overload
273+
def _get_remaining_error_quota(self, quota: Quota) -> int:
274+
...
275+
276+
@overload
277+
def _get_remaining_error_quota(self) -> int | None:
278+
...
279+
280+
def _get_remaining_error_quota(
281+
self, quota: Quota | None = None, window_end: int | None = None
282+
) -> int | None:
283+
"""
284+
Get the number of allowable errors remaining in the given quota for the time window ending
285+
at the given time.
286+
287+
If no quota is given, in OK and RECOVERY states, return the current controlling quota's
288+
remaining errors. In BROKEN state, return None.
289+
290+
If no time window end is given, return the current amount of quota remaining.
291+
"""
292+
if not quota:
293+
quota = self._get_controlling_quota()
294+
if quota is None: # BROKEN state
295+
return None
296+
297+
now = int(time.time())
298+
window_end = window_end or now
299+
300+
_, result = self.limiter.check_within_quotas(
301+
[RequestedQuota(self.key, quota.limit, [quota])], window_end
302+
)
303+
304+
return result[0].granted

Diff for: tests/sentry/utils/test_circuit_breaker2.py

+204-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
import time
2+
from typing import Any
13
from unittest import TestCase
24
from unittest.mock import ANY, MagicMock, patch
35

46
from django.conf import settings
57
from redis.client import Pipeline
68

7-
from sentry.ratelimits.sliding_windows import Quota, RedisSlidingWindowRateLimiter
9+
from sentry.ratelimits.sliding_windows import (
10+
GrantedQuota,
11+
Quota,
12+
RedisSlidingWindowRateLimiter,
13+
RequestedQuota,
14+
)
815
from sentry.testutils.helpers.datetime import freeze_time
9-
from sentry.utils.circuit_breaker2 import CircuitBreaker, CircuitBreakerConfig
16+
from sentry.utils.circuit_breaker2 import CircuitBreaker, CircuitBreakerConfig, CircuitBreakerState
1017

1118
# Note: These need to be relatively big. If the limit is too low, the RECOVERY quota isn't big
1219
# enough to be useful, and if the window is too short, redis (which doesn't seem to listen to the
@@ -18,11 +25,203 @@
1825
}
1926

2027

28+
class MockCircuitBreaker(CircuitBreaker):
29+
"""
30+
A circuit breaker with extra methods useful for mocking state.
31+
32+
To understand the methods below, it helps to understand the `RedisSlidingWindowRateLimiter`
33+
which powers the circuit breaker. Details can be found in
34+
https://github.com/getsentry/sentry-redis-tools/blob/d4f3dc883b1137d82b6b7a92f4b5b41991c1fc8a/sentry_redis_tools/sliding_windows_rate_limiter.py,
35+
(which is the implementation behind the rate limiter) but TL;DR, quota usage during the time
36+
window is tallied in buckets ("granules"), and as time passes the window slides forward one
37+
granule at a time. To be able to mimic this, most of the methods here operate at the granule
38+
level.
39+
"""
40+
41+
def _set_breaker_state(
42+
self, state: CircuitBreakerState, seconds_left: int | None = None
43+
) -> None:
44+
"""
45+
Adjust redis keys to force the breaker into the given state. If no remaining seconds are
46+
given, puts the breaker at the beginning of its time in the given state.
47+
"""
48+
now = int(time.time())
49+
50+
if state == CircuitBreakerState.OK:
51+
self._delete_from_redis([self.broken_state_key, self.recovery_state_key])
52+
53+
elif state == CircuitBreakerState.BROKEN:
54+
broken_state_timeout = seconds_left or self.broken_state_duration
55+
broken_state_end = now + broken_state_timeout
56+
recovery_timeout = broken_state_timeout + self.recovery_duration
57+
recovery_end = now + recovery_timeout
58+
59+
self._set_in_redis(
60+
[
61+
(self.broken_state_key, broken_state_end, broken_state_timeout),
62+
(self.recovery_state_key, recovery_end, recovery_timeout),
63+
]
64+
)
65+
66+
elif state == CircuitBreakerState.RECOVERY:
67+
recovery_timeout = seconds_left or self.recovery_duration
68+
recovery_end = now + recovery_timeout
69+
70+
self._delete_from_redis([self.broken_state_key])
71+
self._set_in_redis([(self.recovery_state_key, recovery_end, recovery_timeout)])
72+
73+
assert self._get_state_and_remaining_time() == (
74+
state,
75+
(
76+
None
77+
if state == CircuitBreakerState.OK
78+
else (
79+
broken_state_timeout
80+
if state == CircuitBreakerState.BROKEN
81+
else recovery_timeout
82+
)
83+
),
84+
)
85+
86+
def _add_quota_usage(
87+
self,
88+
quota: Quota,
89+
amount_used: int,
90+
granule_or_window_end: int | None = None,
91+
) -> None:
92+
"""
93+
Add to the usage total of the given quota, in the granule or window ending at the given
94+
time. If a window (rather than a granule) end time is given, usage will be added to the
95+
final granule.
96+
97+
If no end time is given, the current time will be used.
98+
"""
99+
now = int(time.time())
100+
window_end_time = granule_or_window_end or now
101+
102+
self.limiter.use_quotas(
103+
[RequestedQuota(self.key, amount_used, [quota])],
104+
[GrantedQuota(self.key, amount_used, [])],
105+
window_end_time,
106+
)
107+
108+
def _clear_quota(self, quota: Quota, window_end: int | None = None) -> list[int]:
109+
"""
110+
Clear usage of the given quota up until the end of the given time window. If no window end
111+
is given, clear the quota up to the present.
112+
113+
Returns the list of granule values which were cleared.
114+
"""
115+
now = int(time.time())
116+
window_end_time = window_end or now
117+
granule_end_times = self._get_granule_end_times(quota, window_end_time)
118+
num_granules = len(granule_end_times)
119+
previous_granule_values = [0] * num_granules
120+
121+
current_total_quota_used = quota.limit - self._get_remaining_error_quota(
122+
quota, window_end_time
123+
)
124+
if current_total_quota_used != 0:
125+
# Empty the granules one by one, starting with the oldest.
126+
#
127+
# To empty each granule, we need to add negative quota usage, which means we need to
128+
# know how much usage is currently in each granule. Unfortunately, the limiter will only
129+
# report quota usage at the window level, not the granule level. To get around this, we
130+
# start with a window ending with the oldest granule. Any granules before it will have
131+
# expired, so the window usage will equal the granule usage.ending in that granule will
132+
# have a total usage equal to that of the granule.
133+
#
134+
# Once we zero-out the granule, we can move the window one granule forward. It will now
135+
# consist of expired granules, the granule we just set to 0, and the granule we care
136+
# about. Thus the window usage will again match the granule usage, which we can use to
137+
# empty the granule. We then just repeat the pattern until we've reached the end of the
138+
# window we want to clear.
139+
for i, granule_end_time in enumerate(granule_end_times):
140+
granule_quota_used = quota.limit - self._get_remaining_error_quota(
141+
quota, granule_end_time
142+
)
143+
previous_granule_values[i] = granule_quota_used
144+
self._add_quota_usage(quota, -granule_quota_used, granule_end_time)
145+
146+
new_total_quota_used = quota.limit - self._get_remaining_error_quota(
147+
quota, window_end_time
148+
)
149+
assert new_total_quota_used == 0
150+
151+
return previous_granule_values
152+
153+
def _get_granule_end_times(
154+
self, quota: Quota, window_end: int, newest_first: bool = False
155+
) -> list[int]:
156+
"""
157+
Given a quota and the end of the time window it's covering, return the timestamps
158+
corresponding to the end of each granule.
159+
"""
160+
window_duration = quota.window_seconds
161+
granule_duration = quota.granularity_seconds
162+
num_granules = window_duration // granule_duration
163+
164+
# Walk backwards through the granules
165+
end_times_newest_first = [
166+
window_end - num_granules_ago * granule_duration
167+
for num_granules_ago in range(num_granules)
168+
]
169+
170+
return end_times_newest_first if newest_first else list(reversed(end_times_newest_first))
171+
172+
def _set_granule_values(
173+
self,
174+
quota: Quota,
175+
values: list[int | None],
176+
window_end: int | None = None,
177+
) -> None:
178+
"""
179+
Set the usage in each granule of the given quota, for the time window ending at the given
180+
time.
181+
182+
If no ending time is given, the current time is used.
183+
184+
The list of values should be ordered from oldest to newest and must contain the same number
185+
of elements as the window has granules. To only change some of the values, pass `None` in
186+
the spot of any value which should remain unchanged. (For example, in a two-granule window,
187+
to only change the older granule, pass `[3, None]`.)
188+
"""
189+
window_duration = quota.window_seconds
190+
granule_duration = quota.granularity_seconds
191+
num_granules = window_duration // granule_duration
192+
193+
if len(values) != num_granules:
194+
raise Exception(
195+
f"Exactly {num_granules} granule values must be provided. "
196+
+ "To leave an existing value as is, include `None` in its spot."
197+
)
198+
199+
now = int(time.time())
200+
window_end_time = window_end or now
201+
202+
previous_values = self._clear_quota(quota, window_end_time)
203+
204+
for i, granule_end_time, value in zip(
205+
range(num_granules), self._get_granule_end_times(quota, window_end_time), values
206+
):
207+
# When we cleared the quota above, we set each granule's value to 0, so here "adding"
208+
# usage is actually setting usage
209+
if value is not None:
210+
self._add_quota_usage(quota, value, granule_end_time)
211+
else:
212+
self._add_quota_usage(quota, previous_values[i], granule_end_time)
213+
214+
def _delete_from_redis(self, keys: list[str]) -> Any:
215+
for key in keys:
216+
self.redis_pipeline.delete(key)
217+
return self.redis_pipeline.execute()
218+
219+
21220
@freeze_time()
22221
class CircuitBreakerTest(TestCase):
23222
def setUp(self) -> None:
24223
self.config = DEFAULT_CONFIG
25-
self.breaker = CircuitBreaker("dogs_are_great", self.config)
224+
self.breaker = MockCircuitBreaker("dogs_are_great", self.config)
26225

27226
# Clear all existing keys from redis
28227
self.breaker.redis_pipeline.flushall()
@@ -78,7 +277,7 @@ def test_fixes_too_loose_recovery_limit(self, mock_logger: MagicMock):
78277
(False, mock_logger.warning),
79278
]:
80279
settings.DEBUG = settings_debug_value
81-
breaker = CircuitBreaker("dogs_are_great", config)
280+
breaker = MockCircuitBreaker("dogs_are_great", config)
82281

83282
expected_log_function.assert_called_with(
84283
"Circuit breaker '%s' has a recovery error limit (%d) greater than or equal"
@@ -104,7 +303,7 @@ def test_fixes_mismatched_state_durations(self, mock_logger: MagicMock):
104303
(False, mock_logger.warning),
105304
]:
106305
settings.DEBUG = settings_debug_value
107-
breaker = CircuitBreaker("dogs_are_great", config)
306+
breaker = MockCircuitBreaker("dogs_are_great", config)
108307

109308
expected_log_function.assert_called_with(
110309
"Circuit breaker '%s' has BROKEN and RECOVERY state durations (%d and %d sec, respectively)"

0 commit comments

Comments
 (0)