Skip to content

Media type encoding support #646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions openapi_core/deserializing/media_types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,37 @@
from json import loads as json_loads
from xml.etree.ElementTree import fromstring as xml_loads
from collections import defaultdict

from openapi_core.deserializing.media_types.datatypes import (
MediaTypeDeserializersDict,
)
from openapi_core.deserializing.media_types.factories import (
MediaTypeDeserializersFactory,
)
from openapi_core.deserializing.media_types.util import binary_loads
from openapi_core.deserializing.media_types.util import data_form_loads
from openapi_core.deserializing.media_types.util import json_loads
from openapi_core.deserializing.media_types.util import plain_loads
from openapi_core.deserializing.media_types.util import urlencoded_form_loads
from openapi_core.deserializing.media_types.util import xml_loads
from openapi_core.deserializing.styles import style_deserializers_factory

__all__ = ["media_type_deserializers_factory"]

media_type_deserializers: MediaTypeDeserializersDict = {
"text/html": plain_loads,
"text/plain": plain_loads,
"application/json": json_loads,
"application/vnd.api+json": json_loads,
"application/xml": xml_loads,
"application/xhtml+xml": xml_loads,
"application/x-www-form-urlencoded": urlencoded_form_loads,
"multipart/form-data": data_form_loads,
}
media_type_deserializers: MediaTypeDeserializersDict = defaultdict(
lambda: binary_loads,
**{
"text/html": plain_loads,
"text/plain": plain_loads,
"application/octet-stream": binary_loads,
"application/json": json_loads,
"application/vnd.api+json": json_loads,
"application/xml": xml_loads,
"application/xhtml+xml": xml_loads,
"application/x-www-form-urlencoded": urlencoded_form_loads,
"multipart/form-data": data_form_loads,
}
)

media_type_deserializers_factory = MediaTypeDeserializersFactory(
style_deserializers_factory,
media_type_deserializers=media_type_deserializers,
)
169 changes: 159 additions & 10 deletions openapi_core/deserializing/media_types/deserializers.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,182 @@
import warnings
from typing import Any
from typing import Mapping
from typing import Optional
from typing import cast
from xml.etree.ElementTree import ParseError

from jsonschema_path import SchemaPath

from openapi_core.deserializing.media_types.datatypes import (
DeserializerCallable,
)
from openapi_core.deserializing.media_types.datatypes import (
MediaTypeDeserializersDict,
)
from openapi_core.deserializing.media_types.exceptions import (
MediaTypeDeserializeError,
)
from openapi_core.deserializing.styles.factories import (
StyleDeserializersFactory,
)
from openapi_core.schema.encodings import get_content_type
from openapi_core.schema.parameters import get_style_and_explode
from openapi_core.schema.protocols import SuportsGetAll
from openapi_core.schema.protocols import SuportsGetList
from openapi_core.schema.schemas import get_properties


class MediaTypesDeserializer:
def __init__(
self,
media_type_deserializers: Optional[MediaTypeDeserializersDict] = None,
extra_media_type_deserializers: Optional[
MediaTypeDeserializersDict
] = None,
):
if media_type_deserializers is None:
media_type_deserializers = {}
self.media_type_deserializers = media_type_deserializers
if extra_media_type_deserializers is None:
extra_media_type_deserializers = {}
self.extra_media_type_deserializers = extra_media_type_deserializers

def deserialize(self, mimetype: str, value: Any, **parameters: str) -> Any:
deserializer_callable = self.get_deserializer_callable(mimetype)

try:
return deserializer_callable(value, **parameters)
except (ParseError, ValueError, TypeError, AttributeError):
raise MediaTypeDeserializeError(mimetype, value)

def get_deserializer_callable(
self,
mimetype: str,
) -> DeserializerCallable:
if mimetype in self.extra_media_type_deserializers:
return self.extra_media_type_deserializers[mimetype]
return self.media_type_deserializers[mimetype]


class CallableMediaTypeDeserializer:
class MediaTypeDeserializer:
def __init__(
self,
style_deserializers_factory: StyleDeserializersFactory,
media_types_deserializer: MediaTypesDeserializer,
mimetype: str,
deserializer_callable: Optional[DeserializerCallable] = None,
schema: Optional[SchemaPath] = None,
encoding: Optional[SchemaPath] = None,
**parameters: str,
):
self.style_deserializers_factory = style_deserializers_factory
self.media_types_deserializer = media_types_deserializer
self.mimetype = mimetype
self.deserializer_callable = deserializer_callable
self.schema = schema
self.encoding = encoding
self.parameters = parameters

def deserialize(self, value: Any) -> Any:
if self.deserializer_callable is None:
warnings.warn(f"Unsupported {self.mimetype} mimetype")
return value
deserialized = self.media_types_deserializer.deserialize(
self.mimetype, value, **self.parameters
)

try:
return self.deserializer_callable(value, **self.parameters)
except (ParseError, ValueError, TypeError, AttributeError):
raise MediaTypeDeserializeError(self.mimetype, value)
if (
self.mimetype != "application/x-www-form-urlencoded"
and not self.mimetype.startswith("multipart")
):
return deserialized

# decode multipart request bodies
return self.decode(deserialized)

def evolve(
self, mimetype: str, schema: Optional[SchemaPath]
) -> "MediaTypeDeserializer":
cls = self.__class__

return cls(
self.style_deserializers_factory,
self.media_types_deserializer,
mimetype,
schema=schema,
)

def decode(self, location: Mapping[str, Any]) -> Mapping[str, Any]:
# schema is required for multipart
assert self.schema is not None
schema_props = self.schema.get("properties")
properties = {}
for prop_name, prop_schema in get_properties(self.schema).items():
try:
properties[prop_name] = self.decode_property(
prop_name, prop_schema, location
)
except KeyError:
if "default" not in prop_schema:
continue
properties[prop_name] = prop_schema["default"]

return properties

def decode_property(
self,
prop_name: str,
prop_schema: SchemaPath,
location: Mapping[str, Any],
) -> Any:
if self.encoding is None or prop_name not in self.encoding:
return self.decode_property_content_type(
prop_name, prop_schema, location
)

prep_encoding = self.encoding / prop_name
if (
"style" not in prep_encoding
and "explode" not in prep_encoding
and "allowReserved" not in prep_encoding
):
return self.decode_property_content_type(
prop_name, prop_schema, location, prep_encoding
)

return self.decode_property_style(
prop_name, prop_schema, location, prep_encoding
)

def decode_property_style(
self,
prop_name: str,
prop_schema: SchemaPath,
location: Mapping[str, Any],
prep_encoding: SchemaPath,
) -> Any:
prop_style, prop_explode = get_style_and_explode(
prep_encoding, default_location="query"
)
prop_deserializer = self.style_deserializers_factory.create(
prop_style, prop_explode, prop_schema, name=prop_name
)
return prop_deserializer.deserialize(location)

def decode_property_content_type(
self,
prop_name: str,
prop_schema: SchemaPath,
location: Mapping[str, Any],
prep_encoding: Optional[SchemaPath] = None,
) -> Any:
prop_content_type = get_content_type(prop_schema, prep_encoding)
prop_deserializer = self.evolve(
prop_content_type,
prop_schema,
)
prop_schema_type = prop_schema.getkey("type", "")
if prop_schema_type == "array":
if isinstance(location, SuportsGetAll):
value = location.getall(prop_name)
if isinstance(location, SuportsGetList):
value = location.getlist(prop_name)
return list(map(prop_deserializer.deserialize, value))
else:
value = location[prop_name]
return prop_deserializer.deserialize(value)
40 changes: 24 additions & 16 deletions openapi_core/deserializing/media_types/factories.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,60 @@
from typing import Mapping
from typing import Optional

from jsonschema_path import SchemaPath

from openapi_core.deserializing.media_types.datatypes import (
DeserializerCallable,
)
from openapi_core.deserializing.media_types.datatypes import (
MediaTypeDeserializersDict,
)
from openapi_core.deserializing.media_types.deserializers import (
CallableMediaTypeDeserializer,
MediaTypeDeserializer,
)
from openapi_core.deserializing.media_types.deserializers import (
MediaTypesDeserializer,
)
from openapi_core.deserializing.styles.factories import (
StyleDeserializersFactory,
)


class MediaTypeDeserializersFactory:
def __init__(
self,
style_deserializers_factory: StyleDeserializersFactory,
media_type_deserializers: Optional[MediaTypeDeserializersDict] = None,
):
self.style_deserializers_factory = style_deserializers_factory
if media_type_deserializers is None:
media_type_deserializers = {}
self.media_type_deserializers = media_type_deserializers

def create(
self,
mimetype: str,
schema: Optional[SchemaPath] = None,
parameters: Optional[Mapping[str, str]] = None,
encoding: Optional[SchemaPath] = None,
extra_media_type_deserializers: Optional[
MediaTypeDeserializersDict
] = None,
) -> CallableMediaTypeDeserializer:
) -> MediaTypeDeserializer:
if parameters is None:
parameters = {}
if extra_media_type_deserializers is None:
extra_media_type_deserializers = {}
deserialize_callable = self.get_deserializer_callable(
mimetype,
extra_media_type_deserializers=extra_media_type_deserializers,
media_types_deserializer = MediaTypesDeserializer(
self.media_type_deserializers,
extra_media_type_deserializers,
)

return CallableMediaTypeDeserializer(
mimetype, deserialize_callable, **parameters
return MediaTypeDeserializer(
self.style_deserializers_factory,
media_types_deserializer,
mimetype,
schema=schema,
encoding=encoding,
**parameters,
)

def get_deserializer_callable(
self,
mimetype: str,
extra_media_type_deserializers: MediaTypeDeserializersDict,
) -> Optional[DeserializerCallable]:
if mimetype in extra_media_type_deserializers:
return extra_media_type_deserializers[mimetype]
return self.media_type_deserializers.get(mimetype)
51 changes: 41 additions & 10 deletions openapi_core/deserializing/media_types/util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
from email.parser import Parser
from json import loads
from typing import Any
from typing import Dict
from typing import Mapping
from typing import Union
from urllib.parse import parse_qsl
from xml.etree.ElementTree import Element
from xml.etree.ElementTree import fromstring

from werkzeug.datastructures import ImmutableMultiDict


def binary_loads(value: Union[str, bytes], **parameters: str) -> bytes:
charset = "utf-8"
if "charset" in parameters:
charset = parameters["charset"]
if isinstance(value, str):
return value.encode(charset)
return value


def plain_loads(value: Union[str, bytes], **parameters: str) -> str:
Expand All @@ -18,20 +32,37 @@ def plain_loads(value: Union[str, bytes], **parameters: str) -> str:
return value


def urlencoded_form_loads(value: Any, **parameters: str) -> Dict[str, Any]:
def json_loads(value: Union[str, bytes], **parameters: str) -> Any:
return loads(value)


def xml_loads(value: Union[str, bytes], **parameters: str) -> Element:
return fromstring(value)


def urlencoded_form_loads(value: Any, **parameters: str) -> Mapping[str, Any]:
return dict(parse_qsl(value))


def data_form_loads(
value: Union[str, bytes], **parameters: str
) -> Dict[str, Any]:
) -> Mapping[str, Any]:
if isinstance(value, bytes):
value = value.decode("ASCII", errors="surrogateescape")
boundary = ""
if "boundary" in parameters:
boundary = parameters["boundary"]
parser = Parser()
parts = parser.parsestr(value, headersonly=False)
return {
part.get_param("name", header="content-disposition"): part.get_payload(
decode=True
)
for part in parts.get_payload()
}
mimetype = "multipart/form-data"
header = f'Content-Type: {mimetype}; boundary="{boundary}"'
text = "\n\n".join([header, value])
parts = parser.parsestr(text, headersonly=False)
return ImmutableMultiDict(
[
(
part.get_param("name", header="content-disposition"),
part.get_payload(decode=True),
)
for part in parts.get_payload()
]
)
Loading