1
+ from typing import Any , Callable
2
+
1
3
import pytest
2
4
3
5
from starlette .applications import Starlette
4
6
from starlette .background import BackgroundTask
5
7
from starlette .middleware .errors import ServerErrorMiddleware
8
+ from starlette .requests import Request
6
9
from starlette .responses import JSONResponse , Response
7
10
from starlette .routing import Route
11
+ from starlette .testclient import TestClient
12
+ from starlette .types import Receive , Scope , Send
13
+
14
+ TestClientFactory = Callable [..., TestClient ]
8
15
9
16
10
- def test_handler (test_client_factory ):
11
- async def app (scope , receive , send ):
17
+ def test_handler (
18
+ test_client_factory : TestClientFactory ,
19
+ ) -> None :
20
+ async def app (scope : Scope , receive : Receive , send : Send ) -> None :
12
21
raise RuntimeError ("Something went wrong" )
13
22
14
- def error_500 (request , exc ) :
23
+ def error_500 (request : Request , exc : Exception ) -> JSONResponse :
15
24
return JSONResponse ({"detail" : "Server Error" }, status_code = 500 )
16
25
17
26
app = ServerErrorMiddleware (app , handler = error_500 )
@@ -21,8 +30,8 @@ def error_500(request, exc):
21
30
assert response .json () == {"detail" : "Server Error" }
22
31
23
32
24
- def test_debug_text (test_client_factory ) :
25
- async def app (scope , receive , send ) :
33
+ def test_debug_text (test_client_factory : TestClientFactory ) -> None :
34
+ async def app (scope : Scope , receive : Receive , send : Send ) -> None :
26
35
raise RuntimeError ("Something went wrong" )
27
36
28
37
app = ServerErrorMiddleware (app , debug = True )
@@ -33,8 +42,8 @@ async def app(scope, receive, send):
33
42
assert "RuntimeError: Something went wrong" in response .text
34
43
35
44
36
- def test_debug_html (test_client_factory ) :
37
- async def app (scope , receive , send ) :
45
+ def test_debug_html (test_client_factory : TestClientFactory ) -> None :
46
+ async def app (scope : Scope , receive : Receive , send : Send ) -> None :
38
47
raise RuntimeError ("Something went wrong" )
39
48
40
49
app = ServerErrorMiddleware (app , debug = True )
@@ -45,8 +54,8 @@ async def app(scope, receive, send):
45
54
assert "RuntimeError" in response .text
46
55
47
56
48
- def test_debug_after_response_sent (test_client_factory ) :
49
- async def app (scope , receive , send ) :
57
+ def test_debug_after_response_sent (test_client_factory : TestClientFactory ) -> None :
58
+ async def app (scope : Scope , receive : Receive , send : Send ) -> None :
50
59
response = Response (b"" , status_code = 204 )
51
60
await response (scope , receive , send )
52
61
raise RuntimeError ("Something went wrong" )
@@ -57,12 +66,12 @@ async def app(scope, receive, send):
57
66
client .get ("/" )
58
67
59
68
60
- def test_debug_not_http (test_client_factory ) :
69
+ def test_debug_not_http (test_client_factory : TestClientFactory ) -> None :
61
70
"""
62
71
DebugMiddleware should just pass through any non-http messages as-is.
63
72
"""
64
73
65
- async def app (scope , receive , send ) :
74
+ async def app (scope : Scope , receive : Receive , send : Send ) -> None :
66
75
raise RuntimeError ("Something went wrong" )
67
76
68
77
app = ServerErrorMiddleware (app )
@@ -73,17 +82,17 @@ async def app(scope, receive, send):
73
82
pass # pragma: nocover
74
83
75
84
76
- def test_background_task (test_client_factory ) :
85
+ def test_background_task (test_client_factory : TestClientFactory ) -> None :
77
86
accessed_error_handler = False
78
87
79
- def error_handler (request , exc ) :
88
+ def error_handler (request : Request , exc : Exception ) -> Any :
80
89
nonlocal accessed_error_handler
81
90
accessed_error_handler = True
82
91
83
- def raise_exception ():
92
+ def raise_exception () -> None :
84
93
raise Exception ("Something went wrong" )
85
94
86
- async def endpoint (request ) :
95
+ async def endpoint (request : Request ) -> Response :
87
96
task = BackgroundTask (raise_exception )
88
97
return Response (status_code = 204 , background = task )
89
98
0 commit comments