Skip to content

Commit 0066487

Browse files
author
Bartek Ogryczak
authored
chore(typing): stricter decorator type checking for leaky bucket (#74687)
Adds stricter decorator type checking for leaky bucket
1 parent 62d415c commit 0066487

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

Diff for: src/sentry/ratelimits/leaky_bucket.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
from collections.abc import Callable
66
from dataclasses import dataclass
77
from time import time
8-
from typing import Any
8+
from typing import Any, ParamSpec, TypeVar
99

1010
from django.conf import settings
1111

1212
from sentry.exceptions import InvalidConfiguration
1313
from sentry.utils import redis
1414

15+
P = ParamSpec("P")
16+
R = TypeVar("R")
17+
1518
logger = logging.getLogger(__name__)
1619

1720

@@ -113,9 +116,9 @@ def get_bucket_state(self, key: str | None = None) -> LeakyBucketLimitInfo:
113116
def decorator(
114117
self,
115118
key_override: str | None = None,
116-
limited_handler: Callable[[LeakyBucketLimitInfo, dict[str, Any]], Any] | None = None,
119+
limited_handler: Callable[[LeakyBucketLimitInfo, dict[str, Any]], R] | None = None,
117120
raise_exception: bool = False,
118-
) -> Callable[[Any], Any]:
121+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
119122
"""
120123
Decorator to limit the rate of requests
121124
@@ -181,7 +184,7 @@ def my_function():
181184
182185
"""
183186

184-
def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
187+
def decorator(func: Callable[P, R]) -> Callable[P, R]:
185188
@functools.wraps(func)
186189
def wrapper(*args: Any, **kwargs: Any) -> Any:
187190
try:

Diff for: tests/sentry/ratelimits/test_leaky_bucket.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Never
22
from unittest import mock
33

44
import pytest
@@ -64,7 +64,7 @@ def test_drip_rate(self) -> None:
6464

6565
def test_decorator(self) -> None:
6666
@self.limiter("foo")
67-
def foo() -> None:
67+
def foo() -> Never:
6868
assert False, "This should not be executed when limited"
6969

7070
with freeze_time("2077-09-13"):
@@ -75,7 +75,7 @@ def foo() -> None:
7575
assert foo() is None
7676

7777
@self.limiter("bar", raise_exception=True)
78-
def bar() -> None:
78+
def bar() -> Never:
7979
assert False, "This should not be executed when limited"
8080

8181
with freeze_time("2077-09-13"):
@@ -88,23 +88,23 @@ def bar() -> None:
8888

8989
last_info: list[LeakyBucketLimitInfo] = []
9090

91-
def callback(info: LeakyBucketLimitInfo, context: dict[str, Any]) -> LeakyBucketLimitInfo:
91+
def callback(info: LeakyBucketLimitInfo, context: dict[str, Any]) -> str:
9292
last_info.append(info)
93-
return info
93+
return "rate limited"
9494

9595
@self.limiter("baz", limited_handler=callback)
96-
def baz() -> bool:
97-
return True
96+
def baz() -> str:
97+
return "normal value"
9898

9999
with freeze_time("2077-09-13"):
100100
for i in range(5):
101-
assert baz() is True
101+
assert baz() == "normal value"
102102
assert len(last_info) == 0
103103

104-
info = baz()
105-
assert info
104+
baz_rv = baz()
105+
assert baz_rv == "rate limited"
106106
assert len(last_info) == 1
107-
assert last_info[0] == info
107+
info = last_info[0]
108108
assert info.wait_time > 0
109109
assert info.current_level == 5
110110

@@ -114,7 +114,7 @@ def test_decorator_default_key(self) -> None:
114114
with mock.patch.object(limiter, "_redis_key", wraps=limiter._redis_key) as _redis_key_spy:
115115

116116
@limiter()
117-
def foo() -> None:
117+
def foo() -> Any:
118118
pass
119119

120120
foo()

0 commit comments

Comments
 (0)