diff --git a/docs/customizations.rst b/docs/customizations.rst index d4dd4174..28a38de3 100644 --- a/docs/customizations.rst +++ b/docs/customizations.rst @@ -14,32 +14,26 @@ If you know you have a valid specification already, disabling the validator can spec = Spec.from_dict(spec_dict, validator=None) -Deserializers -------------- +Media type deserializers +------------------------ -Pass custom defined media type deserializers dictionary with supported mimetypes as a key to `MediaTypeDeserializersFactory` and then pass it to `RequestValidator` or `ResponseValidator` constructor: +Pass custom defined media type deserializers dictionary with supported mimetypes as a key to `unmarshal_response` function: .. code-block:: python - from openapi_core.deserializing.media_types.factories import MediaTypeDeserializersFactory - def protobuf_deserializer(message): feature = route_guide_pb2.Feature() feature.ParseFromString(message) return feature - custom_media_type_deserializers = { + extra_media_type_deserializers = { 'application/protobuf': protobuf_deserializer, } - media_type_deserializers_factory = MediaTypeDeserializersFactory( - custom_deserializers=custom_media_type_deserializers, - ) - result = validate_response( + result = unmarshal_response( request, response, spec=spec, - cls=V30ResponseValidator, - media_type_deserializers_factory=media_type_deserializers_factory, + extra_media_type_deserializers=extra_media_type_deserializers, ) Format validators diff --git a/openapi_core/deserializing/media_types/__init__.py b/openapi_core/deserializing/media_types/__init__.py index 5017ac49..b8aef87f 100644 --- a/openapi_core/deserializing/media_types/__init__.py +++ b/openapi_core/deserializing/media_types/__init__.py @@ -1,7 +1,22 @@ +from json import loads + +from openapi_core.deserializing.media_types.datatypes import ( + MediaTypeDeserializersDict, +) from openapi_core.deserializing.media_types.factories import ( MediaTypeDeserializersFactory, ) +from openapi_core.deserializing.media_types.util import data_form_loads +from openapi_core.deserializing.media_types.util import urlencoded_form_loads __all__ = ["media_type_deserializers_factory"] -media_type_deserializers_factory = MediaTypeDeserializersFactory() +media_type_deserializers: MediaTypeDeserializersDict = { + "application/json": loads, + "application/x-www-form-urlencoded": urlencoded_form_loads, + "multipart/form-data": data_form_loads, +} + +media_type_deserializers_factory = MediaTypeDeserializersFactory( + media_type_deserializers=media_type_deserializers, +) diff --git a/openapi_core/deserializing/media_types/datatypes.py b/openapi_core/deserializing/media_types/datatypes.py index 3d45ab69..db226cfe 100644 --- a/openapi_core/deserializing/media_types/datatypes.py +++ b/openapi_core/deserializing/media_types/datatypes.py @@ -1,4 +1,6 @@ from typing import Any from typing import Callable +from typing import Dict DeserializerCallable = Callable[[Any], Any] +MediaTypeDeserializersDict = Dict[str, DeserializerCallable] diff --git a/openapi_core/deserializing/media_types/deserializers.py b/openapi_core/deserializing/media_types/deserializers.py index bac900d4..4ba040cf 100644 --- a/openapi_core/deserializing/media_types/deserializers.py +++ b/openapi_core/deserializing/media_types/deserializers.py @@ -1,6 +1,6 @@ import warnings from typing import Any -from typing import Callable +from typing import Optional from openapi_core.deserializing.media_types.datatypes import ( DeserializerCallable, @@ -10,28 +10,20 @@ ) -class BaseMediaTypeDeserializer: - def __init__(self, mimetype: str): - self.mimetype = mimetype - - def __call__(self, value: Any) -> Any: - raise NotImplementedError - - -class UnsupportedMimetypeDeserializer(BaseMediaTypeDeserializer): - def __call__(self, value: Any) -> Any: - warnings.warn(f"Unsupported {self.mimetype} mimetype") - return value - - -class CallableMediaTypeDeserializer(BaseMediaTypeDeserializer): +class CallableMediaTypeDeserializer: def __init__( - self, mimetype: str, deserializer_callable: DeserializerCallable + self, + mimetype: str, + deserializer_callable: Optional[DeserializerCallable] = None, ): self.mimetype = mimetype self.deserializer_callable = deserializer_callable - def __call__(self, value: Any) -> Any: + def deserialize(self, value: Any) -> Any: + if self.deserializer_callable is None: + warnings.warn(f"Unsupported {self.mimetype} mimetype") + return value + try: return self.deserializer_callable(value) except (ValueError, TypeError, AttributeError): diff --git a/openapi_core/deserializing/media_types/factories.py b/openapi_core/deserializing/media_types/factories.py index b5114757..2008c54c 100644 --- a/openapi_core/deserializing/media_types/factories.py +++ b/openapi_core/deserializing/media_types/factories.py @@ -1,51 +1,60 @@ -from json import loads -from typing import Any -from typing import Callable +import warnings from typing import Dict from typing import Optional from openapi_core.deserializing.media_types.datatypes import ( DeserializerCallable, ) -from openapi_core.deserializing.media_types.deserializers import ( - BaseMediaTypeDeserializer, +from openapi_core.deserializing.media_types.datatypes import ( + MediaTypeDeserializersDict, ) from openapi_core.deserializing.media_types.deserializers import ( CallableMediaTypeDeserializer, ) -from openapi_core.deserializing.media_types.deserializers import ( - UnsupportedMimetypeDeserializer, -) -from openapi_core.deserializing.media_types.util import data_form_loads -from openapi_core.deserializing.media_types.util import urlencoded_form_loads class MediaTypeDeserializersFactory: - MEDIA_TYPE_DESERIALIZERS: Dict[str, DeserializerCallable] = { - "application/json": loads, - "application/x-www-form-urlencoded": urlencoded_form_loads, - "multipart/form-data": data_form_loads, - } - def __init__( self, - custom_deserializers: Optional[Dict[str, DeserializerCallable]] = None, + media_type_deserializers: Optional[MediaTypeDeserializersDict] = None, + custom_deserializers: Optional[MediaTypeDeserializersDict] = None, ): + if media_type_deserializers is None: + media_type_deserializers = {} + self.media_type_deserializers = media_type_deserializers if custom_deserializers is None: custom_deserializers = {} + else: + warnings.warn( + "custom_deserializers is deprecated. " + "Use extra_media_type_deserializers.", + DeprecationWarning, + ) self.custom_deserializers = custom_deserializers - def create(self, mimetype: str) -> BaseMediaTypeDeserializer: - deserialize_callable = self.get_deserializer_callable(mimetype) - - if deserialize_callable is None: - return UnsupportedMimetypeDeserializer(mimetype) + def create( + self, + mimetype: str, + extra_media_type_deserializers: Optional[ + MediaTypeDeserializersDict + ] = None, + ) -> CallableMediaTypeDeserializer: + if extra_media_type_deserializers is None: + extra_media_type_deserializers = {} + deserialize_callable = self.get_deserializer_callable( + mimetype, + extra_media_type_deserializers=extra_media_type_deserializers, + ) return CallableMediaTypeDeserializer(mimetype, deserialize_callable) def get_deserializer_callable( - self, mimetype: str + self, + mimetype: str, + extra_media_type_deserializers: MediaTypeDeserializersDict, ) -> Optional[DeserializerCallable]: if mimetype in self.custom_deserializers: return self.custom_deserializers[mimetype] - return self.MEDIA_TYPE_DESERIALIZERS.get(mimetype) + if mimetype in extra_media_type_deserializers: + return extra_media_type_deserializers[mimetype] + return self.media_type_deserializers.get(mimetype) diff --git a/openapi_core/deserializing/parameters/deserializers.py b/openapi_core/deserializing/parameters/deserializers.py index 22906c0e..ae93b718 100644 --- a/openapi_core/deserializing/parameters/deserializers.py +++ b/openapi_core/deserializing/parameters/deserializers.py @@ -2,6 +2,7 @@ from typing import Any from typing import Callable from typing import List +from typing import Optional from openapi_core.deserializing.exceptions import DeserializeError from openapi_core.deserializing.parameters.datatypes import ( @@ -15,35 +16,25 @@ from openapi_core.spec import Spec -class BaseParameterDeserializer: - def __init__(self, param_or_header: Spec, style: str): - self.param_or_header = param_or_header - self.style = style - - def __call__(self, value: Any) -> Any: - raise NotImplementedError - - -class UnsupportedStyleDeserializer(BaseParameterDeserializer): - def __call__(self, value: Any) -> Any: - warnings.warn(f"Unsupported {self.style} style") - return value - - -class CallableParameterDeserializer(BaseParameterDeserializer): +class CallableParameterDeserializer: def __init__( self, param_or_header: Spec, style: str, - deserializer_callable: DeserializerCallable, + deserializer_callable: Optional[DeserializerCallable] = None, ): - super().__init__(param_or_header, style) + self.param_or_header = param_or_header + self.style = style self.deserializer_callable = deserializer_callable self.aslist = get_aslist(self.param_or_header) self.explode = get_explode(self.param_or_header) - def __call__(self, value: Any) -> Any: + def deserialize(self, value: Any) -> Any: + if self.deserializer_callable is None: + warnings.warn(f"Unsupported {self.style} style") + return value + # if "in" not defined then it's a Header if "allowEmptyValue" in self.param_or_header: warnings.warn( diff --git a/openapi_core/deserializing/parameters/factories.py b/openapi_core/deserializing/parameters/factories.py index 07011bf7..e0f559d2 100644 --- a/openapi_core/deserializing/parameters/factories.py +++ b/openapi_core/deserializing/parameters/factories.py @@ -5,15 +5,9 @@ from openapi_core.deserializing.parameters.datatypes import ( DeserializerCallable, ) -from openapi_core.deserializing.parameters.deserializers import ( - BaseParameterDeserializer, -) from openapi_core.deserializing.parameters.deserializers import ( CallableParameterDeserializer, ) -from openapi_core.deserializing.parameters.deserializers import ( - UnsupportedStyleDeserializer, -) from openapi_core.deserializing.parameters.util import split from openapi_core.schema.parameters import get_style from openapi_core.spec import Spec @@ -28,13 +22,10 @@ class ParameterDeserializersFactory: "deepObject": partial(re.split, pattern=r"\[|\]"), } - def create(self, param_or_header: Spec) -> BaseParameterDeserializer: + def create(self, param_or_header: Spec) -> CallableParameterDeserializer: style = get_style(param_or_header) - if style not in self.PARAMETER_STYLE_DESERIALIZERS: - return UnsupportedStyleDeserializer(param_or_header, style) - - deserialize_callable = self.PARAMETER_STYLE_DESERIALIZERS[style] + deserialize_callable = self.PARAMETER_STYLE_DESERIALIZERS.get(style) return CallableParameterDeserializer( param_or_header, style, deserialize_callable ) diff --git a/openapi_core/unmarshalling/request/unmarshallers.py b/openapi_core/unmarshalling/request/unmarshallers.py index 2983d082..ccd7cc9a 100644 --- a/openapi_core/unmarshalling/request/unmarshallers.py +++ b/openapi_core/unmarshalling/request/unmarshallers.py @@ -6,6 +6,9 @@ from openapi_core.deserializing.media_types import ( media_type_deserializers_factory, ) +from openapi_core.deserializing.media_types.datatypes import ( + MediaTypeDeserializersDict, +) from openapi_core.deserializing.media_types.factories import ( MediaTypeDeserializersFactory, ) @@ -90,6 +93,9 @@ def __init__( schema_validators_factory: Optional[SchemaValidatorsFactory] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, + extra_media_type_deserializers: Optional[ + MediaTypeDeserializersDict + ] = None, security_provider_factory: SecurityProviderFactory = security_provider_factory, schema_unmarshallers_factory: Optional[ SchemaUnmarshallersFactory @@ -107,6 +113,7 @@ def __init__( schema_validators_factory=schema_validators_factory, format_validators=format_validators, extra_format_validators=extra_format_validators, + extra_media_type_deserializers=extra_media_type_deserializers, schema_unmarshallers_factory=schema_unmarshallers_factory, format_unmarshallers=format_unmarshallers, extra_format_unmarshallers=extra_format_unmarshallers, @@ -121,6 +128,7 @@ def __init__( schema_validators_factory=schema_validators_factory, format_validators=format_validators, extra_format_validators=extra_format_validators, + extra_media_type_deserializers=extra_media_type_deserializers, security_provider_factory=security_provider_factory, ) diff --git a/openapi_core/unmarshalling/unmarshallers.py b/openapi_core/unmarshalling/unmarshallers.py index 61ae6fd7..af857906 100644 --- a/openapi_core/unmarshalling/unmarshallers.py +++ b/openapi_core/unmarshalling/unmarshallers.py @@ -8,6 +8,9 @@ from openapi_core.deserializing.media_types import ( media_type_deserializers_factory, ) +from openapi_core.deserializing.media_types.datatypes import ( + MediaTypeDeserializersDict, +) from openapi_core.deserializing.media_types.factories import ( MediaTypeDeserializersFactory, ) @@ -42,6 +45,9 @@ def __init__( schema_validators_factory: Optional[SchemaValidatorsFactory] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, + extra_media_type_deserializers: Optional[ + MediaTypeDeserializersDict + ] = None, schema_unmarshallers_factory: Optional[ SchemaUnmarshallersFactory ] = None, @@ -61,6 +67,7 @@ def __init__( schema_validators_factory=schema_validators_factory, format_validators=format_validators, extra_format_validators=extra_format_validators, + extra_media_type_deserializers=extra_media_type_deserializers, ) self.schema_unmarshallers_factory = ( schema_unmarshallers_factory or self.schema_unmarshallers_factory diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index b25246a9..c8224a09 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -12,6 +12,9 @@ from openapi_core.deserializing.media_types import ( media_type_deserializers_factory, ) +from openapi_core.deserializing.media_types.datatypes import ( + MediaTypeDeserializersDict, +) from openapi_core.deserializing.media_types.factories import ( MediaTypeDeserializersFactory, ) @@ -68,6 +71,9 @@ def __init__( schema_validators_factory: Optional[SchemaValidatorsFactory] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, + extra_media_type_deserializers: Optional[ + MediaTypeDeserializersDict + ] = None, security_provider_factory: SecurityProviderFactory = security_provider_factory, ): super().__init__( @@ -79,6 +85,7 @@ def __init__( schema_validators_factory=schema_validators_factory, format_validators=format_validators, extra_format_validators=extra_format_validators, + extra_media_type_deserializers=extra_media_type_deserializers, ) self.security_provider_factory = security_provider_factory diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index fc3e93bd..b307d97c 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -15,6 +15,9 @@ from openapi_core.deserializing.media_types import ( media_type_deserializers_factory, ) +from openapi_core.deserializing.media_types.datatypes import ( + MediaTypeDeserializersDict, +) from openapi_core.deserializing.media_types.factories import ( MediaTypeDeserializersFactory, ) @@ -50,6 +53,9 @@ def __init__( schema_validators_factory: Optional[SchemaValidatorsFactory] = None, format_validators: Optional[FormatValidatorsDict] = None, extra_format_validators: Optional[FormatValidatorsDict] = None, + extra_media_type_deserializers: Optional[ + MediaTypeDeserializersDict + ] = None, ): self.spec = spec self.base_url = base_url @@ -68,6 +74,7 @@ def __init__( ) self.format_validators = format_validators self.extra_format_validators = extra_format_validators + self.extra_media_type_deserializers = extra_media_type_deserializers def _get_media_type(self, content: Spec, mimetype: str) -> MediaType: from openapi_core.templating.media_types.finders import MediaTypeFinder @@ -76,12 +83,15 @@ def _get_media_type(self, content: Spec, mimetype: str) -> MediaType: return finder.find(mimetype) def _deserialise_data(self, mimetype: str, value: Any) -> Any: - deserializer = self.media_type_deserializers_factory.create(mimetype) - return deserializer(value) + deserializer = self.media_type_deserializers_factory.create( + mimetype, + extra_media_type_deserializers=self.extra_media_type_deserializers, + ) + return deserializer.deserialize(value) def _deserialise_parameter(self, param: Spec, value: Any) -> Any: deserializer = self.parameter_deserializers_factory.create(param) - return deserializer(value) + return deserializer.deserialize(value) def _cast(self, schema: Spec, value: Any) -> Any: caster = self.schema_casters_factory.create(schema) diff --git a/tests/unit/deserializing/test_media_types_deserializers.py b/tests/unit/deserializing/test_media_types_deserializers.py index 246f656d..40960651 100644 --- a/tests/unit/deserializing/test_media_types_deserializers.py +++ b/tests/unit/deserializing/test_media_types_deserializers.py @@ -1,6 +1,7 @@ import pytest from openapi_core.deserializing.exceptions import DeserializeError +from openapi_core.deserializing.media_types import media_type_deserializers from openapi_core.deserializing.media_types.factories import ( MediaTypeDeserializersFactory, ) @@ -9,63 +10,91 @@ class TestMediaTypeDeserializer: @pytest.fixture def deserializer_factory(self): - def create_deserializer(media_type, custom_deserializers=None): + def create_deserializer( + media_type, + media_type_deserializers=media_type_deserializers, + extra_media_type_deserializers=None, + custom_deserializers=None, + ): return MediaTypeDeserializersFactory( - custom_deserializers=custom_deserializers - ).create(media_type) + media_type_deserializers, + custom_deserializers=custom_deserializers, + ).create( + media_type, + extra_media_type_deserializers=extra_media_type_deserializers, + ) return create_deserializer def test_unsupported(self, deserializer_factory): mimetype = "application/unsupported" + deserializer = deserializer_factory(mimetype) value = "" with pytest.warns(UserWarning): - result = deserializer_factory(mimetype)(value) + result = deserializer.deserialize(value) + + assert result == value + + def test_no_deserializer(self, deserializer_factory): + mimetype = "application/json" + deserializer = deserializer_factory( + mimetype, media_type_deserializers=None + ) + value = "{}" + + with pytest.warns(UserWarning): + result = deserializer.deserialize(value) assert result == value def test_json_empty(self, deserializer_factory): mimetype = "application/json" + deserializer = deserializer_factory(mimetype) value = "" with pytest.raises(DeserializeError): - deserializer_factory(mimetype)(value) + deserializer.deserialize(value) def test_json_empty_object(self, deserializer_factory): mimetype = "application/json" + deserializer = deserializer_factory(mimetype) value = "{}" - result = deserializer_factory(mimetype)(value) + result = deserializer.deserialize(value) assert result == {} def test_urlencoded_form_empty(self, deserializer_factory): mimetype = "application/x-www-form-urlencoded" + deserializer = deserializer_factory(mimetype) value = "" - result = deserializer_factory(mimetype)(value) + result = deserializer.deserialize(value) assert result == {} def test_urlencoded_form_simple(self, deserializer_factory): mimetype = "application/x-www-form-urlencoded" + deserializer = deserializer_factory(mimetype) value = "param1=test" - result = deserializer_factory(mimetype)(value) + result = deserializer.deserialize(value) assert result == {"param1": "test"} @pytest.mark.parametrize("value", [b"", ""]) def test_data_form_empty(self, deserializer_factory, value): mimetype = "multipart/form-data" + deserializer = deserializer_factory(mimetype) - result = deserializer_factory(mimetype)(value) + result = deserializer.deserialize(value) assert result == {} def test_data_form_simple(self, deserializer_factory): mimetype = "multipart/form-data" + deserializer = deserializer_factory(mimetype) value = ( b'Content-Type: multipart/form-data; boundary="' b'===============2872712225071193122=="\n' @@ -76,23 +105,48 @@ def test_data_form_simple(self, deserializer_factory): b"--===============2872712225071193122==--\n" ) - result = deserializer_factory(mimetype)(value) + result = deserializer.deserialize(value) assert result == {"param1": b"test"} - def test_custom_simple(self, deserializer_factory): + def test_custom_deserializer(self, deserializer_factory): + deserialized = "x-custom" + + def custom_deserializer(value): + return deserialized + custom_mimetype = "application/custom" + custom_deserializers = { + custom_mimetype: custom_deserializer, + } + with pytest.warns(DeprecationWarning): + deserializer = deserializer_factory( + custom_mimetype, custom_deserializers=custom_deserializers + ) value = "{}" + result = deserializer.deserialize(value) + + assert result == deserialized + + def test_custom_simple(self, deserializer_factory): + deserialized = "x-custom" + def custom_deserializer(value): - return "custom" + return deserialized - custom_deserializers = { + custom_mimetype = "application/custom" + extra_media_type_deserializers = { custom_mimetype: custom_deserializer, } + deserializer = deserializer_factory( + custom_mimetype, + extra_media_type_deserializers=extra_media_type_deserializers, + ) + value = "{}" - result = deserializer_factory( - custom_mimetype, custom_deserializers=custom_deserializers - )(value) + result = deserializer.deserialize( + value, + ) - assert result == "custom" + assert result == deserialized diff --git a/tests/unit/deserializing/test_parameters_deserializers.py b/tests/unit/deserializing/test_parameters_deserializers.py index 4e2ffd88..2247dea4 100644 --- a/tests/unit/deserializing/test_parameters_deserializers.py +++ b/tests/unit/deserializing/test_parameters_deserializers.py @@ -20,10 +20,11 @@ def create_deserializer(param): def test_unsupported(self, deserializer_factory): spec = {"name": "param", "in": "header", "style": "unsupported"} param = Spec.from_dict(spec, validator=None) + deserializer = deserializer_factory(param) value = "" with pytest.warns(UserWarning): - result = deserializer_factory(param)(value) + result = deserializer.deserialize(value) assert result == value @@ -33,10 +34,11 @@ def test_query_empty(self, deserializer_factory): "in": "query", } param = Spec.from_dict(spec, validator=None) + deserializer = deserializer_factory(param) value = "" with pytest.raises(EmptyQueryParameterValue): - deserializer_factory(param)(value) + deserializer.deserialize(value) def test_query_valid(self, deserializer_factory): spec = { @@ -44,8 +46,9 @@ def test_query_valid(self, deserializer_factory): "in": "query", } param = Spec.from_dict(spec, validator=None) + deserializer = deserializer_factory(param) value = "test" - result = deserializer_factory(param)(value) + result = deserializer.deserialize(value) assert result == value