From 863ba3f7f83fb09a8bce17d74602d0e0f91a0c65 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Tue, 17 Oct 2023 16:55:50 +0000 Subject: [PATCH] Media type encoding support --- .../deserializing/media_types/__init__.py | 32 ++- .../media_types/deserializers.py | 169 +++++++++++- .../deserializing/media_types/factories.py | 40 +-- .../deserializing/media_types/util.py | 51 +++- .../deserializing/styles/factories.py | 13 +- openapi_core/schema/encodings.py | 40 +++ openapi_core/schema/parameters.py | 20 +- openapi_core/validation/validators.py | 24 +- tests/integration/test_petstore.py | 3 +- .../test_media_types_deserializers.py | 241 +++++++++++++++--- .../test_styles_deserializers.py | 8 +- 11 files changed, 538 insertions(+), 103 deletions(-) create mode 100644 openapi_core/schema/encodings.py diff --git a/openapi_core/deserializing/media_types/__init__.py b/openapi_core/deserializing/media_types/__init__.py index 70331f9b..fd4a0ae1 100644 --- a/openapi_core/deserializing/media_types/__init__.py +++ b/openapi_core/deserializing/media_types/__init__.py @@ -1,5 +1,4 @@ -from json import loads as json_loads -from xml.etree.ElementTree import fromstring as xml_loads +from collections import defaultdict from openapi_core.deserializing.media_types.datatypes import ( MediaTypeDeserializersDict, @@ -7,23 +6,32 @@ from openapi_core.deserializing.media_types.factories import ( MediaTypeDeserializersFactory, ) +from openapi_core.deserializing.media_types.util import binary_loads from openapi_core.deserializing.media_types.util import data_form_loads +from openapi_core.deserializing.media_types.util import json_loads from openapi_core.deserializing.media_types.util import plain_loads from openapi_core.deserializing.media_types.util import urlencoded_form_loads +from openapi_core.deserializing.media_types.util import xml_loads +from openapi_core.deserializing.styles import style_deserializers_factory __all__ = ["media_type_deserializers_factory"] -media_type_deserializers: MediaTypeDeserializersDict = { - "text/html": plain_loads, - "text/plain": plain_loads, - "application/json": json_loads, - "application/vnd.api+json": json_loads, - "application/xml": xml_loads, - "application/xhtml+xml": xml_loads, - "application/x-www-form-urlencoded": urlencoded_form_loads, - "multipart/form-data": data_form_loads, -} +media_type_deserializers: MediaTypeDeserializersDict = defaultdict( + lambda: binary_loads, + **{ + "text/html": plain_loads, + "text/plain": plain_loads, + "application/octet-stream": binary_loads, + "application/json": json_loads, + "application/vnd.api+json": json_loads, + "application/xml": xml_loads, + "application/xhtml+xml": xml_loads, + "application/x-www-form-urlencoded": urlencoded_form_loads, + "multipart/form-data": data_form_loads, + } +) media_type_deserializers_factory = MediaTypeDeserializersFactory( + style_deserializers_factory, media_type_deserializers=media_type_deserializers, ) diff --git a/openapi_core/deserializing/media_types/deserializers.py b/openapi_core/deserializing/media_types/deserializers.py index 2bdef976..0fc6b0ba 100644 --- a/openapi_core/deserializing/media_types/deserializers.py +++ b/openapi_core/deserializing/media_types/deserializers.py @@ -1,33 +1,182 @@ import warnings from typing import Any +from typing import Mapping from typing import Optional +from typing import cast from xml.etree.ElementTree import ParseError +from jsonschema_path import SchemaPath + from openapi_core.deserializing.media_types.datatypes import ( DeserializerCallable, ) +from openapi_core.deserializing.media_types.datatypes import ( + MediaTypeDeserializersDict, +) from openapi_core.deserializing.media_types.exceptions import ( MediaTypeDeserializeError, ) +from openapi_core.deserializing.styles.factories import ( + StyleDeserializersFactory, +) +from openapi_core.schema.encodings import get_content_type +from openapi_core.schema.parameters import get_style_and_explode +from openapi_core.schema.protocols import SuportsGetAll +from openapi_core.schema.protocols import SuportsGetList +from openapi_core.schema.schemas import get_properties + + +class MediaTypesDeserializer: + def __init__( + self, + media_type_deserializers: Optional[MediaTypeDeserializersDict] = None, + extra_media_type_deserializers: Optional[ + MediaTypeDeserializersDict + ] = None, + ): + if media_type_deserializers is None: + media_type_deserializers = {} + self.media_type_deserializers = media_type_deserializers + if extra_media_type_deserializers is None: + extra_media_type_deserializers = {} + self.extra_media_type_deserializers = extra_media_type_deserializers + + def deserialize(self, mimetype: str, value: Any, **parameters: str) -> Any: + deserializer_callable = self.get_deserializer_callable(mimetype) + + try: + return deserializer_callable(value, **parameters) + except (ParseError, ValueError, TypeError, AttributeError): + raise MediaTypeDeserializeError(mimetype, value) + + def get_deserializer_callable( + self, + mimetype: str, + ) -> DeserializerCallable: + if mimetype in self.extra_media_type_deserializers: + return self.extra_media_type_deserializers[mimetype] + return self.media_type_deserializers[mimetype] -class CallableMediaTypeDeserializer: +class MediaTypeDeserializer: def __init__( self, + style_deserializers_factory: StyleDeserializersFactory, + media_types_deserializer: MediaTypesDeserializer, mimetype: str, - deserializer_callable: Optional[DeserializerCallable] = None, + schema: Optional[SchemaPath] = None, + encoding: Optional[SchemaPath] = None, **parameters: str, ): + self.style_deserializers_factory = style_deserializers_factory + self.media_types_deserializer = media_types_deserializer self.mimetype = mimetype - self.deserializer_callable = deserializer_callable + self.schema = schema + self.encoding = encoding self.parameters = parameters def deserialize(self, value: Any) -> Any: - if self.deserializer_callable is None: - warnings.warn(f"Unsupported {self.mimetype} mimetype") - return value + deserialized = self.media_types_deserializer.deserialize( + self.mimetype, value, **self.parameters + ) - try: - return self.deserializer_callable(value, **self.parameters) - except (ParseError, ValueError, TypeError, AttributeError): - raise MediaTypeDeserializeError(self.mimetype, value) + if ( + self.mimetype != "application/x-www-form-urlencoded" + and not self.mimetype.startswith("multipart") + ): + return deserialized + + # decode multipart request bodies + return self.decode(deserialized) + + def evolve( + self, mimetype: str, schema: Optional[SchemaPath] + ) -> "MediaTypeDeserializer": + cls = self.__class__ + + return cls( + self.style_deserializers_factory, + self.media_types_deserializer, + mimetype, + schema=schema, + ) + + def decode(self, location: Mapping[str, Any]) -> Mapping[str, Any]: + # schema is required for multipart + assert self.schema is not None + schema_props = self.schema.get("properties") + properties = {} + for prop_name, prop_schema in get_properties(self.schema).items(): + try: + properties[prop_name] = self.decode_property( + prop_name, prop_schema, location + ) + except KeyError: + if "default" not in prop_schema: + continue + properties[prop_name] = prop_schema["default"] + + return properties + + def decode_property( + self, + prop_name: str, + prop_schema: SchemaPath, + location: Mapping[str, Any], + ) -> Any: + if self.encoding is None or prop_name not in self.encoding: + return self.decode_property_content_type( + prop_name, prop_schema, location + ) + + prep_encoding = self.encoding / prop_name + if ( + "style" not in prep_encoding + and "explode" not in prep_encoding + and "allowReserved" not in prep_encoding + ): + return self.decode_property_content_type( + prop_name, prop_schema, location, prep_encoding + ) + + return self.decode_property_style( + prop_name, prop_schema, location, prep_encoding + ) + + def decode_property_style( + self, + prop_name: str, + prop_schema: SchemaPath, + location: Mapping[str, Any], + prep_encoding: SchemaPath, + ) -> Any: + prop_style, prop_explode = get_style_and_explode( + prep_encoding, default_location="query" + ) + prop_deserializer = self.style_deserializers_factory.create( + prop_style, prop_explode, prop_schema, name=prop_name + ) + return prop_deserializer.deserialize(location) + + def decode_property_content_type( + self, + prop_name: str, + prop_schema: SchemaPath, + location: Mapping[str, Any], + prep_encoding: Optional[SchemaPath] = None, + ) -> Any: + prop_content_type = get_content_type(prop_schema, prep_encoding) + prop_deserializer = self.evolve( + prop_content_type, + prop_schema, + ) + prop_schema_type = prop_schema.getkey("type", "") + if prop_schema_type == "array": + if isinstance(location, SuportsGetAll): + value = location.getall(prop_name) + if isinstance(location, SuportsGetList): + value = location.getlist(prop_name) + return list(map(prop_deserializer.deserialize, value)) + else: + value = location[prop_name] + return prop_deserializer.deserialize(value) diff --git a/openapi_core/deserializing/media_types/factories.py b/openapi_core/deserializing/media_types/factories.py index 9087c6b1..b39d65a5 100644 --- a/openapi_core/deserializing/media_types/factories.py +++ b/openapi_core/deserializing/media_types/factories.py @@ -1,6 +1,8 @@ from typing import Mapping from typing import Optional +from jsonschema_path import SchemaPath + from openapi_core.deserializing.media_types.datatypes import ( DeserializerCallable, ) @@ -8,15 +10,23 @@ MediaTypeDeserializersDict, ) from openapi_core.deserializing.media_types.deserializers import ( - CallableMediaTypeDeserializer, + MediaTypeDeserializer, +) +from openapi_core.deserializing.media_types.deserializers import ( + MediaTypesDeserializer, +) +from openapi_core.deserializing.styles.factories import ( + StyleDeserializersFactory, ) class MediaTypeDeserializersFactory: def __init__( self, + style_deserializers_factory: StyleDeserializersFactory, media_type_deserializers: Optional[MediaTypeDeserializersDict] = None, ): + self.style_deserializers_factory = style_deserializers_factory if media_type_deserializers is None: media_type_deserializers = {} self.media_type_deserializers = media_type_deserializers @@ -24,29 +34,27 @@ def __init__( def create( self, mimetype: str, + schema: Optional[SchemaPath] = None, parameters: Optional[Mapping[str, str]] = None, + encoding: Optional[SchemaPath] = None, extra_media_type_deserializers: Optional[ MediaTypeDeserializersDict ] = None, - ) -> CallableMediaTypeDeserializer: + ) -> MediaTypeDeserializer: if parameters is None: parameters = {} if extra_media_type_deserializers is None: extra_media_type_deserializers = {} - deserialize_callable = self.get_deserializer_callable( - mimetype, - extra_media_type_deserializers=extra_media_type_deserializers, + media_types_deserializer = MediaTypesDeserializer( + self.media_type_deserializers, + extra_media_type_deserializers, ) - return CallableMediaTypeDeserializer( - mimetype, deserialize_callable, **parameters + return MediaTypeDeserializer( + self.style_deserializers_factory, + media_types_deserializer, + mimetype, + schema=schema, + encoding=encoding, + **parameters, ) - - def get_deserializer_callable( - self, - mimetype: str, - extra_media_type_deserializers: MediaTypeDeserializersDict, - ) -> Optional[DeserializerCallable]: - if mimetype in extra_media_type_deserializers: - return extra_media_type_deserializers[mimetype] - return self.media_type_deserializers.get(mimetype) diff --git a/openapi_core/deserializing/media_types/util.py b/openapi_core/deserializing/media_types/util.py index c73315d7..aa3c333c 100644 --- a/openapi_core/deserializing/media_types/util.py +++ b/openapi_core/deserializing/media_types/util.py @@ -1,8 +1,22 @@ from email.parser import Parser +from json import loads from typing import Any -from typing import Dict +from typing import Mapping from typing import Union from urllib.parse import parse_qsl +from xml.etree.ElementTree import Element +from xml.etree.ElementTree import fromstring + +from werkzeug.datastructures import ImmutableMultiDict + + +def binary_loads(value: Union[str, bytes], **parameters: str) -> bytes: + charset = "utf-8" + if "charset" in parameters: + charset = parameters["charset"] + if isinstance(value, str): + return value.encode(charset) + return value def plain_loads(value: Union[str, bytes], **parameters: str) -> str: @@ -18,20 +32,37 @@ def plain_loads(value: Union[str, bytes], **parameters: str) -> str: return value -def urlencoded_form_loads(value: Any, **parameters: str) -> Dict[str, Any]: +def json_loads(value: Union[str, bytes], **parameters: str) -> Any: + return loads(value) + + +def xml_loads(value: Union[str, bytes], **parameters: str) -> Element: + return fromstring(value) + + +def urlencoded_form_loads(value: Any, **parameters: str) -> Mapping[str, Any]: return dict(parse_qsl(value)) def data_form_loads( value: Union[str, bytes], **parameters: str -) -> Dict[str, Any]: +) -> Mapping[str, Any]: if isinstance(value, bytes): value = value.decode("ASCII", errors="surrogateescape") + boundary = "" + if "boundary" in parameters: + boundary = parameters["boundary"] parser = Parser() - parts = parser.parsestr(value, headersonly=False) - return { - part.get_param("name", header="content-disposition"): part.get_payload( - decode=True - ) - for part in parts.get_payload() - } + mimetype = "multipart/form-data" + header = f'Content-Type: {mimetype}; boundary="{boundary}"' + text = "\n\n".join([header, value]) + parts = parser.parsestr(text, headersonly=False) + return ImmutableMultiDict( + [ + ( + part.get_param("name", header="content-disposition"), + part.get_payload(decode=True), + ) + for part in parts.get_payload() + ] + ) diff --git a/openapi_core/deserializing/styles/factories.py b/openapi_core/deserializing/styles/factories.py index 26a5f61e..cfacb2ce 100644 --- a/openapi_core/deserializing/styles/factories.py +++ b/openapi_core/deserializing/styles/factories.py @@ -10,9 +10,6 @@ from openapi_core.deserializing.styles.datatypes import DeserializerCallable from openapi_core.deserializing.styles.datatypes import StyleDeserializersDict from openapi_core.deserializing.styles.deserializers import StyleDeserializer -from openapi_core.deserializing.styles.util import split -from openapi_core.schema.parameters import get_explode -from openapi_core.schema.parameters import get_style class StyleDeserializersFactory: @@ -25,12 +22,12 @@ def __init__( self.style_deserializers = style_deserializers def create( - self, param_or_header: SchemaPath, name: Optional[str] = None + self, + style: str, + explode: bool, + schema: SchemaPath, + name: str, ) -> StyleDeserializer: - name = name or param_or_header["name"] - style = get_style(param_or_header) - explode = get_explode(param_or_header) - schema = param_or_header / "schema" schema_type = schema.getkey("type", "") deserialize_callable = self.style_deserializers.get(style) diff --git a/openapi_core/schema/encodings.py b/openapi_core/schema/encodings.py new file mode 100644 index 00000000..2dd3d9fa --- /dev/null +++ b/openapi_core/schema/encodings.py @@ -0,0 +1,40 @@ +from typing import Optional +from typing import cast + +from jsonschema_path import SchemaPath + + +def get_content_type( + prop_schema: SchemaPath, encoding: Optional[SchemaPath] +) -> str: + if encoding is None: + return get_default_content_type(prop_schema, encoding=False) + + if "contentType" not in encoding: + return get_default_content_type(prop_schema, encoding=True) + + return cast(str, encoding["contentType"]) + + +def get_default_content_type( + prop_schema: Optional[SchemaPath], encoding: bool = False +) -> str: + if prop_schema is None: + return "text/plain" + + prop_type = prop_schema.getkey("type") + if prop_type is None: + return "text/plain" if encoding else "application/octet-stream" + + prop_format = prop_schema.getkey("format") + if prop_type == "string" and prop_format in ["binary", "base64"]: + return "application/octet-stream" + + if prop_type == "object": + return "application/json" + + if prop_type == "array": + prop_items = prop_schema / "items" + return get_default_content_type(prop_items, encoding=encoding) + + return "text/plain" diff --git a/openapi_core/schema/parameters.py b/openapi_core/schema/parameters.py index da1a5f16..4f43ea05 100644 --- a/openapi_core/schema/parameters.py +++ b/openapi_core/schema/parameters.py @@ -2,6 +2,7 @@ from typing import Dict from typing import Mapping from typing import Optional +from typing import Tuple from jsonschema_path import SchemaPath @@ -9,14 +10,15 @@ from openapi_core.schema.protocols import SuportsGetList -def get_style(param_or_header: SchemaPath) -> str: +def get_style( + param_or_header: SchemaPath, default_location: str = "header" +) -> str: """Checks parameter/header style for simpler scenarios""" if "style" in param_or_header: assert isinstance(param_or_header["style"], str) return param_or_header["style"] - # if "in" not defined then it's a Header - location = param_or_header.getkey("in", "header") + location = param_or_header.getkey("in", default_location) # determine default return "simple" if location in ["path", "header"] else "form" @@ -31,3 +33,15 @@ def get_explode(param_or_header: SchemaPath) -> bool: # determine default style = get_style(param_or_header) return style == "form" + + +def get_style_and_explode( + param_or_header: SchemaPath, default_location: str = "header" +) -> Tuple[str, bool]: + """Checks parameter/header explode for simpler scenarios""" + style = get_style(param_or_header, default_location=default_location) + if "explode" in param_or_header: + assert isinstance(param_or_header["explode"], bool) + return style, param_or_header["explode"] + + return style, style == "form" diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index f1a34a63..3494dad1 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -31,6 +31,7 @@ ) from openapi_core.protocols import Request from openapi_core.protocols import WebhookRequest +from openapi_core.schema.parameters import get_style_and_explode from openapi_core.templating.media_types.datatypes import MediaType from openapi_core.templating.paths.datatypes import PathOperationServer from openapi_core.templating.paths.finders import APICallPathFinder @@ -97,12 +98,22 @@ def _find_media_type( return finder.find(mimetype) def _deserialise_media_type( - self, mimetype: str, parameters: Mapping[str, str], value: Any + self, + media_type: SchemaPath, + mimetype: str, + parameters: Mapping[str, str], + value: Any, ) -> Any: + schema = media_type.get("schema") + encoding = None + if "encoding" in media_type: + encoding = media_type.get("encoding") deserializer = self.media_type_deserializers_factory.create( mimetype, - extra_media_type_deserializers=self.extra_media_type_deserializers, + schema=schema, parameters=parameters, + encoding=encoding, + extra_media_type_deserializers=self.extra_media_type_deserializers, ) return deserializer.deserialize(value) @@ -112,8 +123,11 @@ def _deserialise_style( location: Mapping[str, Any], name: Optional[str] = None, ) -> Any: + name = name or param_or_header["name"] + style, explode = get_style_and_explode(param_or_header) + schema = param_or_header / "schema" deserializer = self.style_deserializers_factory.create( - param_or_header, name=name + style, explode, schema, name=name ) return deserializer.deserialize(location) @@ -213,7 +227,9 @@ def _get_content_schema_value_and_schema( ) # no point to catch KetError # in complex scenrios schema doesn't exist - deserialised = self._deserialise_media_type(mime_type, parameters, raw) + deserialised = self._deserialise_media_type( + media_type, mime_type, parameters, raw + ) casted = self._cast(media_type, deserialised) if "schema" not in media_type: diff --git a/tests/integration/test_petstore.py b/tests/integration/test_petstore.py index 81a78e68..164dab31 100644 --- a/tests/integration/test_petstore.py +++ b/tests/integration/test_petstore.py @@ -1374,8 +1374,7 @@ def test_get_pet_wildcard(self, spec): data = b"imagedata" response = MockResponse(data, mimetype="image/png") - with pytest.warns(UserWarning): - response_result = unmarshal_response(request, response, spec=spec) + response_result = unmarshal_response(request, response, spec=spec) assert response_result.errors == [] assert response_result.data == data diff --git a/tests/unit/deserializing/test_media_types_deserializers.py b/tests/unit/deserializing/test_media_types_deserializers.py index 28279f93..1d099a3d 100644 --- a/tests/unit/deserializing/test_media_types_deserializers.py +++ b/tests/unit/deserializing/test_media_types_deserializers.py @@ -1,55 +1,40 @@ from xml.etree.ElementTree import Element import pytest +from jsonschema_path import SchemaPath from openapi_core.deserializing.exceptions import DeserializeError from openapi_core.deserializing.media_types import media_type_deserializers from openapi_core.deserializing.media_types.factories import ( MediaTypeDeserializersFactory, ) +from openapi_core.deserializing.styles import style_deserializers_factory class TestMediaTypeDeserializer: @pytest.fixture def deserializer_factory(self): def create_deserializer( - media_type, + mimetype, + schema=None, + encoding=None, parameters=None, media_type_deserializers=media_type_deserializers, extra_media_type_deserializers=None, ): return MediaTypeDeserializersFactory( + style_deserializers_factory, media_type_deserializers, ).create( - media_type, + mimetype, + schema=schema, parameters=parameters, + encoding=encoding, extra_media_type_deserializers=extra_media_type_deserializers, ) return create_deserializer - def test_unsupported(self, deserializer_factory): - mimetype = "application/unsupported" - deserializer = deserializer_factory(mimetype) - value = "" - - with pytest.warns(UserWarning): - result = deserializer.deserialize(value) - - assert result == value - - def test_no_deserializer(self, deserializer_factory): - mimetype = "application/json" - deserializer = deserializer_factory( - mimetype, media_type_deserializers=None - ) - value = "{}" - - with pytest.warns(UserWarning): - result = deserializer.deserialize(value) - - assert result == value - @pytest.mark.parametrize( "mimetype,parameters,value,expected", [ @@ -79,6 +64,23 @@ def test_plain_valid( assert result == expected + @pytest.mark.parametrize( + "mimetype", + [ + "application/json", + "application/vnd.api+json", + ], + ) + def test_json_valid(self, deserializer_factory, mimetype): + parameters = {"charset": "utf-8"} + deserializer = deserializer_factory(mimetype, parameters=parameters) + value = '{"test": "test"}' + + result = deserializer.deserialize(value) + + assert type(result) is dict + assert result == {"test": "test"} + @pytest.mark.parametrize( "mimetype", [ @@ -130,16 +132,54 @@ def test_xml_empty(self, deserializer_factory, mimetype): ], ) def test_xml_valid(self, deserializer_factory, mimetype): - deserializer = deserializer_factory(mimetype) + parameters = {"charset": "utf-8"} + deserializer = deserializer_factory(mimetype, parameters=parameters) value = "text" result = deserializer.deserialize(value) assert type(result) is Element + def test_octet_stream_empty(self, deserializer_factory): + mimetype = "application/octet-stream" + deserializer = deserializer_factory(mimetype) + value = "" + + result = deserializer.deserialize(value) + + assert result == b"" + + @pytest.mark.parametrize( + "mimetype", + [ + "image/gif", + "image/png", + ], + ) + def test_octet_stream_implicit(self, deserializer_factory, mimetype): + deserializer = deserializer_factory(mimetype) + value = b"" + + result = deserializer.deserialize(value) + + assert result == value + + def test_octet_stream_simple(self, deserializer_factory): + mimetype = "application/octet-stream" + schema_dict = {} + schema = SchemaPath.from_dict(schema_dict) + deserializer = deserializer_factory(mimetype, schema=schema) + value = b"test" + + result = deserializer.deserialize(value) + + assert result == b"test" + def test_urlencoded_form_empty(self, deserializer_factory): mimetype = "application/x-www-form-urlencoded" - deserializer = deserializer_factory(mimetype) + schema_dict = {} + schema = SchemaPath.from_dict(schema_dict) + deserializer = deserializer_factory(mimetype, schema=schema) value = "" result = deserializer.deserialize(value) @@ -148,38 +188,165 @@ def test_urlencoded_form_empty(self, deserializer_factory): def test_urlencoded_form_simple(self, deserializer_factory): mimetype = "application/x-www-form-urlencoded" - deserializer = deserializer_factory(mimetype) - value = "param1=test" + schema_dict = { + "type": "object", + "properties": { + "name": { + "type": "string", + }, + }, + } + schema = SchemaPath.from_dict(schema_dict) + encoding_dict = { + "name": { + "style": "form", + }, + } + encoding = SchemaPath.from_dict(encoding_dict) + deserializer = deserializer_factory( + mimetype, schema=schema, encoding=encoding + ) + value = "name=foo+bar" result = deserializer.deserialize(value) - assert result == {"param1": "test"} + assert result == { + "name": "foo bar", + } + + def test_urlencoded_deepobject(self, deserializer_factory): + mimetype = "application/x-www-form-urlencoded" + schema_dict = { + "type": "object", + "properties": { + "color": { + "type": "object", + "properties": { + "R": { + "type": "integer", + }, + "G": { + "type": "integer", + }, + "B": { + "type": "integer", + }, + }, + }, + }, + } + schema = SchemaPath.from_dict(schema_dict) + encoding_dict = { + "color": { + "style": "deepObject", + "explode": True, + }, + } + encoding = SchemaPath.from_dict(encoding_dict) + deserializer = deserializer_factory( + mimetype, schema=schema, encoding=encoding + ) + value = "color[R]=100&color[G]=200&color[B]=150" + + result = deserializer.deserialize(value) + + assert result == { + "color": { + "R": "100", + "G": "200", + "B": "150", + }, + } @pytest.mark.parametrize("value", [b"", ""]) - def test_data_form_empty(self, deserializer_factory, value): + def test_multipart_form_empty(self, deserializer_factory, value): mimetype = "multipart/form-data" - deserializer = deserializer_factory(mimetype) + schema_dict = {} + schema = SchemaPath.from_dict(schema_dict) + deserializer = deserializer_factory(mimetype, schema=schema) result = deserializer.deserialize(value) assert result == {} - def test_data_form_simple(self, deserializer_factory): + def test_multipart_form_simple(self, deserializer_factory): mimetype = "multipart/form-data" - deserializer = deserializer_factory(mimetype) + schema_dict = { + "type": "object", + "properties": { + "param1": { + "type": "string", + "format": "binary", + }, + "param2": { + "type": "string", + "format": "binary", + }, + }, + } + schema = SchemaPath.from_dict(schema_dict) + encoding_dict = { + "param1": { + "contentType": "application/octet-stream", + }, + } + encoding = SchemaPath.from_dict(encoding_dict) + parameters = { + "boundary": "===============2872712225071193122==", + } + deserializer = deserializer_factory( + mimetype, schema=schema, parameters=parameters, encoding=encoding + ) value = ( - b'Content-Type: multipart/form-data; boundary="' - b'===============2872712225071193122=="\n' - b"MIME-Version: 1.0\n\n" b"--===============2872712225071193122==\n" b"Content-Type: text/plain\nMIME-Version: 1.0\n" b'Content-Disposition: form-data; name="param1"\n\ntest\n' + b"--===============2872712225071193122==\n" + b"Content-Type: text/plain\nMIME-Version: 1.0\n" + b'Content-Disposition: form-data; name="param2"\n\ntest2\n' b"--===============2872712225071193122==--\n" ) result = deserializer.deserialize(value) - assert result == {"param1": b"test"} + assert result == { + "param1": b"test", + "param2": b"test2", + } + + def test_multipart_form_array(self, deserializer_factory): + mimetype = "multipart/form-data" + schema_dict = { + "type": "object", + "properties": { + "file": { + "type": "array", + "items": {}, + }, + }, + } + schema = SchemaPath.from_dict(schema_dict) + parameters = { + "boundary": "===============2872712225071193122==", + } + deserializer = deserializer_factory( + mimetype, schema=schema, parameters=parameters + ) + value = ( + b"--===============2872712225071193122==\n" + b"Content-Type: text/plain\nMIME-Version: 1.0\n" + b'Content-Disposition: form-data; name="file"\n\ntest\n' + b"--===============2872712225071193122==\n" + b"Content-Type: text/plain\nMIME-Version: 1.0\n" + b'Content-Disposition: form-data; name="file"\n\ntest2\n' + b"--===============2872712225071193122==--\n" + ) + + result = deserializer.deserialize(value) + + assert result == { + "file": [b"test", b"test2"], + } def test_custom_simple(self, deserializer_factory): deserialized = "x-custom" diff --git a/tests/unit/deserializing/test_styles_deserializers.py b/tests/unit/deserializing/test_styles_deserializers.py index a6895a3a..3c516143 100644 --- a/tests/unit/deserializing/test_styles_deserializers.py +++ b/tests/unit/deserializing/test_styles_deserializers.py @@ -7,13 +7,19 @@ from openapi_core.deserializing.styles.exceptions import ( EmptyQueryParameterValue, ) +from openapi_core.schema.parameters import get_style_and_explode class TestParameterStyleDeserializer: @pytest.fixture def deserializer_factory(self): def create_deserializer(param, name=None): - return style_deserializers_factory.create(param, name=name) + name = name or param["name"] + style, explode = get_style_and_explode(param) + schema = param / "schema" + return style_deserializers_factory.create( + style, explode, schema, name=name + ) return create_deserializer