Skip to content

Commit 2076707

Browse files
committed
Starlette middleware
1 parent 09f065b commit 2076707

File tree

12 files changed

+719
-28
lines changed

12 files changed

+719
-28
lines changed

Diff for: docs/integrations.rst

+51
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,57 @@ Starlette
402402

403403
This section describes integration with `Starlette <https://www.starlette.io>`__ ASGI framework.
404404

405+
Middleware
406+
~~~~~~~~~~
407+
408+
Starlette can be integrated by middleware. Add ``StarletteOpenAPIMiddleware`` with ``spec`` to your ``middleware`` list.
409+
410+
.. code-block:: python
411+
:emphasize-lines: 1,6
412+
413+
from openapi_core.contrib.starlette.middlewares import StarletteOpenAPIMiddleware
414+
from starlette.applications import Starlette
415+
from starlette.middleware import Middleware
416+
417+
middleware = [
418+
Middleware(StarletteOpenAPIMiddleware, spec=spec),
419+
]
420+
421+
app = Starlette(
422+
# ...
423+
middleware=middleware,
424+
)
425+
426+
After that you have access to unmarshal result object with all validated request data from endpoint through ``openapi`` key of request's scope directory.
427+
428+
.. code-block:: python
429+
430+
async def get_endpoint(req):
431+
# get parameters object with path, query, cookies and headers parameters
432+
validated_params = req.scope["openapi"].parameters
433+
# or specific location parameters
434+
validated_path_params = req.scope["openapi"].parameters.path
435+
436+
# get body
437+
validated_body = req.scope["openapi"].body
438+
439+
# get security data
440+
validated_security = req.scope["openapi"].security
441+
442+
You can skip response validation process: by setting ``response_cls`` to ``None``
443+
444+
.. code-block:: python
445+
:emphasize-lines: 2
446+
447+
middleware = [
448+
Middleware(StarletteOpenAPIMiddleware, spec=spec, response_cls=None),
449+
]
450+
451+
app = Starlette(
452+
# ...
453+
middleware=middleware,
454+
)
455+
405456
Low level
406457
~~~~~~~~~
407458

Diff for: openapi_core/contrib/starlette/handlers.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""OpenAPI core contrib starlette handlers module"""
2+
from typing import Any
3+
from typing import Callable
4+
from typing import Dict
5+
from typing import Iterable
6+
from typing import Optional
7+
from typing import Type
8+
9+
from starlette.middleware.base import RequestResponseEndpoint
10+
from starlette.requests import Request
11+
from starlette.responses import JSONResponse
12+
from starlette.responses import Response
13+
14+
from openapi_core.templating.media_types.exceptions import MediaTypeNotFound
15+
from openapi_core.templating.paths.exceptions import OperationNotFound
16+
from openapi_core.templating.paths.exceptions import PathNotFound
17+
from openapi_core.templating.paths.exceptions import ServerNotFound
18+
from openapi_core.templating.security.exceptions import SecurityNotFound
19+
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult
20+
21+
22+
class StarletteOpenAPIErrorsHandler:
23+
OPENAPI_ERROR_STATUS: Dict[Type[BaseException], int] = {
24+
ServerNotFound: 400,
25+
SecurityNotFound: 403,
26+
OperationNotFound: 405,
27+
PathNotFound: 404,
28+
MediaTypeNotFound: 415,
29+
}
30+
31+
def __call__(
32+
self,
33+
errors: Iterable[Exception],
34+
) -> JSONResponse:
35+
data_errors = [self.format_openapi_error(err) for err in errors]
36+
data = {
37+
"errors": data_errors,
38+
}
39+
data_error_max = max(data_errors, key=self.get_error_status)
40+
return JSONResponse(data, status_code=data_error_max["status"])
41+
42+
@classmethod
43+
def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
44+
if error.__cause__ is not None:
45+
error = error.__cause__
46+
return {
47+
"title": str(error),
48+
"status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400),
49+
"type": str(type(error)),
50+
}
51+
52+
@classmethod
53+
def get_error_status(cls, error: Dict[str, Any]) -> str:
54+
return str(error["status"])
55+
56+
57+
class StarletteOpenAPIValidRequestHandler:
58+
def __init__(self, request: Request, call_next: RequestResponseEndpoint):
59+
self.request = request
60+
self.call_next = call_next
61+
62+
async def __call__(
63+
self, request_unmarshal_result: RequestUnmarshalResult
64+
) -> Response:
65+
self.request.scope["openapi"] = request_unmarshal_result
66+
return await self.call_next(self.request)

Diff for: openapi_core/contrib/starlette/middlewares.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""OpenAPI core contrib starlette middlewares module"""
2+
from typing import Callable
3+
4+
from aioitertools.builtins import list as alist
5+
from aioitertools.itertools import tee as atee
6+
from jsonschema_path import SchemaPath
7+
from starlette.middleware.base import BaseHTTPMiddleware
8+
from starlette.middleware.base import RequestResponseEndpoint
9+
from starlette.requests import Request
10+
from starlette.responses import Response
11+
from starlette.responses import StreamingResponse
12+
from starlette.types import ASGIApp
13+
14+
from openapi_core.contrib.starlette.handlers import (
15+
StarletteOpenAPIErrorsHandler,
16+
)
17+
from openapi_core.contrib.starlette.handlers import (
18+
StarletteOpenAPIValidRequestHandler,
19+
)
20+
from openapi_core.contrib.starlette.requests import StarletteOpenAPIRequest
21+
from openapi_core.contrib.starlette.responses import StarletteOpenAPIResponse
22+
from openapi_core.unmarshalling.processors import AsyncUnmarshallingProcessor
23+
24+
25+
class StarletteOpenAPIMiddleware(
26+
BaseHTTPMiddleware, AsyncUnmarshallingProcessor[Request, Response]
27+
):
28+
request_cls = StarletteOpenAPIRequest
29+
response_cls = StarletteOpenAPIResponse
30+
valid_request_handler_cls = StarletteOpenAPIValidRequestHandler
31+
errors_handler = StarletteOpenAPIErrorsHandler()
32+
33+
def __init__(self, app: ASGIApp, spec: SchemaPath):
34+
BaseHTTPMiddleware.__init__(self, app)
35+
AsyncUnmarshallingProcessor.__init__(self, spec)
36+
37+
async def dispatch(
38+
self, request: Request, call_next: RequestResponseEndpoint
39+
) -> Response:
40+
valid_request_handler = self.valid_request_handler_cls(
41+
request, call_next
42+
)
43+
response = await self.handle_request(
44+
request, valid_request_handler, self.errors_handler
45+
)
46+
return await self.handle_response(
47+
request, response, self.errors_handler
48+
)
49+
50+
async def _get_openapi_request(
51+
self, request: Request
52+
) -> StarletteOpenAPIRequest:
53+
body = await request.body()
54+
return self.request_cls(request, body)
55+
56+
async def _get_openapi_response(
57+
self, response: Response
58+
) -> StarletteOpenAPIResponse:
59+
assert self.response_cls is not None
60+
data = None
61+
if isinstance(response, StreamingResponse):
62+
body_iter1, body_iter2 = atee(response.body_iterator)
63+
response.body_iterator = body_iter2
64+
data = b"".join(
65+
[
66+
chunk.encode(response.charset)
67+
if not isinstance(chunk, bytes)
68+
else chunk
69+
async for chunk in body_iter1
70+
]
71+
)
72+
return self.response_cls(response, data=data)
73+
74+
def _validate_response(self) -> bool:
75+
return self.response_cls is not None

Diff for: openapi_core/contrib/starlette/requests.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
class StarletteOpenAPIRequest:
11-
def __init__(self, request: Request):
11+
def __init__(self, request: Request, body: Optional[bytes] = None):
1212
if not isinstance(request, Request):
1313
raise TypeError(f"'request' argument is not type of {Request}")
1414
self.request = request
@@ -19,7 +19,7 @@ def __init__(self, request: Request):
1919
cookie=self.request.cookies,
2020
)
2121

22-
self._get_body = AsyncToSync(self.request.body, force_new_loop=True)
22+
self._body = body
2323

2424
@property
2525
def host_url(self) -> str:
@@ -35,13 +35,7 @@ def method(self) -> str:
3535

3636
@property
3737
def body(self) -> Optional[bytes]:
38-
body = self._get_body()
39-
if body is None:
40-
return None
41-
if isinstance(body, bytes):
42-
return body
43-
assert isinstance(body, str)
44-
return body.encode("utf-8")
38+
return self._body
4539

4640
@property
4741
def content_type(self) -> str:

Diff for: openapi_core/typing.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Awaitable
12
from typing import Callable
23
from typing import Iterable
34
from typing import TypeVar
@@ -11,3 +12,6 @@
1112

1213
ErrorsHandlerCallable = Callable[[Iterable[Exception]], ResponseType]
1314
ValidRequestHandlerCallable = Callable[[RequestUnmarshalResult], ResponseType]
15+
AsyncValidRequestHandlerCallable = Callable[
16+
[RequestUnmarshalResult], Awaitable[ResponseType]
17+
]

Diff for: openapi_core/unmarshalling/processors.py

+72
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from openapi_core.protocols import Request
99
from openapi_core.protocols import Response
1010
from openapi_core.shortcuts import get_classes
11+
from openapi_core.typing import AsyncValidRequestHandlerCallable
1112
from openapi_core.typing import ErrorsHandlerCallable
1213
from openapi_core.typing import RequestType
1314
from openapi_core.typing import ResponseType
@@ -90,3 +91,74 @@ def handle_response(
9091
if response_unmarshal_result.errors:
9192
return errors_handler(response_unmarshal_result.errors)
9293
return response
94+
95+
96+
class AsyncUnmarshallingProcessor(Generic[RequestType, ResponseType]):
97+
def __init__(
98+
self,
99+
spec: SchemaPath,
100+
request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None,
101+
response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None,
102+
**unmarshaller_kwargs: Any,
103+
):
104+
if (
105+
request_unmarshaller_cls is None
106+
or response_unmarshaller_cls is None
107+
):
108+
classes = get_classes(spec)
109+
if request_unmarshaller_cls is None:
110+
request_unmarshaller_cls = classes.request_unmarshaller_cls
111+
if response_unmarshaller_cls is None:
112+
response_unmarshaller_cls = classes.response_unmarshaller_cls
113+
114+
self.request_processor = RequestUnmarshallingProcessor(
115+
spec,
116+
request_unmarshaller_cls,
117+
**unmarshaller_kwargs,
118+
)
119+
self.response_processor = ResponseUnmarshallingProcessor(
120+
spec,
121+
response_unmarshaller_cls,
122+
**unmarshaller_kwargs,
123+
)
124+
125+
async def _get_openapi_request(self, request: RequestType) -> Request:
126+
raise NotImplementedError
127+
128+
async def _get_openapi_response(self, response: ResponseType) -> Response:
129+
raise NotImplementedError
130+
131+
def _validate_response(self) -> bool:
132+
raise NotImplementedError
133+
134+
async def handle_request(
135+
self,
136+
request: RequestType,
137+
valid_handler: AsyncValidRequestHandlerCallable[ResponseType],
138+
errors_handler: ErrorsHandlerCallable[ResponseType],
139+
) -> ResponseType:
140+
openapi_request = await self._get_openapi_request(request)
141+
request_unmarshal_result = self.request_processor.process(
142+
openapi_request
143+
)
144+
if request_unmarshal_result.errors:
145+
return errors_handler(request_unmarshal_result.errors)
146+
result = await valid_handler(request_unmarshal_result)
147+
return result
148+
149+
async def handle_response(
150+
self,
151+
request: RequestType,
152+
response: ResponseType,
153+
errors_handler: ErrorsHandlerCallable[ResponseType],
154+
) -> ResponseType:
155+
if not self._validate_response():
156+
return response
157+
openapi_request = await self._get_openapi_request(request)
158+
openapi_response = await self._get_openapi_response(response)
159+
response_unmarshal_result = self.response_processor.process(
160+
openapi_request, openapi_response
161+
)
162+
if response_unmarshal_result.errors:
163+
return errors_handler(response_unmarshal_result.errors)
164+
return response

Diff for: poetry.lock

+31-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)