|
1 | 1 | import warnings
|
2 | 2 | from typing import Any
|
| 3 | +from typing import Dict |
3 | 4 | from typing import Optional
|
4 | 5 | from xml.etree.ElementTree import ParseError
|
5 | 6 |
|
6 | 7 | from openapi_core.deserializing.media_types.datatypes import (
|
7 | 8 | DeserializerCallable,
|
8 | 9 | )
|
| 10 | +from openapi_core.deserializing.media_types.datatypes import ( |
| 11 | + MediaTypeDeserializersDict, |
| 12 | +) |
9 | 13 | from openapi_core.deserializing.media_types.exceptions import (
|
10 | 14 | MediaTypeDeserializeError,
|
11 | 15 | )
|
| 16 | +from openapi_core.schema.encodings import get_encoding_default_content_type |
| 17 | +from openapi_core.spec import Spec |
12 | 18 |
|
13 | 19 |
|
14 |
| -class CallableMediaTypeDeserializer: |
| 20 | +class ContentTypesDeserializer: |
15 | 21 | def __init__(
|
16 | 22 | self,
|
17 |
| - mimetype: str, |
18 |
| - deserializer_callable: Optional[DeserializerCallable] = None, |
| 23 | + media_type_deserializers: Optional[MediaTypeDeserializersDict] = None, |
| 24 | + extra_media_type_deserializers: Optional[MediaTypeDeserializersDict] = None, |
19 | 25 | ):
|
20 |
| - self.mimetype = mimetype |
21 |
| - self.deserializer_callable = deserializer_callable |
| 26 | + if media_type_deserializers is None: |
| 27 | + media_type_deserializers = {} |
| 28 | + self.media_type_deserializers = media_type_deserializers |
| 29 | + if extra_media_type_deserializers is None: |
| 30 | + extra_media_type_deserializers = {} |
| 31 | + self.extra_media_type_deserializers = extra_media_type_deserializers |
22 | 32 |
|
23 |
| - def deserialize(self, value: Any) -> Any: |
24 |
| - if self.deserializer_callable is None: |
25 |
| - warnings.warn(f"Unsupported {self.mimetype} mimetype") |
| 33 | + def deserialize(self, mimetype: str, value: Any) -> Any: |
| 34 | + deserializer_callable = self.get_deserializer_callable(mimetype) |
| 35 | + if deserializer_callable is None: |
| 36 | + warnings.warn(f"Unsupported {mimetype} mimetype") |
26 | 37 | return value
|
27 | 38 |
|
28 | 39 | try:
|
29 |
| - return self.deserializer_callable(value) |
| 40 | + return deserializer_callable(value) |
30 | 41 | except (ParseError, ValueError, TypeError, AttributeError):
|
31 |
| - raise MediaTypeDeserializeError(self.mimetype, value) |
| 42 | + raise MediaTypeDeserializeError(mimetype, value) |
| 43 | + |
| 44 | + def get_deserializer_callable( |
| 45 | + self, |
| 46 | + mimetype: str, |
| 47 | + ) -> Optional[DeserializerCallable]: |
| 48 | + if mimetype in self.extra_media_type_deserializers: |
| 49 | + return self.extra_media_type_deserializers[mimetype] |
| 50 | + return self.media_type_deserializers.get(mimetype) |
| 51 | + |
| 52 | + |
| 53 | +class MediaTypeDeserializer(ContentTypesDeserializer): |
| 54 | + def __init__( |
| 55 | + self, |
| 56 | + schema: Spec, |
| 57 | + mimetype: str, |
| 58 | + content_types_deserializers: ContentTypesDeserializer, |
| 59 | + encoding: Optional[Spec] = None, |
| 60 | + ): |
| 61 | + self.schema = schema |
| 62 | + self.mimetype = mimetype |
| 63 | + self.content_types_deserializers = content_types_deserializers |
| 64 | + self.encoding = encoding |
| 65 | + |
| 66 | + def deserialize(self, value: Any) -> Any: |
| 67 | + deserialized = self.content_types_deserializers.deserialize(self.mimetype, value) |
| 68 | + |
| 69 | + if self.mimetype != "application/x-www-form-urlencoded" and not self.mimetype.startswith("multipart"): |
| 70 | + return deserialized |
| 71 | + |
| 72 | + return self.decode(deserialized) |
| 73 | + |
| 74 | + def evolve(self, schema: Spec, mimetype: str) -> "MediaTypeDeserializer": |
| 75 | + cls = self.__class__ |
| 76 | + |
| 77 | + return cls( |
| 78 | + schema, |
| 79 | + mimetype, |
| 80 | + self.content_types_deserializers, |
| 81 | + ) |
| 82 | + |
| 83 | + def decode(self, value: Dict[str, Any]) -> Dict[str, Any]: |
| 84 | + return { |
| 85 | + prop_name: self.decode_property(prop_name, prop_value) |
| 86 | + for prop_name, prop_value in value.items() |
| 87 | + } |
| 88 | + |
| 89 | + def decode_property(self, prop_name: str, value: Any) -> Any: |
| 90 | + schema_props = self.schema.get("properties") |
| 91 | + prop_schema = None |
| 92 | + if schema_props is not None and prop_name in schema_props: |
| 93 | + prop_schema = self.schema.get(prop_name) |
| 94 | + prop_content_type = self.get_property_content_type(prop_name, prop_schema) |
| 95 | + prop_deserializer = self.evolve( |
| 96 | + schema=prop_schema, |
| 97 | + mimetype=prop_content_type, |
| 98 | + ) |
| 99 | + return prop_deserializer.deserialize(value) |
| 100 | + |
| 101 | + def get_property_content_type(self, prop_name: str, prop_schema: Optional[Spec] = None) -> str: |
| 102 | + if self.encoding is None: |
| 103 | + return get_encoding_default_content_type(prop_schema) |
| 104 | + |
| 105 | + if prop_name not in self.encoding: |
| 106 | + return get_encoding_default_content_type(prop_schema) |
| 107 | + |
| 108 | + prep_encoding = self.encoding.get(prop_name) |
| 109 | + prop_content_type = prep_encoding.getkey("contentType") |
| 110 | + if prop_content_type is None: |
| 111 | + return get_encoding_default_content_type(prop_schema) |
| 112 | + |
| 113 | + return prop_content_type |
0 commit comments