From dc5a4e6c2a788c58fbf5ef4d891dc26dbc38024d Mon Sep 17 00:00:00 2001 From: p1c2u Date: Sun, 12 Feb 2023 10:27:33 +0000 Subject: [PATCH 1/4] Unmarshallers format validators refactor --- openapi_core/__init__.py | 12 - .../unmarshalling/schemas/__init__.py | 111 ++- .../unmarshalling/schemas/datatypes.py | 5 + .../unmarshalling/schemas/exceptions.py | 35 +- .../unmarshalling/schemas/factories.py | 168 ++-- .../unmarshalling/schemas/unmarshallers.py | 489 +++++------ openapi_core/unmarshalling/schemas/util.py | 26 +- openapi_core/validation/__init__.py | 12 - openapi_core/validation/decorators.py | 2 +- openapi_core/validation/request/__init__.py | 78 +- openapi_core/validation/request/exceptions.py | 2 +- openapi_core/validation/request/proxies.py | 34 +- openapi_core/validation/request/validators.py | 62 +- openapi_core/validation/response/__init__.py | 55 +- .../validation/response/exceptions.py | 2 +- openapi_core/validation/response/proxies.py | 34 +- .../validation/response/validators.py | 49 +- openapi_core/validation/schemas/__init__.py | 26 + openapi_core/validation/schemas/datatypes.py | 4 + openapi_core/validation/schemas/exceptions.py | 23 + openapi_core/validation/schemas/factories.py | 62 ++ openapi_core/validation/schemas/util.py | 27 + openapi_core/validation/schemas/validators.py | 137 +++ openapi_core/validation/validators.py | 52 +- poetry.lock | 2 +- pyproject.toml | 2 +- tests/integration/conftest.py | 12 + .../unmarshalling/test_unmarshallers.py | 15 +- tests/integration/validation/test_petstore.py | 15 +- .../validation/test_request_validator.py | 413 +++++++++ .../validation/test_response_validator.py | 192 +++++ .../integration/validation/test_validators.py | 783 ------------------ tests/unit/unmarshalling/test_unmarshal.py | 36 +- tests/unit/unmarshalling/test_validate.py | 13 +- .../test_request_response_validators.py | 5 +- 35 files changed, 1451 insertions(+), 1544 deletions(-) create mode 100644 openapi_core/validation/schemas/__init__.py create mode 100644 openapi_core/validation/schemas/datatypes.py create mode 100644 openapi_core/validation/schemas/exceptions.py create mode 100644 openapi_core/validation/schemas/factories.py create mode 100644 openapi_core/validation/schemas/util.py create mode 100644 openapi_core/validation/schemas/validators.py create mode 100644 tests/integration/validation/test_request_validator.py create mode 100644 tests/integration/validation/test_response_validator.py delete mode 100644 tests/integration/validation/test_validators.py diff --git a/openapi_core/__init__.py b/openapi_core/__init__.py index df667774..4d8953b0 100644 --- a/openapi_core/__init__.py +++ b/openapi_core/__init__.py @@ -5,11 +5,6 @@ from openapi_core.validation.request import V30RequestValidator from openapi_core.validation.request import V31RequestValidator from openapi_core.validation.request import V31WebhookRequestValidator -from openapi_core.validation.request import openapi_request_body_validator -from openapi_core.validation.request import ( - openapi_request_parameters_validator, -) -from openapi_core.validation.request import openapi_request_security_validator from openapi_core.validation.request import openapi_request_validator from openapi_core.validation.request import openapi_v3_request_validator from openapi_core.validation.request import openapi_v30_request_validator @@ -19,8 +14,6 @@ from openapi_core.validation.response import V30ResponseValidator from openapi_core.validation.response import V31ResponseValidator from openapi_core.validation.response import V31WebhookResponseValidator -from openapi_core.validation.response import openapi_response_data_validator -from openapi_core.validation.response import openapi_response_headers_validator from openapi_core.validation.response import openapi_response_validator from openapi_core.validation.response import openapi_v3_response_validator from openapi_core.validation.response import openapi_v30_response_validator @@ -51,14 +44,9 @@ "openapi_v3_request_validator", "openapi_v30_request_validator", "openapi_v31_request_validator", - "openapi_request_body_validator", - "openapi_request_parameters_validator", - "openapi_request_security_validator", "openapi_request_validator", "openapi_v3_response_validator", "openapi_v30_response_validator", "openapi_v31_response_validator", - "openapi_response_data_validator", - "openapi_response_headers_validator", "openapi_response_validator", ] diff --git a/openapi_core/unmarshalling/schemas/__init__.py b/openapi_core/unmarshalling/schemas/__init__.py index 47b40055..d74e2eb9 100644 --- a/openapi_core/unmarshalling/schemas/__init__.py +++ b/openapi_core/unmarshalling/schemas/__init__.py @@ -1,31 +1,116 @@ -from openapi_schema_validator import OAS30Validator -from openapi_schema_validator import OAS31Validator +from collections import OrderedDict +from functools import partial + +from isodate.isodatetime import parse_datetime from openapi_core.unmarshalling.schemas.enums import ValidationContext from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) +from openapi_core.unmarshalling.schemas.unmarshallers import AnyUnmarshaller +from openapi_core.unmarshalling.schemas.unmarshallers import ArrayUnmarshaller +from openapi_core.unmarshalling.schemas.unmarshallers import ( + MultiTypeUnmarshaller, +) +from openapi_core.unmarshalling.schemas.unmarshallers import ( + ObjectReadUnmarshaller, +) +from openapi_core.unmarshalling.schemas.unmarshallers import ObjectUnmarshaller +from openapi_core.unmarshalling.schemas.unmarshallers import ( + ObjectWriteUnmarshaller, +) +from openapi_core.unmarshalling.schemas.unmarshallers import ( + PrimitiveUnmarshaller, +) +from openapi_core.unmarshalling.schemas.unmarshallers import TypesUnmarshaller +from openapi_core.unmarshalling.schemas.util import format_byte +from openapi_core.unmarshalling.schemas.util import format_date +from openapi_core.unmarshalling.schemas.util import format_uuid +from openapi_core.validation.schemas import ( + oas30_read_schema_validators_factory, +) +from openapi_core.validation.schemas import ( + oas30_write_schema_validators_factory, +) +from openapi_core.validation.schemas import oas31_schema_validators_factory __all__ = [ - "oas30_request_schema_unmarshallers_factory", - "oas30_response_schema_unmarshallers_factory", - "oas31_request_schema_unmarshallers_factory", - "oas31_response_schema_unmarshallers_factory", + "oas30_format_unmarshallers", + "oas31_format_unmarshallers", + "oas30_write_schema_unmarshallers_factory", + "oas30_read_schema_unmarshallers_factory", "oas31_schema_unmarshallers_factory", ] -oas30_request_schema_unmarshallers_factory = SchemaUnmarshallersFactory( - OAS30Validator, - context=ValidationContext.REQUEST, +oas30_unmarshallers_dict = OrderedDict( + [ + ("string", PrimitiveUnmarshaller), + ("integer", PrimitiveUnmarshaller), + ("number", PrimitiveUnmarshaller), + ("boolean", PrimitiveUnmarshaller), + ("array", ArrayUnmarshaller), + ("object", ObjectUnmarshaller), + ] +) +oas30_write_unmarshallers_dict = oas30_unmarshallers_dict.copy() +oas30_write_unmarshallers_dict.update( + { + "object": ObjectWriteUnmarshaller, + } +) +oas30_read_unmarshallers_dict = oas30_unmarshallers_dict.copy() +oas30_read_unmarshallers_dict.update( + { + "object": ObjectReadUnmarshaller, + } +) +oas31_unmarshallers_dict = oas30_unmarshallers_dict.copy() +oas31_unmarshallers_dict.update( + { + "null": PrimitiveUnmarshaller, + } +) + +oas30_write_types_unmarshaller = TypesUnmarshaller( + oas30_unmarshallers_dict, + AnyUnmarshaller, +) +oas30_read_types_unmarshaller = TypesUnmarshaller( + oas30_unmarshallers_dict, + AnyUnmarshaller, +) +oas31_types_unmarshaller = TypesUnmarshaller( + oas31_unmarshallers_dict, + AnyUnmarshaller, + multi=MultiTypeUnmarshaller, +) + +oas30_format_unmarshallers = { + # string compatible + "date": format_date, + "date-time": parse_datetime, + "binary": bytes, + "uuid": format_uuid, + "byte": format_byte, +} +oas31_format_unmarshallers = oas30_format_unmarshallers + +oas30_write_schema_unmarshallers_factory = SchemaUnmarshallersFactory( + oas30_write_schema_validators_factory, + oas30_write_types_unmarshaller, + format_unmarshallers=oas30_format_unmarshallers, ) -oas30_response_schema_unmarshallers_factory = SchemaUnmarshallersFactory( - OAS30Validator, - context=ValidationContext.RESPONSE, +oas30_read_schema_unmarshallers_factory = SchemaUnmarshallersFactory( + oas30_read_schema_validators_factory, + oas30_read_types_unmarshaller, + format_unmarshallers=oas30_format_unmarshallers, ) oas31_schema_unmarshallers_factory = SchemaUnmarshallersFactory( - OAS31Validator, + oas31_schema_validators_factory, + oas31_types_unmarshaller, + format_unmarshallers=oas31_format_unmarshallers, ) # alias to v31 version (request/response are the same bcs no context needed) diff --git a/openapi_core/unmarshalling/schemas/datatypes.py b/openapi_core/unmarshalling/schemas/datatypes.py index 96008373..23e0eb0c 100644 --- a/openapi_core/unmarshalling/schemas/datatypes.py +++ b/openapi_core/unmarshalling/schemas/datatypes.py @@ -1,7 +1,12 @@ +from typing import Any +from typing import Callable from typing import Dict from typing import Optional from openapi_core.unmarshalling.schemas.formatters import Formatter +FormatUnmarshaller = Callable[[Any], Any] + CustomFormattersDict = Dict[str, Formatter] FormattersDict = Dict[Optional[str], Formatter] +UnmarshallersDict = Dict[str, Callable[[Any], Any]] diff --git a/openapi_core/unmarshalling/schemas/exceptions.py b/openapi_core/unmarshalling/schemas/exceptions.py index 2d6fafad..43aaa2e2 100644 --- a/openapi_core/unmarshalling/schemas/exceptions.py +++ b/openapi_core/unmarshalling/schemas/exceptions.py @@ -1,6 +1,4 @@ from dataclasses import dataclass -from dataclasses import field -from typing import Iterable from openapi_core.exceptions import OpenAPIError @@ -9,29 +7,22 @@ class UnmarshalError(OpenAPIError): """Schema unmarshal operation error""" -class ValidateError(UnmarshalError): - """Schema validate operation error""" - - class UnmarshallerError(UnmarshalError): """Unmarshaller error""" @dataclass -class InvalidSchemaValue(ValidateError): - value: str - type: str - schema_errors: Iterable[Exception] = field(default_factory=list) +class FormatterNotFoundError(UnmarshallerError): + """Formatter not found to unmarshal""" + + type_format: str def __str__(self) -> str: - return ( - "Value {value} not valid for schema of type {type}: {errors}" - ).format(value=self.value, type=self.type, errors=self.schema_errors) + return f"Formatter not found for {self.type_format} format" -@dataclass -class InvalidSchemaFormatValue(UnmarshallerError): - """Value failed to format with formatter""" +class FormatUnmarshalError(UnmarshallerError): + """Unable to unmarshal value for format""" value: str type: str @@ -39,19 +30,9 @@ class InvalidSchemaFormatValue(UnmarshallerError): def __str__(self) -> str: return ( - "Failed to format value {value} to format {type}: {exception}" + "Unable to unmarshal value {value} for format {type}: {exception}" ).format( value=self.value, type=self.type, exception=self.original_exception, ) - - -@dataclass -class FormatterNotFoundError(UnmarshallerError): - """Formatter not found to unmarshal""" - - type_format: str - - def __str__(self) -> str: - return f"Formatter not found for {self.type_format} format" diff --git a/openapi_core/unmarshalling/schemas/factories.py b/openapi_core/unmarshalling/schemas/factories.py index bc847685..9cce1ce7 100644 --- a/openapi_core/unmarshalling/schemas/factories.py +++ b/openapi_core/unmarshalling/schemas/factories.py @@ -1,127 +1,85 @@ import sys import warnings -from typing import Any -from typing import Dict -from typing import Iterable from typing import Optional -from typing import Type -from typing import Union if sys.version_info >= (3, 8): from functools import cached_property else: from backports.cached_property import cached_property -from jsonschema.protocols import Validator -from openapi_schema_validator import OAS30Validator from openapi_core.spec import Spec from openapi_core.unmarshalling.schemas.datatypes import CustomFormattersDict -from openapi_core.unmarshalling.schemas.datatypes import FormattersDict +from openapi_core.unmarshalling.schemas.datatypes import FormatUnmarshaller +from openapi_core.unmarshalling.schemas.datatypes import UnmarshallersDict from openapi_core.unmarshalling.schemas.enums import ValidationContext from openapi_core.unmarshalling.schemas.exceptions import ( FormatterNotFoundError, ) -from openapi_core.unmarshalling.schemas.formatters import Formatter -from openapi_core.unmarshalling.schemas.unmarshallers import AnyUnmarshaller -from openapi_core.unmarshalling.schemas.unmarshallers import ArrayUnmarshaller -from openapi_core.unmarshalling.schemas.unmarshallers import ( - BaseSchemaUnmarshaller, -) -from openapi_core.unmarshalling.schemas.unmarshallers import ( - BooleanUnmarshaller, -) -from openapi_core.unmarshalling.schemas.unmarshallers import ( - ComplexUnmarshaller, -) -from openapi_core.unmarshalling.schemas.unmarshallers import ( - IntegerUnmarshaller, -) -from openapi_core.unmarshalling.schemas.unmarshallers import ( - MultiTypeUnmarshaller, -) -from openapi_core.unmarshalling.schemas.unmarshallers import NullUnmarshaller -from openapi_core.unmarshalling.schemas.unmarshallers import NumberUnmarshaller -from openapi_core.unmarshalling.schemas.unmarshallers import ObjectUnmarshaller -from openapi_core.unmarshalling.schemas.unmarshallers import StringUnmarshaller -from openapi_core.unmarshalling.schemas.util import build_format_checker +from openapi_core.unmarshalling.schemas.unmarshallers import SchemaUnmarshaller +from openapi_core.unmarshalling.schemas.unmarshallers import TypesUnmarshaller +from openapi_core.validation.schemas.factories import SchemaValidatorsFactory -class SchemaValidatorsFactory: - CONTEXTS = { - ValidationContext.REQUEST: "write", - ValidationContext.RESPONSE: "read", - } - +class SchemaFormatUnmarshallersFactory: def __init__( self, - schema_validator_class: Type[Validator], + schema_validators_factory: SchemaValidatorsFactory, + format_unmarshallers: Optional[UnmarshallersDict] = None, custom_formatters: Optional[CustomFormattersDict] = None, - context: Optional[ValidationContext] = None, ): - self.schema_validator_class = schema_validator_class + self.schema_validators_factory = schema_validators_factory + if format_unmarshallers is None: + format_unmarshallers = {} + self.format_unmarshallers = format_unmarshallers if custom_formatters is None: custom_formatters = {} self.custom_formatters = custom_formatters - self.context = context - def create(self, schema: Spec) -> Validator: - resolver = schema.accessor.resolver # type: ignore - custom_format_checks = { - name: formatter.validate - for name, formatter in self.custom_formatters.items() - } - format_checker = build_format_checker(**custom_format_checks) - kwargs = { - "resolver": resolver, - "format_checker": format_checker, - } - if self.context is not None: - kwargs[self.CONTEXTS[self.context]] = True - with schema.open() as schema_dict: - return self.schema_validator_class(schema_dict, **kwargs) + def create(self, schema_format: str) -> Optional[FormatUnmarshaller]: + if schema_format in self.custom_formatters: + formatter = self.custom_formatters[schema_format] + return formatter.format + if schema_format in self.format_unmarshallers: + return self.format_unmarshallers[schema_format] + return None -class SchemaUnmarshallersFactory: - UNMARSHALLERS: Dict[str, Type[BaseSchemaUnmarshaller]] = { - "string": StringUnmarshaller, - "integer": IntegerUnmarshaller, - "number": NumberUnmarshaller, - "boolean": BooleanUnmarshaller, - "array": ArrayUnmarshaller, - "object": ObjectUnmarshaller, - "null": NullUnmarshaller, - "any": AnyUnmarshaller, - } - - COMPLEX_UNMARSHALLERS: Dict[str, Type[ComplexUnmarshaller]] = { - "array": ArrayUnmarshaller, - "object": ObjectUnmarshaller, - "any": AnyUnmarshaller, - } +class SchemaUnmarshallersFactory: def __init__( self, - schema_validator_class: Type[Validator], + schema_validators_factory: SchemaValidatorsFactory, + types_unmarshaller: TypesUnmarshaller, + format_unmarshallers: Optional[UnmarshallersDict] = None, custom_formatters: Optional[CustomFormattersDict] = None, - context: Optional[ValidationContext] = None, ): - self.schema_validator_class = schema_validator_class + self.schema_validators_factory = schema_validators_factory + self.types_unmarshaller = types_unmarshaller if custom_formatters is None: custom_formatters = {} + else: + warnings.warn( + "custom_formatters is deprecated. " + "Register new checks to FormatChecker to validate custom formats " + "and add format_unmarshallers to unmarshal custom formats.", + DeprecationWarning, + ) + if format_unmarshallers is None: + format_unmarshallers = {} + self.format_unmarshallers = format_unmarshallers self.custom_formatters = custom_formatters - self.context = context @cached_property - def validators_factory(self) -> SchemaValidatorsFactory: - return SchemaValidatorsFactory( - self.schema_validator_class, + def format_unmarshallers_factory(self) -> SchemaFormatUnmarshallersFactory: + return SchemaFormatUnmarshallersFactory( + self.schema_validators_factory, + self.format_unmarshallers, self.custom_formatters, - self.context, ) def create( self, schema: Spec, type_override: Optional[str] = None - ) -> BaseSchemaUnmarshaller: + ) -> SchemaUnmarshaller: """Create unmarshaller from the schema.""" if schema is None: raise TypeError("Invalid schema") @@ -129,39 +87,29 @@ def create( if schema.getkey("deprecated", False): warnings.warn("The schema is deprecated", DeprecationWarning) - validator = self.validators_factory.create(schema) + formatters_checks = { + name: formatter.validate + for name, formatter in self.custom_formatters.items() + } + self.schema_validators_factory.add_checks(**formatters_checks) + + schema_validator = self.schema_validators_factory.create(schema) schema_format = schema.getkey("format") - formatter = self.custom_formatters.get(schema_format) - schema_type = type_override or schema.getkey("type", "any") - if isinstance(schema_type, Iterable) and not isinstance( - schema_type, str + # FIXME: don;t raise exception on unknown format + if ( + schema_format + and schema_format + not in self.schema_validators_factory.format_checker.checkers + and schema_format not in self.custom_formatters ): - return MultiTypeUnmarshaller( - schema, - validator, - formatter, - self.validators_factory, - self, - context=self.context, - ) - if schema_type in self.COMPLEX_UNMARSHALLERS: - complex_klass = self.COMPLEX_UNMARSHALLERS[schema_type] - return complex_klass( - schema, - validator, - formatter, - self.validators_factory, - self, - context=self.context, - ) + raise FormatterNotFoundError(schema_format) - klass = self.UNMARSHALLERS[schema_type] - return klass( + return SchemaUnmarshaller( schema, - validator, - formatter, - self.validators_factory, + schema_validator, self, + self.format_unmarshallers_factory, + self.types_unmarshaller, ) diff --git a/openapi_core/unmarshalling/schemas/unmarshallers.py b/openapi_core/unmarshalling/schemas/unmarshallers.py index e9a21ced..f5ae678a 100644 --- a/openapi_core/unmarshalling/schemas/unmarshallers.py +++ b/openapi_core/unmarshalling/schemas/unmarshallers.py @@ -1,48 +1,30 @@ import logging +import warnings from functools import partial from typing import TYPE_CHECKING from typing import Any from typing import Iterable from typing import Iterator from typing import List +from typing import Mapping from typing import Optional +from typing import Type +from typing import Union from typing import cast -from isodate.isodatetime import parse_datetime -from jsonschema._types import is_array -from jsonschema._types import is_bool -from jsonschema._types import is_integer -from jsonschema._types import is_null -from jsonschema._types import is_number -from jsonschema._types import is_object -from jsonschema.exceptions import ValidationError -from jsonschema.protocols import Validator -from openapi_schema_validator._format import oas30_format_checker -from openapi_schema_validator._types import is_string - from openapi_core.extensions.models.factories import ModelPathFactory from openapi_core.schema.schemas import get_properties from openapi_core.spec import Spec -from openapi_core.unmarshalling.schemas.datatypes import FormattersDict from openapi_core.unmarshalling.schemas.enums import ValidationContext -from openapi_core.unmarshalling.schemas.exceptions import ( - FormatterNotFoundError, -) -from openapi_core.unmarshalling.schemas.exceptions import ( - InvalidSchemaFormatValue, -) -from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue -from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError +from openapi_core.unmarshalling.schemas.exceptions import FormatUnmarshalError from openapi_core.unmarshalling.schemas.exceptions import UnmarshallerError -from openapi_core.unmarshalling.schemas.exceptions import ValidateError -from openapi_core.unmarshalling.schemas.formatters import Formatter -from openapi_core.unmarshalling.schemas.util import format_byte -from openapi_core.unmarshalling.schemas.util import format_date -from openapi_core.unmarshalling.schemas.util import format_number -from openapi_core.unmarshalling.schemas.util import format_uuid -from openapi_core.util import forcebool +from openapi_core.validation.schemas.exceptions import ValidateError +from openapi_core.validation.schemas.validators import SchemaValidator if TYPE_CHECKING: + from openapi_core.unmarshalling.schemas.factories import ( + SchemaFormatUnmarshallersFactory, + ) from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) @@ -53,275 +35,94 @@ log = logging.getLogger(__name__) -class BaseSchemaUnmarshaller: - FORMATTERS: FormattersDict = { - None: Formatter(), - } - +class PrimitiveUnmarshaller: def __init__( self, - schema: Spec, - validator: Validator, - formatter: Optional[Formatter], - validators_factory: "SchemaValidatorsFactory", - unmarshallers_factory: "SchemaUnmarshallersFactory", - ): + schema, + schema_validator, + schema_unmarshaller, + schema_unmarshallers_factory, + ) -> None: self.schema = schema - self.validator = validator - self.schema_format = schema.getkey("format") + self.schema_validator = schema_validator + self.schema_unmarshaller = schema_unmarshaller + self.schema_unmarshallers_factory = schema_unmarshallers_factory - if formatter is None: - if self.schema_format not in self.FORMATTERS: - raise FormatterNotFoundError(self.schema_format) - self.formatter = self.FORMATTERS[self.schema_format] - else: - self.formatter = formatter - - self.validators_factory = validators_factory - self.unmarshallers_factory = unmarshallers_factory - - def __call__(self, value: Any) -> Any: - self.validate(value) + self.schema_format = schema.getkey("format") - # skip unmarshalling for nullable in OpenAPI 3.0 - if value is None and self.schema.getkey("nullable", False): + def __call__(self, value: Any, subschemas: bool = True) -> Any: + best_format = self._get_format(value, subschemas=subschemas) + format_unmarshaller = self.schema_unmarshallers_factory.format_unmarshallers_factory.create( + best_format + ) + if format_unmarshaller is None: return value - - return self.unmarshal(value) - - def _validate_format(self, value: Any) -> None: - result = self.formatter.validate(value) - if not result: - schema_type = self.schema.getkey("type", "any") - raise InvalidSchemaValue(value, schema_type) - - def validate(self, value: Any) -> None: - errors_iter = self.validator.iter_errors(value) - errors = tuple(errors_iter) - if errors: - schema_type = self.schema.getkey("type", "any") - raise InvalidSchemaValue(value, schema_type, schema_errors=errors) - - def format(self, value: Any) -> Any: try: - return self.formatter.format(value) + return format_unmarshaller(value) except (ValueError, TypeError) as exc: - raise InvalidSchemaFormatValue(value, self.schema_format, exc) - - def _get_best_unmarshaller(self, value: Any) -> "BaseSchemaUnmarshaller": - if "format" not in self.schema: - one_of_schema = self._get_one_of_schema(value) - if one_of_schema is not None and "format" in one_of_schema: - one_of_unmarshaller = self.unmarshallers_factory.create( - one_of_schema - ) - return one_of_unmarshaller - - any_of_schemas = self._iter_any_of_schemas(value) - for any_of_schema in any_of_schemas: - if "format" in any_of_schema: - any_of_unmarshaller = self.unmarshallers_factory.create( - any_of_schema - ) - return any_of_unmarshaller - - all_of_schemas = self._iter_all_of_schemas(value) - for all_of_schema in all_of_schemas: - if "format" in all_of_schema: - all_of_unmarshaller = self.unmarshallers_factory.create( - all_of_schema - ) - return all_of_unmarshaller - - return self - - def unmarshal(self, value: Any) -> Any: - unmarshaller = self._get_best_unmarshaller(value) - return unmarshaller.format(value) + raise FormatUnmarshalError(value, self.schema_format, exc) - def _get_one_of_schema( - self, - value: Any, - ) -> Optional[Spec]: - if "oneOf" not in self.schema: - return None + def _get_format( + self, value: Any, subschemas: bool = True + ) -> Optional[str]: + if "format" in self.schema: + return self.schema.getkey("format") - one_of_schemas = self.schema / "oneOf" - for subschema in one_of_schemas: - validator = self.validators_factory.create(subschema) - try: - validator.validate(value) - except ValidationError: - continue - else: - return subschema + if subschemas is False: + return None - log.warning("valid oneOf schema not found") - return None + one_of_schema = self.schema_validator.get_one_of_schema(value) + if one_of_schema is not None and "format" in one_of_schema: + return one_of_schema.getkey("format") - def _iter_any_of_schemas( - self, - value: Any, - ) -> Iterator[Spec]: - if "anyOf" not in self.schema: - return - - any_of_schemas = self.schema / "anyOf" - for subschema in any_of_schemas: - validator = self.validators_factory.create(subschema) - try: - validator.validate(value) - except ValidationError: - continue - else: - yield subschema + any_of_schemas = self.schema_validator.iter_any_of_schemas(value) + for any_of_schema in any_of_schemas: + if "format" in any_of_schema: + return any_of_schema.getkey("format") - def _iter_all_of_schemas( - self, - value: Any, - ) -> Iterator[Spec]: - if "allOf" not in self.schema: - return - - all_of_schemas = self.schema / "allOf" - for subschema in all_of_schemas: - if "type" not in subschema: - continue - validator = self.validators_factory.create(subschema) - try: - validator.validate(value) - except ValidationError: - log.warning("invalid allOf schema found") - else: - yield subschema - - -class StringUnmarshaller(BaseSchemaUnmarshaller): - FORMATTERS: FormattersDict = { - None: Formatter.from_callables(partial(is_string, None), str), - "password": Formatter.from_callables( - partial(oas30_format_checker.check, format="password"), str - ), - "date": Formatter.from_callables( - partial(oas30_format_checker.check, format="date"), format_date - ), - "date-time": Formatter.from_callables( - partial(oas30_format_checker.check, format="date-time"), - parse_datetime, - ), - "binary": Formatter.from_callables( - partial(oas30_format_checker.check, format="binary"), bytes - ), - "uuid": Formatter.from_callables( - partial(oas30_format_checker.check, format="uuid"), format_uuid - ), - "byte": Formatter.from_callables( - partial(oas30_format_checker.check, format="byte"), format_byte - ), - } - - -class IntegerUnmarshaller(BaseSchemaUnmarshaller): - FORMATTERS: FormattersDict = { - None: Formatter.from_callables(partial(is_integer, None), int), - "int32": Formatter.from_callables( - partial(oas30_format_checker.check, format="int32"), int - ), - "int64": Formatter.from_callables( - partial(oas30_format_checker.check, format="int64"), int - ), - } - - -class NumberUnmarshaller(BaseSchemaUnmarshaller): - FORMATTERS: FormattersDict = { - None: Formatter.from_callables( - partial(is_number, None), format_number - ), - "float": Formatter.from_callables( - partial(oas30_format_checker.check, format="float"), float - ), - "double": Formatter.from_callables( - partial(oas30_format_checker.check, format="double"), float - ), - } - - -class BooleanUnmarshaller(BaseSchemaUnmarshaller): - FORMATTERS: FormattersDict = { - None: Formatter.from_callables(partial(is_bool, None), forcebool), - } - - -class NullUnmarshaller(BaseSchemaUnmarshaller): - FORMATTERS: FormattersDict = { - None: Formatter.from_callables(partial(is_null, None), None), - } - - -class ComplexUnmarshaller(BaseSchemaUnmarshaller): - def __init__( - self, - schema: Spec, - validator: Validator, - formatter: Optional[Formatter], - validators_factory: "SchemaValidatorsFactory", - unmarshallers_factory: "SchemaUnmarshallersFactory", - context: Optional[ValidationContext] = None, - ): - super().__init__( - schema, - validator, - formatter, - validators_factory, - unmarshallers_factory, - ) - self.context = context + all_of_schemas = self.schema_validator.iter_all_of_schemas(value) + for all_of_schema in all_of_schemas: + if "format" in all_of_schema: + return all_of_schema.getkey("format") + return None -class ArrayUnmarshaller(ComplexUnmarshaller): - FORMATTERS: FormattersDict = { - None: Formatter.from_callables(partial(is_array, None), list), - } +class ArrayUnmarshaller(PrimitiveUnmarshaller): @property - def items_unmarshaller(self) -> "BaseSchemaUnmarshaller": + def items_unmarshaller(self) -> "PrimitiveUnmarshaller": # sometimes we don't have any schema i.e. free-form objects items_schema = self.schema.get( "items", Spec.from_dict({}, validator=None) ) - return self.unmarshallers_factory.create(items_schema) + return self.schema_unmarshaller.evolve(items_schema) - def unmarshal(self, value: Any) -> Optional[List[Any]]: - value = super().unmarshal(value) - return list(map(self.items_unmarshaller, value)) + def __call__(self, value: Any) -> Optional[List[Any]]: + return list(map(self.items_unmarshaller.unmarshal, value)) -class ObjectUnmarshaller(ComplexUnmarshaller): - FORMATTERS: FormattersDict = { - None: Formatter.from_callables(partial(is_object, None), dict), - } +class ObjectUnmarshaller(PrimitiveUnmarshaller): + context = NotImplemented @property def object_class_factory(self) -> ModelPathFactory: return ModelPathFactory() - def unmarshal(self, value: Any) -> Any: - properties = self.format(value) + def __call__(self, value: Any) -> Any: + properties = self._unmarshal_raw(value) fields: Iterable[str] = properties and properties.keys() or [] object_class = self.object_class_factory.create(self.schema, fields) return object_class(**properties) - def format(self, value: Any, schema_only: bool = False) -> Any: - formatted = super().format(value) + def _unmarshal_raw(self, value: Any, schema_only: bool = False) -> Any: + formatted = super().__call__(value) return self._unmarshal_properties(formatted, schema_only=schema_only) - def _clone(self, schema: Spec) -> "ObjectUnmarshaller": - return cast( - "ObjectUnmarshaller", - self.unmarshallers_factory.create(schema, type_override="object"), + def evolve(self, schema: Spec) -> "ObjectUnmarshaller": + return self.schema_unmarshaller.evolve(schema).get_unmarshaller( + "object" ) def _unmarshal_properties( @@ -329,34 +130,36 @@ def _unmarshal_properties( ) -> Any: properties = {} - one_of_schema = self._get_one_of_schema(value) + one_of_schema = self.schema_validator.get_one_of_schema(value) if one_of_schema is not None: - one_of_properties = self._clone(one_of_schema).format( + one_of_properties = self.evolve(one_of_schema)._unmarshal_raw( value, schema_only=True ) properties.update(one_of_properties) - any_of_schemas = self._iter_any_of_schemas(value) + any_of_schemas = self.schema_validator.iter_any_of_schemas(value) for any_of_schema in any_of_schemas: - any_of_properties = self._clone(any_of_schema).format( + any_of_properties = self.evolve(any_of_schema)._unmarshal_raw( value, schema_only=True ) properties.update(any_of_properties) - all_of_schemas = self._iter_all_of_schemas(value) + all_of_schemas = self.schema_validator.iter_all_of_schemas(value) for all_of_schema in all_of_schemas: - all_of_properties = self._clone(all_of_schema).format( + all_of_properties = self.evolve(all_of_schema)._unmarshal_raw( value, schema_only=True ) properties.update(all_of_properties) for prop_name, prop_schema in get_properties(self.schema).items(): - read_only = prop_schema.getkey("readOnly", False) - if self.context == ValidationContext.REQUEST and read_only: - continue - write_only = prop_schema.getkey("writeOnly", False) - if self.context == ValidationContext.RESPONSE and write_only: - continue + # check for context in OpenAPI 3.0 + if self.context is not NotImplemented: + read_only = prop_schema.getkey("readOnly", False) + if self.context == ValidationContext.REQUEST and read_only: + continue + write_only = prop_schema.getkey("writeOnly", False) + if self.context == ValidationContext.RESPONSE and write_only: + continue try: prop_value = value[prop_name] except KeyError: @@ -364,9 +167,9 @@ def _unmarshal_properties( continue prop_value = prop_schema["default"] - properties[prop_name] = self.unmarshallers_factory.create( + properties[prop_name] = self.schema_unmarshallers_factory.create( prop_schema - )(prop_value) + ).unmarshal(prop_value) if schema_only: return properties @@ -383,51 +186,53 @@ def _unmarshal_properties( # defined schema else: additional_prop_schema = self.schema / "additionalProperties" - additional_prop_unmarshaler = self.unmarshallers_factory.create( - additional_prop_schema + additional_prop_unmarshaler = ( + self.schema_unmarshallers_factory.create( + additional_prop_schema + ) ) for prop_name, prop_value in value.items(): if prop_name in properties: continue - properties[prop_name] = additional_prop_unmarshaler(prop_value) + properties[prop_name] = additional_prop_unmarshaler.unmarshal( + prop_value + ) return properties -class MultiTypeUnmarshaller(ComplexUnmarshaller): - @property - def types_unmarshallers(self) -> List["BaseSchemaUnmarshaller"]: - types = self.schema.getkey("type", ["any"]) - unmarshaller = partial(self.unmarshallers_factory.create, self.schema) - return list(map(unmarshaller, types)) +class ObjectReadUnmarshaller(ObjectUnmarshaller): + context = ValidationContext.RESPONSE + + +class ObjectWriteUnmarshaller(ObjectUnmarshaller): + context = ValidationContext.REQUEST + +class MultiTypeUnmarshaller(PrimitiveUnmarshaller): @property def type(self) -> List[str]: types = self.schema.getkey("type", ["any"]) assert isinstance(types, list) return types - def _get_unmarshallers_iter(self) -> Iterator["BaseSchemaUnmarshaller"]: + def _get_best_unmarshaller(self, value: Any) -> "PrimitiveUnmarshaller": for schema_type in self.type: - yield self.unmarshallers_factory.create( - self.schema, type_override=schema_type + result = self.schema_validator.type_validator( + value, type_override=schema_type ) - - def _get_best_unmarshaller(self, value: Any) -> "BaseSchemaUnmarshaller": - for unmarshaller in self._get_unmarshallers_iter(): - # validate with validator of formatter (usualy type validator) - try: - unmarshaller._validate_format(value) - except ValidateError: + if not result: continue - else: - return unmarshaller + result = self.schema_validator.format_validator(value) + if not result: + continue + return self.schema_unmarshaller.get_unmarshaller(schema_type) raise UnmarshallerError("Unmarshaller not found for type(s)") - def unmarshal(self, value: Any) -> Any: + def __call__(self, value: Any) -> Any: unmarshaller = self._get_best_unmarshaller(value) - return unmarshaller.unmarshal(value) + return unmarshaller(value) class AnyUnmarshaller(MultiTypeUnmarshaller): @@ -443,3 +248,93 @@ class AnyUnmarshaller(MultiTypeUnmarshaller): @property def type(self) -> List[str]: return self.SCHEMA_TYPES_ORDER + + +class TypesUnmarshaller: + unmarshallers: Mapping[str, Type[PrimitiveUnmarshaller]] = {} + multi: Optional[Type[PrimitiveUnmarshaller]] = None + + def __init__( + self, + unmarshallers: Mapping[str, Type[PrimitiveUnmarshaller]], + default: Type[PrimitiveUnmarshaller], + multi: bool = False, + ): + self.unmarshallers = unmarshallers + self.default = default + self.multi = multi + + def get_type_unmarshaller( + self, + schema_type: Optional[Union[Iterable, str]], + ) -> Type["PrimitiveUnmarshaller"]: + if schema_type is None: + return self.default + if isinstance(schema_type, Iterable) and not isinstance( + schema_type, str + ): + if self.multi is None: + raise TypeError("Unmarshaller does not accept multiple types") + return self.multi + + return self.unmarshallers[schema_type] + + +class SchemaUnmarshaller: + def __init__( + self, + schema: Spec, + schema_validator: SchemaValidator, + schema_unmarshallers_factory: "SchemaUnmarshallersFactory", + format_unmarshallers_factory: "SchemaFormatUnmarshallersFactory", + types_unmarshaller: TypesUnmarshaller, + ): + self.schema = schema + self.schema_validator = schema_validator + + self.schema_unmarshallers_factory = schema_unmarshallers_factory + self.format_unmarshallers_factory = format_unmarshallers_factory + + self.types_unmarshaller = types_unmarshaller + + def __call__(self, value: Any) -> Any: + warnings.warn( + "Calling unmarshaller itself is deprecated. " + "Use unmarshal method instead.", + DeprecationWarning, + ) + return self.unmarshal(value) + + def unmarshal(self, value: Any, subschemas: bool = True) -> Any: + self.schema_validator.validate(value) + + # skip unmarshalling for nullable in OpenAPI 3.0 + if value is None and self.schema.getkey("nullable", False): + return value + + schema_type = self.schema.getkey("type") + unmarshaller = self.get_unmarshaller(schema_type) + return unmarshaller(value) + + def get_unmarshaller( + self, + schema_type: Optional[Union[Iterable, str]], + ): + klass = self.types_unmarshaller.get_type_unmarshaller(schema_type) + return klass( + self.schema, + self.schema_validator, + self, + self.schema_unmarshallers_factory, + ) + + def evolve(self, schema: Spec) -> "SchemaUnmarshaller": + cls = self.__class__ + + return cls( + schema, + self.schema_validator.evolve(schema), + self.schema_unmarshallers_factory, + self.format_unmarshallers_factory, + self.types_unmarshaller, + ) diff --git a/openapi_core/unmarshalling/schemas/util.py b/openapi_core/unmarshalling/schemas/util.py index ca240f48..91ae690e 100644 --- a/openapi_core/unmarshalling/schemas/util.py +++ b/openapi_core/unmarshalling/schemas/util.py @@ -1,16 +1,16 @@ """OpenAPI core schemas util module""" from base64 import b64decode -from copy import copy from datetime import date from datetime import datetime -from functools import lru_cache +from typing import TYPE_CHECKING from typing import Any -from typing import Callable -from typing import Optional from typing import Union from uuid import UUID -from openapi_schema_validator import oas30_format_checker +if TYPE_CHECKING: + StaticMethod = staticmethod[Any] +else: + StaticMethod = staticmethod def format_date(value: str) -> date: @@ -34,12 +34,12 @@ def format_number(value: str) -> Union[int, float]: return float(value) -@lru_cache() -def build_format_checker(**custom_format_checks: Callable[[Any], Any]) -> Any: - if not custom_format_checks: - return oas30_format_checker +class callable_staticmethod(StaticMethod): + """Callable version of staticmethod. - fc = copy(oas30_format_checker) - for name, check in custom_format_checks.items(): - fc.checks(name)(check) - return fc + Prior to Python 3.10, staticmethods are not directly callable + from inside the class. + """ + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.__func__(*args, **kwargs) diff --git a/openapi_core/validation/__init__.py b/openapi_core/validation/__init__.py index 52d41ee2..96f8098f 100644 --- a/openapi_core/validation/__init__.py +++ b/openapi_core/validation/__init__.py @@ -1,20 +1,8 @@ """OpenAPI core validation module""" -from openapi_core.validation.request import openapi_request_body_validator -from openapi_core.validation.request import ( - openapi_request_parameters_validator, -) -from openapi_core.validation.request import openapi_request_security_validator from openapi_core.validation.request import openapi_request_validator -from openapi_core.validation.response import openapi_response_data_validator -from openapi_core.validation.response import openapi_response_headers_validator from openapi_core.validation.response import openapi_response_validator __all__ = [ - "openapi_request_body_validator", - "openapi_request_parameters_validator", - "openapi_request_security_validator", "openapi_request_validator", - "openapi_response_data_validator", - "openapi_response_headers_validator", "openapi_response_validator", ] diff --git a/openapi_core/validation/decorators.py b/openapi_core/validation/decorators.py index 9707483c..fbf50b5a 100644 --- a/openapi_core/validation/decorators.py +++ b/openapi_core/validation/decorators.py @@ -6,7 +6,7 @@ from typing import Type from openapi_core.exceptions import OpenAPIError -from openapi_core.unmarshalling.schemas.exceptions import ValidateError +from openapi_core.validation.schemas.exceptions import ValidateError OpenAPIErrorType = Type[OpenAPIError] diff --git a/openapi_core/validation/request/__init__.py b/openapi_core/validation/request/__init__.py index 5df11a56..27828cb0 100644 --- a/openapi_core/validation/request/__init__.py +++ b/openapi_core/validation/request/__init__.py @@ -2,7 +2,7 @@ from functools import partial from openapi_core.unmarshalling.schemas import ( - oas30_request_schema_unmarshallers_factory, + oas30_write_schema_unmarshallers_factory, ) from openapi_core.unmarshalling.schemas import ( oas31_schema_unmarshallers_factory, @@ -11,16 +11,6 @@ DetectSpecRequestValidatorProxy, ) from openapi_core.validation.request.proxies import SpecRequestValidatorProxy -from openapi_core.validation.request.validators import ( - APICallRequestBodyValidator, -) -from openapi_core.validation.request.validators import ( - APICallRequestParametersValidator, -) -from openapi_core.validation.request.validators import ( - APICallRequestSecurityValidator, -) -from openapi_core.validation.request.validators import APICallRequestValidator from openapi_core.validation.request.validators import V30RequestBodyValidator from openapi_core.validation.request.validators import ( V30RequestParametersValidator, @@ -65,21 +55,9 @@ "V31WebhookRequestValidator", "V3RequestValidator", "V3WebhookRequestValidator", - "openapi_v30_request_body_validator", - "openapi_v30_request_parameters_validator", - "openapi_v30_request_security_validator", "openapi_v30_request_validator", - "openapi_v31_request_body_validator", - "openapi_v31_request_parameters_validator", - "openapi_v31_request_security_validator", "openapi_v31_request_validator", - "openapi_v3_request_body_validator", - "openapi_v3_request_parameters_validator", - "openapi_v3_request_security_validator", "openapi_v3_request_validator", - "openapi_request_body_validator", - "openapi_request_parameters_validator", - "openapi_request_security_validator", "openapi_request_validator", ] @@ -88,71 +66,23 @@ V3WebhookRequestValidator = V31WebhookRequestValidator # spec validators -openapi_v30_request_body_validator = SpecRequestValidatorProxy( - APICallRequestBodyValidator, - schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory, -) -openapi_v30_request_parameters_validator = SpecRequestValidatorProxy( - APICallRequestParametersValidator, - schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory, -) -openapi_v30_request_security_validator = SpecRequestValidatorProxy( - APICallRequestSecurityValidator, - schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory, -) openapi_v30_request_validator = SpecRequestValidatorProxy( - APICallRequestValidator, - schema_unmarshallers_factory=oas30_request_schema_unmarshallers_factory, + "APICallRequestUnmarshaller", + schema_unmarshallers_factory=oas30_write_schema_unmarshallers_factory, deprecated="openapi_v30_request_validator", use="V30RequestValidator", ) - -openapi_v31_request_body_validator = SpecRequestValidatorProxy( - APICallRequestBodyValidator, - schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, -) -openapi_v31_request_parameters_validator = SpecRequestValidatorProxy( - APICallRequestParametersValidator, - schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, -) -openapi_v31_request_security_validator = SpecRequestValidatorProxy( - APICallRequestSecurityValidator, - schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, -) openapi_v31_request_validator = SpecRequestValidatorProxy( - APICallRequestValidator, + "APICallRequestUnmarshaller", schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, deprecated="openapi_v31_request_validator", use="V31RequestValidator", ) # spec validators alias to the latest v3 version -openapi_v3_request_body_validator = openapi_v31_request_body_validator -openapi_v3_request_parameters_validator = ( - openapi_v31_request_parameters_validator -) -openapi_v3_request_security_validator = openapi_v31_request_security_validator openapi_v3_request_validator = openapi_v31_request_validator # detect version spec -openapi_request_body_validator = DetectSpecRequestValidatorProxy( - { - ("openapi", "3.0"): openapi_v30_request_body_validator, - ("openapi", "3.1"): openapi_v31_request_body_validator, - }, -) -openapi_request_parameters_validator = DetectSpecRequestValidatorProxy( - { - ("openapi", "3.0"): openapi_v30_request_parameters_validator, - ("openapi", "3.1"): openapi_v31_request_parameters_validator, - }, -) -openapi_request_security_validator = DetectSpecRequestValidatorProxy( - { - ("openapi", "3.0"): openapi_v30_request_security_validator, - ("openapi", "3.1"): openapi_v31_request_security_validator, - }, -) openapi_request_validator = DetectSpecRequestValidatorProxy( { ("openapi", "3.0"): openapi_v30_request_validator, diff --git a/openapi_core/validation/request/exceptions.py b/openapi_core/validation/request/exceptions.py index f141c351..54a02b8a 100644 --- a/openapi_core/validation/request/exceptions.py +++ b/openapi_core/validation/request/exceptions.py @@ -4,9 +4,9 @@ from openapi_core.exceptions import OpenAPIError from openapi_core.spec import Spec -from openapi_core.unmarshalling.schemas.exceptions import ValidateError from openapi_core.validation.exceptions import ValidationError from openapi_core.validation.request.datatypes import Parameters +from openapi_core.validation.schemas.exceptions import ValidateError @dataclass diff --git a/openapi_core/validation/request/proxies.py b/openapi_core/validation/request/proxies.py index e4d97604..bb6f49ec 100644 --- a/openapi_core/validation/request/proxies.py +++ b/openapi_core/validation/request/proxies.py @@ -22,16 +22,22 @@ class SpecRequestValidatorProxy: def __init__( self, - validator_cls: Type["BaseAPICallRequestValidator"], + unmarshaller_cls_name: str, deprecated: str = "RequestValidator", use: Optional[str] = None, - **validator_kwargs: Any, + **unmarshaller_kwargs: Any, ): - self.validator_cls = validator_cls - self.validator_kwargs = validator_kwargs + self.unmarshaller_cls_name = unmarshaller_cls_name + self.unmarshaller_kwargs = unmarshaller_kwargs self.deprecated = deprecated - self.use = use or self.validator_cls.__name__ + self.use = use or self.unmarshaller_cls_name + + @property + def unmarshaller_cls(self) -> Type["BaseAPICallRequestValidator"]: + from openapi_core.unmarshalling.request import unmarshallers + + return getattr(unmarshallers, self.unmarshaller_cls_name) def validate( self, @@ -43,10 +49,10 @@ def validate( f"{self.deprecated} is deprecated. Use {self.use} instead.", DeprecationWarning, ) - validator = self.validator_cls( - spec, base_url=base_url, **self.validator_kwargs + unmarshaller = self.unmarshaller_cls( + spec, base_url=base_url, **self.unmarshaller_kwargs ) - return validator.validate(request) + return unmarshaller.validate(request) def is_valid( self, @@ -54,10 +60,10 @@ def is_valid( request: Request, base_url: Optional[str] = None, ) -> bool: - validator = self.validator_cls( - spec, base_url=base_url, **self.validator_kwargs + unmarshaller = self.unmarshaller_cls( + spec, base_url=base_url, **self.unmarshaller_kwargs ) - error = next(validator.iter_errors(request), None) + error = next(unmarshaller.iter_errors(request), None) return error is None def iter_errors( @@ -66,10 +72,10 @@ def iter_errors( request: Request, base_url: Optional[str] = None, ) -> Iterator[Exception]: - validator = self.validator_cls( - spec, base_url=base_url, **self.validator_kwargs + unmarshaller = self.unmarshaller_cls( + spec, base_url=base_url, **self.unmarshaller_kwargs ) - yield from validator.iter_errors(request) + yield from unmarshaller.iter_errors(request) class DetectSpecRequestValidatorProxy: diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index 9547cbf3..80d823f1 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -4,6 +4,7 @@ from typing import Dict from typing import Iterator from typing import Optional +from typing import Tuple from urllib.parse import urljoin from openapi_core.casting.schemas import schema_casters_factory @@ -34,13 +35,12 @@ from openapi_core.templating.paths.finders import WebhookPathFinder from openapi_core.templating.security.exceptions import SecurityNotFound from openapi_core.unmarshalling.schemas import ( - oas30_request_schema_unmarshallers_factory, + oas30_write_schema_unmarshallers_factory, ) from openapi_core.unmarshalling.schemas import ( oas31_schema_unmarshallers_factory, ) from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError -from openapi_core.unmarshalling.schemas.exceptions import ValidateError from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) @@ -66,6 +66,11 @@ from openapi_core.validation.request.protocols import Request from openapi_core.validation.request.protocols import WebhookRequest from openapi_core.validation.request.proxies import SpecRequestValidatorProxy +from openapi_core.validation.schemas import ( + oas30_write_schema_validators_factory, +) +from openapi_core.validation.schemas import oas31_schema_validators_factory +from openapi_core.validation.schemas.factories import SchemaValidatorsFactory from openapi_core.validation.validators import BaseAPICallValidator from openapi_core.validation.validators import BaseValidator from openapi_core.validation.validators import BaseWebhookValidator @@ -76,21 +81,19 @@ def __init__( self, spec: Spec, base_url: Optional[str] = None, - schema_unmarshallers_factory: Optional[ - SchemaUnmarshallersFactory - ] = None, schema_casters_factory: SchemaCastersFactory = schema_casters_factory, parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory, media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, + schema_validators_factory: Optional[SchemaValidatorsFactory] = None, security_provider_factory: SecurityProviderFactory = security_provider_factory, ): super().__init__( spec, base_url=base_url, - schema_unmarshallers_factory=schema_unmarshallers_factory, schema_casters_factory=schema_casters_factory, parameter_deserializers_factory=parameter_deserializers_factory, media_type_deserializers_factory=media_type_deserializers_factory, + schema_validators_factory=schema_validators_factory, ) self.security_provider_factory = security_provider_factory @@ -283,25 +286,19 @@ def _get_security_value( @ValidationErrorWrapper(RequestBodyError, InvalidRequestBody) def _get_body( self, body: Optional[str], mimetype: str, operation: Spec - ) -> Any: + ) -> Tuple[Any, Optional[Spec]]: if "requestBody" not in operation: return None + # TODO: implement required flag checking request_body = operation / "requestBody" + content = request_body / "content" raw_body = self._get_body_value(body, request_body) - media_type, mimetype = self._get_media_type( - request_body / "content", mimetype + casted, _ = self._get_content_value_and_schema( + raw_body, mimetype, content ) - deserialised = self._deserialise_data(mimetype, raw_body) - casted = self._cast(media_type, deserialised) - - if "schema" not in media_type: - return casted - - schema = media_type / "schema" - unmarshalled = self._unmarshal(schema, casted) - return unmarshalled + return casted def _get_body_value(self, body: Optional[str], request_body: Spec) -> Any: if not body: @@ -428,55 +425,55 @@ def validate(self, request: WebhookRequest) -> RequestValidationResult: class V30RequestBodyValidator(APICallRequestBodyValidator): - schema_unmarshallers_factory = oas30_request_schema_unmarshallers_factory + schema_validators_factory = oas30_write_schema_validators_factory class V30RequestParametersValidator(APICallRequestParametersValidator): - schema_unmarshallers_factory = oas30_request_schema_unmarshallers_factory + schema_validators_factory = oas30_write_schema_validators_factory class V30RequestSecurityValidator(APICallRequestSecurityValidator): - schema_unmarshallers_factory = oas30_request_schema_unmarshallers_factory + schema_validators_factory = oas30_write_schema_validators_factory class V30RequestValidator(APICallRequestValidator): - schema_unmarshallers_factory = oas30_request_schema_unmarshallers_factory + schema_validators_factory = oas30_write_schema_validators_factory class V31RequestBodyValidator(APICallRequestBodyValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory class V31RequestParametersValidator(APICallRequestParametersValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory class V31RequestSecurityValidator(APICallRequestSecurityValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory class V31RequestValidator(APICallRequestValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory path_finder_cls = WebhookPathFinder class V31WebhookRequestBodyValidator(WebhookRequestBodyValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory path_finder_cls = WebhookPathFinder class V31WebhookRequestParametersValidator(WebhookRequestParametersValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory path_finder_cls = WebhookPathFinder class V31WebhookRequestSecurityValidator(WebhookRequestSecurityValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory path_finder_cls = WebhookPathFinder class V31WebhookRequestValidator(WebhookRequestValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory path_finder_cls = WebhookPathFinder @@ -488,7 +485,10 @@ def __init__( **kwargs: Any, ): super().__init__( - APICallRequestValidator, + "APICallRequestUnmarshaller", + schema_validators_factory=( + schema_unmarshallers_factory.schema_validators_factory + ), schema_unmarshallers_factory=schema_unmarshallers_factory, **kwargs, ) diff --git a/openapi_core/validation/response/__init__.py b/openapi_core/validation/response/__init__.py index fcf265b0..fcb2b036 100644 --- a/openapi_core/validation/response/__init__.py +++ b/openapi_core/validation/response/__init__.py @@ -2,7 +2,7 @@ from functools import partial from openapi_core.unmarshalling.schemas import ( - oas30_response_schema_unmarshallers_factory, + oas30_read_schema_unmarshallers_factory, ) from openapi_core.unmarshalling.schemas import ( oas31_schema_unmarshallers_factory, @@ -11,15 +11,6 @@ DetectResponseValidatorProxy, ) from openapi_core.validation.response.proxies import SpecResponseValidatorProxy -from openapi_core.validation.response.validators import ( - APICallResponseDataValidator, -) -from openapi_core.validation.response.validators import ( - APICallResponseHeadersValidator, -) -from openapi_core.validation.response.validators import ( - APICallResponseValidator, -) from openapi_core.validation.response.validators import ( V30ResponseDataValidator, ) @@ -56,17 +47,9 @@ "V31WebhookResponseValidator", "V3ResponseValidator", "V3WebhookResponseValidator", - "openapi_v30_response_data_validator", - "openapi_v30_response_headers_validator", "openapi_v30_response_validator", - "openapi_v31_response_data_validator", - "openapi_v31_response_headers_validator", "openapi_v31_response_validator", - "openapi_v3_response_data_validator", - "openapi_v3_response_headers_validator", "openapi_v3_response_validator", - "openapi_response_data_validator", - "openapi_response_headers_validator", "openapi_response_validator", ] @@ -75,54 +58,24 @@ V3WebhookResponseValidator = V31WebhookResponseValidator # spec validators -openapi_v30_response_data_validator = SpecResponseValidatorProxy( - APICallResponseDataValidator, - schema_unmarshallers_factory=oas30_response_schema_unmarshallers_factory, -) -openapi_v30_response_headers_validator = SpecResponseValidatorProxy( - APICallResponseHeadersValidator, - schema_unmarshallers_factory=oas30_response_schema_unmarshallers_factory, -) openapi_v30_response_validator = SpecResponseValidatorProxy( - APICallResponseValidator, - schema_unmarshallers_factory=oas30_response_schema_unmarshallers_factory, + "APICallResponseUnmarshaller", + schema_unmarshallers_factory=oas30_read_schema_unmarshallers_factory, deprecated="openapi_v30_response_validator", use="V30ResponseValidator", ) -openapi_v31_response_data_validator = SpecResponseValidatorProxy( - APICallResponseDataValidator, - schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, -) -openapi_v31_response_headers_validator = SpecResponseValidatorProxy( - APICallResponseHeadersValidator, - schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, -) openapi_v31_response_validator = SpecResponseValidatorProxy( - APICallResponseValidator, + "APICallResponseUnmarshaller", schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, deprecated="openapi_v31_response_validator", use="V31ResponseValidator", ) # spec validators alias to the latest v3 version -openapi_v3_response_data_validator = openapi_v31_response_data_validator -openapi_v3_response_headers_validator = openapi_v31_response_headers_validator openapi_v3_response_validator = openapi_v31_response_validator # detect version spec -openapi_response_data_validator = DetectResponseValidatorProxy( - { - ("openapi", "3.0"): openapi_v30_response_data_validator, - ("openapi", "3.1"): openapi_v31_response_data_validator, - }, -) -openapi_response_headers_validator = DetectResponseValidatorProxy( - { - ("openapi", "3.0"): openapi_v30_response_headers_validator, - ("openapi", "3.1"): openapi_v31_response_headers_validator, - }, -) openapi_response_validator = DetectResponseValidatorProxy( { ("openapi", "3.0"): openapi_v30_response_validator, diff --git a/openapi_core/validation/response/exceptions.py b/openapi_core/validation/response/exceptions.py index 078ec9b8..4f3b3e89 100644 --- a/openapi_core/validation/response/exceptions.py +++ b/openapi_core/validation/response/exceptions.py @@ -4,8 +4,8 @@ from typing import Iterable from openapi_core.exceptions import OpenAPIError -from openapi_core.unmarshalling.schemas.exceptions import ValidateError from openapi_core.validation.exceptions import ValidationError +from openapi_core.validation.schemas.exceptions import ValidateError @dataclass diff --git a/openapi_core/validation/response/proxies.py b/openapi_core/validation/response/proxies.py index b4e99469..1221fe22 100644 --- a/openapi_core/validation/response/proxies.py +++ b/openapi_core/validation/response/proxies.py @@ -23,16 +23,22 @@ class SpecResponseValidatorProxy: def __init__( self, - validator_cls: Type["BaseAPICallResponseValidator"], + unmarshaller_cls_name: Type["BaseAPICallResponseValidator"], deprecated: str = "ResponseValidator", use: Optional[str] = None, - **validator_kwargs: Any, + **unmarshaller_kwargs: Any, ): - self.validator_cls = validator_cls - self.validator_kwargs = validator_kwargs + self.unmarshaller_cls_name = unmarshaller_cls_name + self.unmarshaller_kwargs = unmarshaller_kwargs self.deprecated = deprecated - self.use = use or self.validator_cls.__name__ + self.use = use or self.unmarshaller_cls_name + + @property + def unmarshaller_cls(self) -> Type["BaseAPICallResponseValidator"]: + from openapi_core.unmarshalling.response import unmarshallers + + return getattr(unmarshallers, self.unmarshaller_cls_name) def validate( self, @@ -45,10 +51,10 @@ def validate( f"{self.deprecated} is deprecated. Use {self.use} instead.", DeprecationWarning, ) - validator = self.validator_cls( - spec, base_url=base_url, **self.validator_kwargs + unmarshaller = self.unmarshaller_cls( + spec, base_url=base_url, **self.unmarshaller_kwargs ) - return validator.validate(request, response) + return unmarshaller.validate(request, response) def is_valid( self, @@ -57,11 +63,11 @@ def is_valid( response: Response, base_url: Optional[str] = None, ) -> bool: - validator = self.validator_cls( - spec, base_url=base_url, **self.validator_kwargs + unmarshaller = self.unmarshaller_cls( + spec, base_url=base_url, **self.unmarshaller_kwargs ) error = next( - validator.iter_errors(request, response), + unmarshaller.iter_errors(request, response), None, ) return error is None @@ -73,10 +79,10 @@ def iter_errors( response: Response, base_url: Optional[str] = None, ) -> Iterator[Exception]: - validator = self.validator_cls( - spec, base_url=base_url, **self.validator_kwargs + unmarshaller = self.unmarshaller_cls( + spec, base_url=base_url, **self.unmarshaller_kwargs ) - yield from validator.iter_errors(request, response) + yield from unmarshaller.iter_errors(request, response) class DetectResponseValidatorProxy: diff --git a/openapi_core/validation/response/validators.py b/openapi_core/validation/response/validators.py index d04e9daa..b80e22c4 100644 --- a/openapi_core/validation/response/validators.py +++ b/openapi_core/validation/response/validators.py @@ -19,13 +19,12 @@ from openapi_core.templating.paths.finders import WebhookPathFinder from openapi_core.templating.responses.exceptions import ResponseFinderError from openapi_core.unmarshalling.schemas import ( - oas30_response_schema_unmarshallers_factory, + oas30_read_schema_unmarshallers_factory, ) from openapi_core.unmarshalling.schemas import ( oas31_schema_unmarshallers_factory, ) from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError -from openapi_core.unmarshalling.schemas.exceptions import ValidateError from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) @@ -45,6 +44,10 @@ from openapi_core.validation.response.exceptions import MissingRequiredHeader from openapi_core.validation.response.protocols import Response from openapi_core.validation.response.proxies import SpecResponseValidatorProxy +from openapi_core.validation.schemas import ( + oas30_read_schema_validators_factory, +) +from openapi_core.validation.schemas import oas31_schema_validators_factory from openapi_core.validation.validators import BaseAPICallValidator from openapi_core.validation.validators import BaseValidator from openapi_core.validation.validators import BaseWebhookValidator @@ -155,20 +158,13 @@ def _get_data( if "content" not in operation_response: return None - media_type, mimetype = self._get_media_type( - operation_response / "content", mimetype - ) - raw_data = self._get_data_value(data) - deserialised = self._deserialise_data(mimetype, raw_data) - casted = self._cast(media_type, deserialised) - - if "schema" not in media_type: - return casted + content = operation_response / "content" - schema = media_type / "schema" - data = self._unmarshal(schema, casted) - - return data + raw_data = self._get_data_value(data) + casted, _ = self._get_content_value_and_schema( + raw_data, mimetype, content + ) + return casted def _get_data_value(self, data: str) -> Any: if not data: @@ -374,39 +370,39 @@ def validate( class V30ResponseDataValidator(APICallResponseDataValidator): - schema_unmarshallers_factory = oas30_response_schema_unmarshallers_factory + schema_validators_factory = oas30_read_schema_validators_factory class V30ResponseHeadersValidator(APICallResponseHeadersValidator): - schema_unmarshallers_factory = oas30_response_schema_unmarshallers_factory + schema_validators_factory = oas30_read_schema_validators_factory class V30ResponseValidator(APICallResponseValidator): - schema_unmarshallers_factory = oas30_response_schema_unmarshallers_factory + schema_validators_factory = oas30_read_schema_validators_factory class V31ResponseDataValidator(APICallResponseDataValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory class V31ResponseHeadersValidator(APICallResponseHeadersValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory class V31ResponseValidator(APICallResponseValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory class V31WebhookResponseDataValidator(WebhookResponseDataValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory class V31WebhookResponseHeadersValidator(WebhookResponseHeadersValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory class V31WebhookResponseValidator(WebhookResponseValidator): - schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + schema_validators_factory = oas31_schema_validators_factory # backward compatibility @@ -417,7 +413,10 @@ def __init__( **kwargs: Any, ): super().__init__( - APICallResponseValidator, + "APICallResponseUnmarshaller", + schema_validators_factory=( + schema_unmarshallers_factory.schema_validators_factory + ), schema_unmarshallers_factory=schema_unmarshallers_factory, **kwargs, ) diff --git a/openapi_core/validation/schemas/__init__.py b/openapi_core/validation/schemas/__init__.py new file mode 100644 index 00000000..9d7e3143 --- /dev/null +++ b/openapi_core/validation/schemas/__init__.py @@ -0,0 +1,26 @@ +from openapi_schema_validator import OAS30ReadValidator +from openapi_schema_validator import OAS30WriteValidator +from openapi_schema_validator import OAS31Validator + +from openapi_core.validation.schemas.factories import SchemaValidatorsFactory + +__all__ = [ + "oas30_write_schema_validators_factory", + "oas30_read_schema_validators_factory", + "oas31_schema_validators_factory", +] + +oas30_write_schema_validators_factory = SchemaValidatorsFactory( + OAS30WriteValidator, +) + +oas30_read_schema_validators_factory = SchemaValidatorsFactory( + OAS30ReadValidator, +) + +oas31_schema_validators_factory = SchemaValidatorsFactory( + OAS31Validator, + # FIXME: OpenAPI 3.1 schema validator uses OpenAPI 3.0 format checker. + # See https://github.com/p1c2u/openapi-core/issues/506 + format_checker=OAS30ReadValidator.FORMAT_CHECKER, +) diff --git a/openapi_core/validation/schemas/datatypes.py b/openapi_core/validation/schemas/datatypes.py new file mode 100644 index 00000000..b3e398f9 --- /dev/null +++ b/openapi_core/validation/schemas/datatypes.py @@ -0,0 +1,4 @@ +from typing import Any +from typing import Callable + +FormatValidator = Callable[[Any], bool] diff --git a/openapi_core/validation/schemas/exceptions.py b/openapi_core/validation/schemas/exceptions.py new file mode 100644 index 00000000..437a273c --- /dev/null +++ b/openapi_core/validation/schemas/exceptions.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from dataclasses import field +from typing import Iterable + +from openapi_core.exceptions import OpenAPIError + + +class ValidateError(OpenAPIError): + """Schema validate operation error""" + + +@dataclass +class InvalidSchemaValue(ValidateError): + """Value not valid for schema""" + + value: str + type: str + schema_errors: Iterable[Exception] = field(default_factory=list) + + def __str__(self) -> str: + return ( + "Value {value} not valid for schema of type {type}: {errors}" + ).format(value=self.value, type=self.type, errors=self.schema_errors) diff --git a/openapi_core/validation/schemas/factories.py b/openapi_core/validation/schemas/factories.py new file mode 100644 index 00000000..3a0e9984 --- /dev/null +++ b/openapi_core/validation/schemas/factories.py @@ -0,0 +1,62 @@ +import sys +from copy import deepcopy +from functools import partial +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Type + +if sys.version_info >= (3, 8): + from functools import cached_property +else: + from backports.cached_property import cached_property +from jsonschema._format import FormatChecker +from jsonschema.protocols import Validator + +from openapi_core.spec import Spec +from openapi_core.unmarshalling.schemas.datatypes import CustomFormattersDict +from openapi_core.validation.schemas.datatypes import FormatValidator +from openapi_core.validation.schemas.util import build_format_checker +from openapi_core.validation.schemas.validators import SchemaValidator + + +class SchemaValidatorsFactory: + def __init__( + self, + schema_validator_class: Type[Validator], + format_checker: Optional[FormatChecker] = None, + formatters: Optional[CustomFormattersDict] = None, + custom_formatters: Optional[CustomFormattersDict] = None, + ): + self.schema_validator_class = schema_validator_class + if format_checker is None: + format_checker = self.schema_validator_class.FORMAT_CHECKER + self.format_checker = deepcopy(format_checker) + if formatters is None: + formatters = {} + self.formatters = formatters + if custom_formatters is None: + custom_formatters = {} + self.custom_formatters = custom_formatters + + def add_checks(self, **format_checks) -> None: + for name, check in format_checks.items(): + self.format_checker.checks(name)(check) + + def get_checker(self, name: str) -> FormatValidator: + if name in self.format_checker.checkers: + return partial(self.format_checker.check, format=name) + + return lambda x: True + + def create(self, schema: Spec) -> Validator: + resolver = schema.accessor.resolver # type: ignore + with schema.open() as schema_dict: + jsonschema_validator = self.schema_validator_class( + schema_dict, + resolver=resolver, + format_checker=self.format_checker, + ) + + return SchemaValidator(schema, jsonschema_validator) diff --git a/openapi_core/validation/schemas/util.py b/openapi_core/validation/schemas/util.py new file mode 100644 index 00000000..3290f0e3 --- /dev/null +++ b/openapi_core/validation/schemas/util.py @@ -0,0 +1,27 @@ +"""OpenAPI core validation schemas util module""" +from copy import deepcopy +from functools import lru_cache +from typing import Any +from typing import Callable +from typing import Optional + +from jsonschema._format import FormatChecker + + +@lru_cache() +def build_format_checker( + format_checker: Optional[FormatChecker] = None, + **format_checks: Callable[[Any], Any], +) -> Any: + if format_checker is None: + fc = FormatChecker() + else: + if not format_checks: + return format_checker + fc = deepcopy(format_checker) + + for name, check in format_checks.items(): + if name in fc.checkers: + continue + fc.checks(name)(check) + return fc diff --git a/openapi_core/validation/schemas/validators.py b/openapi_core/validation/schemas/validators.py new file mode 100644 index 00000000..b6866e96 --- /dev/null +++ b/openapi_core/validation/schemas/validators.py @@ -0,0 +1,137 @@ +import logging +import sys +from functools import partial +from typing import Any +from typing import Iterator +from typing import Optional + +from jsonschema.exceptions import FormatError +from jsonschema.protocols import Validator + +if sys.version_info >= (3, 8): + from functools import cached_property +else: + from backports.cached_property import cached_property + +from openapi_core import Spec +from openapi_core.validation.schemas.datatypes import FormatValidator +from openapi_core.validation.schemas.exceptions import InvalidSchemaValue +from openapi_core.validation.schemas.exceptions import ValidateError + +log = logging.getLogger(__name__) + + +class SchemaValidator: + def __init__( + self, + schema: Spec, + validator: Validator, + ): + self.schema = schema + self.validator = validator + + def validate(self, value: Any) -> None: + errors_iter = self.validator.iter_errors(value) + errors = tuple(errors_iter) + if errors: + schema_type = self.schema.getkey("type", "any") + raise InvalidSchemaValue(value, schema_type, schema_errors=errors) + + def evolve(self, schema: Spec) -> "SchemaValidator": + cls = self.__class__ + + with schema.open() as schema_dict: + return cls(schema, self.validator.evolve(schema=schema_dict)) + + def type_validator( + self, value: Any, type_override: Optional[str] = None + ) -> bool: + callable = self.get_type_validator_callable( + type_override=type_override + ) + return callable(value) + + def format_validator(self, value: Any) -> bool: + try: + self.format_validator_callable(value) + except FormatError: + return False + else: + return True + + def get_type_validator_callable( + self, type_override: Optional[str] = None + ) -> FormatValidator: + schema_type = type_override or self.schema.getkey("type") + if schema_type in self.validator.TYPE_CHECKER._type_checkers: + return partial( + self.validator.TYPE_CHECKER.is_type, type=schema_type + ) + + return lambda x: True + + @cached_property + def format_validator_callable(self) -> FormatValidator: + schema_format = self.schema.getkey("format") + if schema_format in self.validator.format_checker.checkers: + return partial( + self.validator.format_checker.check, format=schema_format + ) + + return lambda x: True + + def get_one_of_schema( + self, + value: Any, + ) -> Optional[Spec]: + if "oneOf" not in self.schema: + return None + + one_of_schemas = self.schema / "oneOf" + for subschema in one_of_schemas: + validator = self.evolve(subschema) + try: + validator.validate(value) + except ValidateError: + continue + else: + return subschema + + log.warning("valid oneOf schema not found") + return None + + def iter_any_of_schemas( + self, + value: Any, + ) -> Iterator[Spec]: + if "anyOf" not in self.schema: + return + + any_of_schemas = self.schema / "anyOf" + for subschema in any_of_schemas: + validator = self.evolve(subschema) + try: + validator.validate(value) + except ValidateError: + continue + else: + yield subschema + + def iter_all_of_schemas( + self, + value: Any, + ) -> Iterator[Spec]: + if "allOf" not in self.schema: + return + + all_of_schemas = self.schema / "allOf" + for subschema in all_of_schemas: + if "type" not in subschema: + continue + validator = self.evolve(subschema) + try: + validator.validate(value) + except ValidateError: + log.warning("invalid allOf schema found") + else: + yield subschema diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index 8e39c865..d656f377 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -39,38 +39,37 @@ from openapi_core.validation.request.protocols import Request from openapi_core.validation.request.protocols import SupportsPathPattern from openapi_core.validation.request.protocols import WebhookRequest +from openapi_core.validation.schemas.factories import SchemaValidatorsFactory class BaseValidator: + schema_validators_factory: SchemaValidatorsFactory = NotImplemented schema_unmarshallers_factory: SchemaUnmarshallersFactory = NotImplemented def __init__( self, spec: Spec, base_url: Optional[str] = None, - schema_unmarshallers_factory: Optional[ - SchemaUnmarshallersFactory - ] = None, schema_casters_factory: SchemaCastersFactory = schema_casters_factory, parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory, media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, + schema_validators_factory: Optional[SchemaValidatorsFactory] = None, ): self.spec = spec self.base_url = base_url - self.schema_unmarshallers_factory = ( - schema_unmarshallers_factory or self.schema_unmarshallers_factory - ) - if self.schema_unmarshallers_factory is NotImplemented: - raise NotImplementedError( - "schema_unmarshallers_factory is not assigned" - ) - self.schema_casters_factory = schema_casters_factory self.parameter_deserializers_factory = parameter_deserializers_factory self.media_type_deserializers_factory = ( media_type_deserializers_factory ) + self.schema_validators_factory = ( + schema_validators_factory or self.schema_validators_factory + ) + if self.schema_validators_factory is NotImplemented: + raise NotImplementedError( + "schema_validators_factory is not assigned" + ) def _get_media_type(self, content: Spec, mimetype: str) -> MediaType: from openapi_core.templating.media_types.finders import MediaTypeFinder @@ -90,16 +89,23 @@ def _cast(self, schema: Spec, value: Any) -> Any: caster = self.schema_casters_factory.create(schema) return caster(value) - def _unmarshal(self, schema: Spec, value: Any) -> Any: - unmarshaller = self.schema_unmarshallers_factory.create(schema) - return unmarshaller(value) - def _get_param_or_header_value( self, param_or_header: Spec, location: Mapping[str, Any], name: Optional[str] = None, ) -> Any: + casted, _ = self._get_param_or_header_value_and_schema( + param_or_header, location, name + ) + return casted + + def _get_param_or_header_value_and_schema( + self, + param_or_header: Spec, + location: Mapping[str, Any], + name: Optional[str] = None, + ) -> Tuple[Any, Spec]: try: raw_value = get_value(param_or_header, location, name=name) except KeyError: @@ -123,8 +129,20 @@ def _get_param_or_header_value( deserialised = self._deserialise_data(mimetype, raw_value) schema = media_type / "schema" casted = self._cast(schema, deserialised) - unmarshalled = self._unmarshal(schema, casted) - return unmarshalled + return casted, schema + + def _get_content_value_and_schema( + self, raw: Any, mimetype: str, content: Spec + ) -> Tuple[Any, Optional[Spec]]: + media_type, mimetype = self._get_media_type(content, mimetype) + deserialised = self._deserialise_data(mimetype, raw) + casted = self._cast(media_type, deserialised) + + if "schema" not in media_type: + return casted, None + + schema = media_type / "schema" + return casted, schema class BaseAPICallValidator(BaseValidator): diff --git a/poetry.lock b/poetry.lock index f3850319..d2b34bb0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1813,4 +1813,4 @@ starlette = [] [metadata] lock-version = "2.0" python-versions = "^3.7.0" -content-hash = "733c6dcfea0d5f94a264a3e954cfc8034d0a3c7be17bd88968ea387c8b19dff9" +content-hash = "12999bd418fc1271f5956ac8cfd48f3f09f3c3d09defabc658eb7ad9b4f338af" diff --git a/pyproject.toml b/pyproject.toml index 6c07c4b9..10ace2bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ flask = {version = "*", optional = true} isodate = "*" more-itertools = "*" parse = "*" -openapi-schema-validator = ">=0.3.0,<0.5" +openapi-schema-validator = "^0.4.2" openapi-spec-validator = "^0.5.0" requests = {version = "*", optional = true} werkzeug = "*" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 0fe4a4ba..29a31981 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -37,3 +37,15 @@ def factory(): spec_from_file=spec_from_file, spec_from_url=spec_from_url, ) + + +@pytest.fixture(scope="session") +def v30_petstore_content(factory): + content, _ = factory.content_from_file("data/v3.0/petstore.yaml") + return content + + +@pytest.fixture(scope="session") +def v30_petstore_spec(v30_petstore_content): + spec_url = "file://tests/integration/data/v3.0/petstore.yaml" + return Spec.from_dict(v30_petstore_content, spec_url=spec_url) diff --git a/tests/integration/unmarshalling/test_unmarshallers.py b/tests/integration/unmarshalling/test_unmarshallers.py index 2c3b6b65..4dc10dec 100644 --- a/tests/integration/unmarshalling/test_unmarshallers.py +++ b/tests/integration/unmarshalling/test_unmarshallers.py @@ -12,10 +12,10 @@ from openapi_core import Spec from openapi_core.unmarshalling.schemas import ( - oas30_request_schema_unmarshallers_factory, + oas30_read_schema_unmarshallers_factory, ) from openapi_core.unmarshalling.schemas import ( - oas30_response_schema_unmarshallers_factory, + oas30_write_schema_unmarshallers_factory, ) from openapi_core.unmarshalling.schemas import ( oas31_schema_unmarshallers_factory, @@ -23,11 +23,8 @@ from openapi_core.unmarshalling.schemas.exceptions import ( FormatterNotFoundError, ) -from openapi_core.unmarshalling.schemas.exceptions import ( - InvalidSchemaFormatValue, -) -from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError +from openapi_core.validation.schemas.exceptions import InvalidSchemaValue class BaseTestOASSchemaUnmarshallersFactoryCall: @@ -762,7 +759,7 @@ def test_object_any_of_invalid(self, unmarshallers_factory): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - with pytest.raises(UnmarshalError): + with pytest.raises(InvalidSchemaValue): unmarshaller({"someint": "1"}) def test_object_one_of_default(self, unmarshallers_factory): @@ -1846,7 +1843,7 @@ class TestOAS30RequestSchemaUnmarshallersFactory( ): @pytest.fixture def unmarshallers_factory(self): - return oas30_request_schema_unmarshallers_factory + return oas30_write_schema_unmarshallers_factory def test_write_only_properties(self, unmarshallers_factory): schema = { @@ -1894,7 +1891,7 @@ class TestOAS30ResponseSchemaUnmarshallersFactory( ): @pytest.fixture def unmarshallers_factory(self): - return oas30_response_schema_unmarshallers_factory + return oas30_read_schema_unmarshallers_factory def test_read_only_properties(self, unmarshallers_factory): schema = { diff --git a/tests/integration/validation/test_petstore.py b/tests/integration/validation/test_petstore.py index 41b12ea5..0f6a14bd 100644 --- a/tests/integration/validation/test_petstore.py +++ b/tests/integration/validation/test_petstore.py @@ -21,7 +21,6 @@ from openapi_core.templating.paths.exceptions import ServerNotFound from openapi_core.testing import MockRequest from openapi_core.testing import MockResponse -from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue from openapi_core.validation.request.datatypes import Parameters from openapi_core.validation.request.exceptions import MissingRequiredParameter from openapi_core.validation.request.exceptions import ParameterError @@ -41,6 +40,7 @@ from openapi_core.validation.response.validators import ( V30ResponseHeadersValidator, ) +from openapi_core.validation.schemas.exceptions import InvalidSchemaValue class TestPetstore: @@ -53,17 +53,12 @@ def api_key_encoded(self): return str(api_key_bytes_enc, "utf8") @pytest.fixture(scope="module") - def spec_uri(self): - return "file://tests/integration/data/v3.0/petstore.yaml" + def spec_dict(self, v30_petstore_content): + return v30_petstore_content @pytest.fixture(scope="module") - def spec_dict(self, factory): - content, _ = factory.content_from_file("data/v3.0/petstore.yaml") - return content - - @pytest.fixture(scope="module") - def spec(self, spec_dict, spec_uri): - return Spec.from_dict(spec_dict, spec_url=spec_uri) + def spec(self, v30_petstore_spec): + return v30_petstore_spec @pytest.fixture(scope="module") def request_parameters_validator(self, spec): diff --git a/tests/integration/validation/test_request_validator.py b/tests/integration/validation/test_request_validator.py new file mode 100644 index 00000000..3051d51f --- /dev/null +++ b/tests/integration/validation/test_request_validator.py @@ -0,0 +1,413 @@ +import json +from base64 import b64encode + +import pytest + +from openapi_core import V30RequestValidator +from openapi_core.templating.media_types.exceptions import MediaTypeNotFound +from openapi_core.templating.paths.exceptions import OperationNotFound +from openapi_core.templating.paths.exceptions import PathNotFound +from openapi_core.templating.security.exceptions import SecurityNotFound +from openapi_core.testing import MockRequest +from openapi_core.validation.request.datatypes import Parameters +from openapi_core.validation.request.exceptions import InvalidParameter +from openapi_core.validation.request.exceptions import MissingRequiredParameter +from openapi_core.validation.request.exceptions import ( + MissingRequiredRequestBody, +) +from openapi_core.validation.request.exceptions import RequestBodyError +from openapi_core.validation.request.exceptions import SecurityError + + +class TestRequestValidator: + host_url = "http://petstore.swagger.io" + + api_key = "12345" + + @property + def api_key_encoded(self): + api_key_bytes = self.api_key.encode("utf8") + api_key_bytes_enc = b64encode(api_key_bytes) + return str(api_key_bytes_enc, "utf8") + + @pytest.fixture(scope="session") + def spec_dict(self, v30_petstore_content): + return v30_petstore_content + + @pytest.fixture(scope="session") + def spec(self, v30_petstore_spec): + return v30_petstore_spec + + @pytest.fixture(scope="session") + def request_validator(self, spec): + return V30RequestValidator(spec) + + def test_request_server_error(self, request_validator): + request = MockRequest("http://petstore.invalid.net/v1", "get", "/") + + result = request_validator.validate(request) + + assert len(result.errors) == 1 + assert type(result.errors[0]) == PathNotFound + assert result.body is None + assert result.parameters == Parameters() + + def test_invalid_path(self, request_validator): + request = MockRequest(self.host_url, "get", "/v1") + + result = request_validator.validate(request) + + assert len(result.errors) == 1 + assert type(result.errors[0]) == PathNotFound + assert result.body is None + assert result.parameters == Parameters() + + def test_invalid_operation(self, request_validator): + request = MockRequest(self.host_url, "patch", "/v1/pets") + + result = request_validator.validate(request) + + assert len(result.errors) == 1 + assert type(result.errors[0]) == OperationNotFound + assert result.body is None + assert result.parameters == Parameters() + + def test_missing_parameter(self, request_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + + with pytest.warns(DeprecationWarning): + result = request_validator.validate(request) + + assert type(result.errors[0]) == MissingRequiredParameter + assert result.body is None + assert result.parameters == Parameters( + query={ + "page": 1, + "search": "", + }, + ) + + def test_get_pets(self, request_validator): + args = {"limit": "10", "ids": ["1", "2"], "api_key": self.api_key} + request = MockRequest( + self.host_url, + "get", + "/v1/pets", + path_pattern="/v1/pets", + args=args, + ) + + with pytest.warns(DeprecationWarning): + result = request_validator.validate(request) + + assert result.errors == [] + assert result.body is None + assert result.parameters == Parameters( + query={ + "limit": 10, + "page": 1, + "search": "", + "ids": [1, 2], + }, + ) + assert result.security == { + "api_key": self.api_key, + } + + def test_get_pets_webob(self, request_validator): + from webob.multidict import GetDict + + request = MockRequest( + self.host_url, + "get", + "/v1/pets", + path_pattern="/v1/pets", + ) + request.parameters.query = GetDict( + [("limit", "5"), ("ids", "1"), ("ids", "2")], {} + ) + + with pytest.warns(DeprecationWarning): + result = request_validator.validate(request) + + assert result.errors == [] + assert result.body is None + assert result.parameters == Parameters( + query={ + "limit": 5, + "page": 1, + "search": "", + "ids": [1, 2], + }, + ) + + def test_missing_body(self, request_validator): + headers = { + "api-key": self.api_key_encoded, + } + cookies = { + "user": "123", + } + request = MockRequest( + "https://development.gigantic-server.com", + "post", + "/v1/pets", + path_pattern="/v1/pets", + headers=headers, + cookies=cookies, + ) + + result = request_validator.validate(request) + + assert len(result.errors) == 1 + assert type(result.errors[0]) == MissingRequiredRequestBody + assert result.body is None + assert result.parameters == Parameters( + header={ + "api-key": self.api_key, + }, + cookie={ + "user": 123, + }, + ) + + def test_invalid_content_type(self, request_validator): + data = "csv,data" + headers = { + "api-key": self.api_key_encoded, + } + cookies = { + "user": "123", + } + request = MockRequest( + "https://development.gigantic-server.com", + "post", + "/v1/pets", + path_pattern="/v1/pets", + mimetype="text/csv", + data=data, + headers=headers, + cookies=cookies, + ) + + result = request_validator.validate(request) + + assert len(result.errors) == 1 + assert type(result.errors[0]) == RequestBodyError + assert result.errors[0].__cause__ == MediaTypeNotFound( + mimetype="text/csv", + availableMimetypes=["application/json", "text/plain"], + ) + assert result.body is None + assert result.parameters == Parameters( + header={ + "api-key": self.api_key, + }, + cookie={ + "user": 123, + }, + ) + + def test_invalid_complex_parameter(self, request_validator, spec_dict): + pet_name = "Cat" + pet_tag = "cats" + pet_street = "Piekna" + pet_city = "Warsaw" + data_json = { + "name": pet_name, + "tag": pet_tag, + "position": 2, + "address": { + "street": pet_street, + "city": pet_city, + }, + "ears": { + "healthy": True, + }, + } + data = json.dumps(data_json) + headers = { + "api-key": self.api_key_encoded, + } + userdata = { + "name": 1, + } + userdata_json = json.dumps(userdata) + cookies = { + "user": "123", + "userdata": userdata_json, + } + request = MockRequest( + "https://development.gigantic-server.com", + "post", + "/v1/pets", + path_pattern="/v1/pets", + data=data, + headers=headers, + cookies=cookies, + ) + + result = request_validator.validate(request) + + assert result.errors == [ + InvalidParameter(name="userdata", location="cookie") + ] + assert result.parameters == Parameters( + header={ + "api-key": self.api_key, + }, + cookie={ + "user": 123, + }, + ) + assert result.security == {} + + schemas = spec_dict["components"]["schemas"] + pet_model = schemas["PetCreate"]["x-model"] + address_model = schemas["Address"]["x-model"] + assert result.body.__class__.__name__ == pet_model + assert result.body.name == pet_name + assert result.body.tag == pet_tag + assert result.body.position == 2 + assert result.body.address.__class__.__name__ == address_model + assert result.body.address.street == pet_street + assert result.body.address.city == pet_city + + def test_post_pets(self, request_validator, spec_dict): + pet_name = "Cat" + pet_tag = "cats" + pet_street = "Piekna" + pet_city = "Warsaw" + data_json = { + "name": pet_name, + "tag": pet_tag, + "position": 2, + "address": { + "street": pet_street, + "city": pet_city, + }, + "ears": { + "healthy": True, + }, + } + data = json.dumps(data_json) + headers = { + "api-key": self.api_key_encoded, + } + cookies = { + "user": "123", + } + request = MockRequest( + "https://development.gigantic-server.com", + "post", + "/v1/pets", + path_pattern="/v1/pets", + data=data, + headers=headers, + cookies=cookies, + ) + + result = request_validator.validate(request) + + assert result.errors == [] + assert result.parameters == Parameters( + header={ + "api-key": self.api_key, + }, + cookie={ + "user": 123, + }, + ) + assert result.security == {} + + schemas = spec_dict["components"]["schemas"] + pet_model = schemas["PetCreate"]["x-model"] + address_model = schemas["Address"]["x-model"] + assert result.body.__class__.__name__ == pet_model + assert result.body.name == pet_name + assert result.body.tag == pet_tag + assert result.body.position == 2 + assert result.body.address.__class__.__name__ == address_model + assert result.body.address.street == pet_street + assert result.body.address.city == pet_city + + def test_post_pets_plain_no_schema(self, request_validator): + data = "plain text" + headers = { + "api-key": self.api_key_encoded, + } + cookies = { + "user": "123", + } + request = MockRequest( + "https://development.gigantic-server.com", + "post", + "/v1/pets", + path_pattern="/v1/pets", + data=data, + headers=headers, + cookies=cookies, + mimetype="text/plain", + ) + + with pytest.warns(UserWarning): + result = request_validator.validate(request) + + assert result.errors == [] + assert result.parameters == Parameters( + header={ + "api-key": self.api_key, + }, + cookie={ + "user": 123, + }, + ) + assert result.security == {} + assert result.body == data + + def test_get_pet_unauthorized(self, request_validator): + request = MockRequest( + self.host_url, + "get", + "/v1/pets/1", + path_pattern="/v1/pets/{petId}", + view_args={"petId": "1"}, + ) + + result = request_validator.validate(request) + + assert len(result.errors) == 1 + assert type(result.errors[0]) is SecurityError + assert result.errors[0].__cause__ == SecurityNotFound( + [["petstore_auth"]] + ) + assert result.body is None + assert result.parameters == Parameters() + assert result.security is None + + def test_get_pet(self, request_validator): + authorization = "Basic " + self.api_key_encoded + headers = { + "Authorization": authorization, + } + request = MockRequest( + self.host_url, + "get", + "/v1/pets/1", + path_pattern="/v1/pets/{petId}", + view_args={"petId": "1"}, + headers=headers, + ) + + result = request_validator.validate(request) + + assert result.errors == [] + assert result.body is None + assert result.parameters == Parameters( + path={ + "petId": 1, + }, + ) + assert result.security == { + "petstore_auth": self.api_key_encoded, + } diff --git a/tests/integration/validation/test_response_validator.py b/tests/integration/validation/test_response_validator.py new file mode 100644 index 00000000..fd1bf01c --- /dev/null +++ b/tests/integration/validation/test_response_validator.py @@ -0,0 +1,192 @@ +import json +from dataclasses import is_dataclass + +import pytest + +from openapi_core import V30ResponseValidator +from openapi_core.deserializing.media_types.exceptions import ( + MediaTypeDeserializeError, +) +from openapi_core.templating.media_types.exceptions import MediaTypeNotFound +from openapi_core.templating.paths.exceptions import OperationNotFound +from openapi_core.templating.paths.exceptions import PathNotFound +from openapi_core.templating.responses.exceptions import ResponseNotFound +from openapi_core.testing import MockRequest +from openapi_core.testing import MockResponse +from openapi_core.validation.response.exceptions import DataError +from openapi_core.validation.response.exceptions import InvalidData +from openapi_core.validation.response.exceptions import InvalidHeader +from openapi_core.validation.response.exceptions import MissingData +from openapi_core.validation.schemas.exceptions import InvalidSchemaValue + + +class TestResponseValidator: + host_url = "http://petstore.swagger.io" + + @pytest.fixture(scope="session") + def spec_dict(self, v30_petstore_content): + return v30_petstore_content + + @pytest.fixture(scope="session") + def spec(self, v30_petstore_spec): + return v30_petstore_spec + + @pytest.fixture(scope="session") + def response_validator(self, spec): + return V30ResponseValidator(spec) + + def test_invalid_server(self, response_validator): + request = MockRequest("http://petstore.invalid.net/v1", "get", "/") + response = MockResponse("Not Found", status_code=404) + + result = response_validator.validate(request, response) + + assert len(result.errors) == 1 + assert type(result.errors[0]) == PathNotFound + assert result.data is None + assert result.headers == {} + + def test_invalid_operation(self, response_validator): + request = MockRequest(self.host_url, "patch", "/v1/pets") + response = MockResponse("Not Found", status_code=404) + + result = response_validator.validate(request, response) + + assert len(result.errors) == 1 + assert type(result.errors[0]) == OperationNotFound + assert result.data is None + assert result.headers == {} + + def test_invalid_response(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + response = MockResponse("Not Found", status_code=409) + + result = response_validator.validate(request, response) + + assert len(result.errors) == 1 + assert type(result.errors[0]) == ResponseNotFound + assert result.data is None + assert result.headers == {} + + def test_invalid_content_type(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + response = MockResponse("Not Found", mimetype="text/csv") + + result = response_validator.validate(request, response) + + assert result.errors == [DataError()] + assert type(result.errors[0].__cause__) == MediaTypeNotFound + assert result.data is None + assert result.headers == {} + + def test_missing_body(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + response = MockResponse(None) + + result = response_validator.validate(request, response) + + assert result.errors == [MissingData()] + assert result.data is None + assert result.headers == {} + + def test_invalid_media_type(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + response = MockResponse("abcde") + + result = response_validator.validate(request, response) + + assert result.errors == [DataError()] + assert result.errors[0].__cause__ == MediaTypeDeserializeError( + mimetype="application/json", value="abcde" + ) + assert result.data is None + assert result.headers == {} + + def test_invalid_media_type_value(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + response = MockResponse("{}") + + result = response_validator.validate(request, response) + + assert result.errors == [InvalidData()] + assert type(result.errors[0].__cause__) == InvalidSchemaValue + assert result.data is None + assert result.headers == {} + + def test_invalid_value(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/tags") + response_json = { + "data": [ + {"id": 1, "name": "Sparky"}, + ], + } + response_data = json.dumps(response_json) + response = MockResponse(response_data) + + result = response_validator.validate(request, response) + + assert result.errors == [InvalidData()] + assert type(result.errors[0].__cause__) == InvalidSchemaValue + assert result.data is None + assert result.headers == {} + + def test_invalid_header(self, response_validator): + userdata = { + "name": 1, + } + userdata_json = json.dumps(userdata) + request = MockRequest( + self.host_url, + "delete", + "/v1/tags", + path_pattern="/v1/tags", + ) + response_json = { + "data": [ + { + "id": 1, + "name": "Sparky", + "ears": { + "healthy": True, + }, + }, + ], + } + response_data = json.dumps(response_json) + headers = { + "x-delete-confirm": "true", + "x-delete-date": "today", + } + response = MockResponse(response_data, headers=headers) + + with pytest.warns(DeprecationWarning): + result = response_validator.validate(request, response) + + assert result.errors == [InvalidHeader(name="x-delete-date")] + assert result.data is None + assert result.headers == {"x-delete-confirm": True} + + def test_get_pets(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + response_json = { + "data": [ + { + "id": 1, + "name": "Sparky", + "ears": { + "healthy": True, + }, + }, + ], + } + response_data = json.dumps(response_json) + response = MockResponse(response_data) + + result = response_validator.validate(request, response) + + assert result.errors == [] + assert is_dataclass(result.data) + assert len(result.data.data) == 1 + assert result.data.data[0].id == 1 + assert result.data.data[0].name == "Sparky" + assert result.headers == {} diff --git a/tests/integration/validation/test_validators.py b/tests/integration/validation/test_validators.py deleted file mode 100644 index 4149f2c6..00000000 --- a/tests/integration/validation/test_validators.py +++ /dev/null @@ -1,783 +0,0 @@ -import json -from base64 import b64encode -from dataclasses import is_dataclass - -import pytest - -from openapi_core import Spec -from openapi_core import V30RequestValidator -from openapi_core import V30ResponseValidator -from openapi_core import openapi_request_validator -from openapi_core.casting.schemas.exceptions import CastError -from openapi_core.deserializing.media_types.exceptions import ( - MediaTypeDeserializeError, -) -from openapi_core.templating.media_types.exceptions import MediaTypeNotFound -from openapi_core.templating.paths.exceptions import OperationNotFound -from openapi_core.templating.paths.exceptions import PathNotFound -from openapi_core.templating.responses.exceptions import ResponseNotFound -from openapi_core.templating.security.exceptions import SecurityNotFound -from openapi_core.testing import MockRequest -from openapi_core.testing import MockResponse -from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue -from openapi_core.validation.request.datatypes import Parameters -from openapi_core.validation.request.exceptions import InvalidParameter -from openapi_core.validation.request.exceptions import MissingRequiredParameter -from openapi_core.validation.request.exceptions import ( - MissingRequiredRequestBody, -) -from openapi_core.validation.request.exceptions import ParameterError -from openapi_core.validation.request.exceptions import RequestBodyError -from openapi_core.validation.request.exceptions import SecurityError -from openapi_core.validation.response.exceptions import DataError -from openapi_core.validation.response.exceptions import InvalidData -from openapi_core.validation.response.exceptions import InvalidHeader -from openapi_core.validation.response.exceptions import MissingData - - -class TestRequestValidator: - host_url = "http://petstore.swagger.io" - - api_key = "12345" - - @property - def api_key_encoded(self): - api_key_bytes = self.api_key.encode("utf8") - api_key_bytes_enc = b64encode(api_key_bytes) - return str(api_key_bytes_enc, "utf8") - - @pytest.fixture(scope="session") - def spec_dict(self, factory): - content, _ = factory.content_from_file("data/v3.0/petstore.yaml") - return content - - @pytest.fixture(scope="session") - def spec(self, spec_dict): - return Spec.from_dict(spec_dict) - - @pytest.fixture(scope="session") - def request_validator(self, spec): - return V30RequestValidator(spec) - - @pytest.fixture(scope="session") - def response_validator(self, spec): - return V30ResponseValidator(spec) - - def test_request_server_error(self, request_validator): - request = MockRequest("http://petstore.invalid.net/v1", "get", "/") - - result = request_validator.validate(request) - - assert len(result.errors) == 1 - assert type(result.errors[0]) == PathNotFound - assert result.body is None - assert result.parameters == Parameters() - - def test_invalid_path(self, request_validator): - request = MockRequest(self.host_url, "get", "/v1") - - result = request_validator.validate(request) - - assert len(result.errors) == 1 - assert type(result.errors[0]) == PathNotFound - assert result.body is None - assert result.parameters == Parameters() - - def test_invalid_operation(self, request_validator): - request = MockRequest(self.host_url, "patch", "/v1/pets") - - result = request_validator.validate(request) - - assert len(result.errors) == 1 - assert type(result.errors[0]) == OperationNotFound - assert result.body is None - assert result.parameters == Parameters() - - def test_missing_parameter(self, request_validator): - request = MockRequest(self.host_url, "get", "/v1/pets") - - with pytest.warns(DeprecationWarning): - result = request_validator.validate(request) - - assert type(result.errors[0]) == MissingRequiredParameter - assert result.body is None - assert result.parameters == Parameters( - query={ - "page": 1, - "search": "", - }, - ) - - def test_get_pets(self, request_validator): - args = {"limit": "10", "ids": ["1", "2"], "api_key": self.api_key} - request = MockRequest( - self.host_url, - "get", - "/v1/pets", - path_pattern="/v1/pets", - args=args, - ) - - with pytest.warns(DeprecationWarning): - result = request_validator.validate(request) - - assert result.errors == [] - assert result.body is None - assert result.parameters == Parameters( - query={ - "limit": 10, - "page": 1, - "search": "", - "ids": [1, 2], - }, - ) - assert result.security == { - "api_key": self.api_key, - } - - def test_get_pets_webob(self, request_validator): - from webob.multidict import GetDict - - request = MockRequest( - self.host_url, - "get", - "/v1/pets", - path_pattern="/v1/pets", - ) - request.parameters.query = GetDict( - [("limit", "5"), ("ids", "1"), ("ids", "2")], {} - ) - - with pytest.warns(DeprecationWarning): - result = request_validator.validate(request) - - assert result.errors == [] - assert result.body is None - assert result.parameters == Parameters( - query={ - "limit": 5, - "page": 1, - "search": "", - "ids": [1, 2], - }, - ) - - def test_missing_body(self, request_validator): - headers = { - "api-key": self.api_key_encoded, - } - cookies = { - "user": "123", - } - request = MockRequest( - "https://development.gigantic-server.com", - "post", - "/v1/pets", - path_pattern="/v1/pets", - headers=headers, - cookies=cookies, - ) - - with pytest.warns(DeprecationWarning): - result = request_validator.validate(request) - - assert len(result.errors) == 1 - assert type(result.errors[0]) == MissingRequiredRequestBody - assert result.body is None - assert result.parameters == Parameters( - header={ - "api-key": self.api_key, - }, - cookie={ - "user": 123, - }, - ) - - def test_invalid_content_type(self, request_validator): - data = "csv,data" - headers = { - "api-key": self.api_key_encoded, - } - cookies = { - "user": "123", - } - request = MockRequest( - "https://development.gigantic-server.com", - "post", - "/v1/pets", - path_pattern="/v1/pets", - mimetype="text/csv", - data=data, - headers=headers, - cookies=cookies, - ) - - with pytest.warns(DeprecationWarning): - result = request_validator.validate(request) - - assert len(result.errors) == 1 - assert type(result.errors[0]) == RequestBodyError - assert result.errors[0].__cause__ == MediaTypeNotFound( - mimetype="text/csv", - availableMimetypes=["application/json", "text/plain"], - ) - assert result.body is None - assert result.parameters == Parameters( - header={ - "api-key": self.api_key, - }, - cookie={ - "user": 123, - }, - ) - - def test_invalid_complex_parameter(self, request_validator, spec_dict): - pet_name = "Cat" - pet_tag = "cats" - pet_street = "Piekna" - pet_city = "Warsaw" - data_json = { - "name": pet_name, - "tag": pet_tag, - "position": 2, - "address": { - "street": pet_street, - "city": pet_city, - }, - "ears": { - "healthy": True, - }, - } - data = json.dumps(data_json) - headers = { - "api-key": self.api_key_encoded, - } - userdata = { - "name": 1, - } - userdata_json = json.dumps(userdata) - cookies = { - "user": "123", - "userdata": userdata_json, - } - request = MockRequest( - "https://development.gigantic-server.com", - "post", - "/v1/pets", - path_pattern="/v1/pets", - data=data, - headers=headers, - cookies=cookies, - ) - - with pytest.warns(DeprecationWarning): - result = request_validator.validate(request) - - assert result.errors == [ - InvalidParameter(name="userdata", location="cookie") - ] - assert result.parameters == Parameters( - header={ - "api-key": self.api_key, - }, - cookie={ - "user": 123, - }, - ) - assert result.security == {} - - schemas = spec_dict["components"]["schemas"] - pet_model = schemas["PetCreate"]["x-model"] - address_model = schemas["Address"]["x-model"] - assert result.body.__class__.__name__ == pet_model - assert result.body.name == pet_name - assert result.body.tag == pet_tag - assert result.body.position == 2 - assert result.body.address.__class__.__name__ == address_model - assert result.body.address.street == pet_street - assert result.body.address.city == pet_city - - def test_post_pets(self, request_validator, spec_dict): - pet_name = "Cat" - pet_tag = "cats" - pet_street = "Piekna" - pet_city = "Warsaw" - data_json = { - "name": pet_name, - "tag": pet_tag, - "position": 2, - "address": { - "street": pet_street, - "city": pet_city, - }, - "ears": { - "healthy": True, - }, - } - data = json.dumps(data_json) - headers = { - "api-key": self.api_key_encoded, - } - cookies = { - "user": "123", - } - request = MockRequest( - "https://development.gigantic-server.com", - "post", - "/v1/pets", - path_pattern="/v1/pets", - data=data, - headers=headers, - cookies=cookies, - ) - - with pytest.warns(DeprecationWarning): - result = request_validator.validate(request) - - assert result.errors == [] - assert result.parameters == Parameters( - header={ - "api-key": self.api_key, - }, - cookie={ - "user": 123, - }, - ) - assert result.security == {} - - schemas = spec_dict["components"]["schemas"] - pet_model = schemas["PetCreate"]["x-model"] - address_model = schemas["Address"]["x-model"] - assert result.body.__class__.__name__ == pet_model - assert result.body.name == pet_name - assert result.body.tag == pet_tag - assert result.body.position == 2 - assert result.body.address.__class__.__name__ == address_model - assert result.body.address.street == pet_street - assert result.body.address.city == pet_city - - def test_post_pets_plain_no_schema(self, request_validator): - data = "plain text" - headers = { - "api-key": self.api_key_encoded, - } - cookies = { - "user": "123", - } - request = MockRequest( - "https://development.gigantic-server.com", - "post", - "/v1/pets", - path_pattern="/v1/pets", - data=data, - headers=headers, - cookies=cookies, - mimetype="text/plain", - ) - - with pytest.warns(UserWarning): - result = request_validator.validate(request) - - assert result.errors == [] - assert result.parameters == Parameters( - header={ - "api-key": self.api_key, - }, - cookie={ - "user": 123, - }, - ) - assert result.security == {} - assert result.body == data - - def test_get_pet_unauthorized(self, request_validator): - request = MockRequest( - self.host_url, - "get", - "/v1/pets/1", - path_pattern="/v1/pets/{petId}", - view_args={"petId": "1"}, - ) - - result = request_validator.validate(request) - - assert len(result.errors) == 1 - assert type(result.errors[0]) is SecurityError - assert result.errors[0].__cause__ == SecurityNotFound( - [["petstore_auth"]] - ) - assert result.body is None - assert result.parameters == Parameters() - assert result.security is None - - def test_get_pet(self, request_validator): - authorization = "Basic " + self.api_key_encoded - headers = { - "Authorization": authorization, - } - request = MockRequest( - self.host_url, - "get", - "/v1/pets/1", - path_pattern="/v1/pets/{petId}", - view_args={"petId": "1"}, - headers=headers, - ) - - with pytest.warns(DeprecationWarning): - result = request_validator.validate(request) - - assert result.errors == [] - assert result.body is None - assert result.parameters == Parameters( - path={ - "petId": 1, - }, - ) - assert result.security == { - "petstore_auth": self.api_key_encoded, - } - - -class TestPathItemParamsValidator: - @pytest.fixture(scope="session") - def spec_dict(self): - return { - "openapi": "3.0.0", - "info": { - "title": "Test path item parameter validation", - "version": "0.1", - }, - "paths": { - "/resource": { - "parameters": [ - { - "name": "resId", - "in": "query", - "required": True, - "schema": { - "type": "integer", - }, - }, - ], - "get": { - "responses": { - "default": {"description": "Return the resource."} - } - }, - } - }, - } - - @pytest.fixture(scope="session") - def spec(self, spec_dict): - return Spec.from_dict(spec_dict) - - @pytest.fixture(scope="session") - def request_validator(self, spec): - return V30RequestValidator(spec) - - def test_request_missing_param(self, request_validator): - request = MockRequest("http://example.com", "get", "/resource") - - result = request_validator.validate(request) - - assert len(result.errors) == 1 - assert type(result.errors[0]) == MissingRequiredParameter - assert result.body is None - assert result.parameters == Parameters() - - def test_request_invalid_param(self, request_validator): - request = MockRequest( - "http://example.com", - "get", - "/resource", - args={"resId": "invalid"}, - ) - - result = request_validator.validate(request) - - assert result.errors == [ - ParameterError(name="resId", location="query") - ] - assert type(result.errors[0].__cause__) is CastError - assert result.body is None - assert result.parameters == Parameters() - - def test_request_valid_param(self, request_validator): - request = MockRequest( - "http://example.com", - "get", - "/resource", - args={"resId": "10"}, - ) - - with pytest.warns(DeprecationWarning): - result = request_validator.validate(request) - - assert len(result.errors) == 0 - assert result.body is None - assert result.parameters == Parameters(query={"resId": 10}) - - def test_request_override_param(self, spec, spec_dict): - # override path parameter on operation - spec_dict["paths"]["/resource"]["get"]["parameters"] = [ - { - # full valid parameter object required - "name": "resId", - "in": "query", - "required": False, - "schema": { - "type": "integer", - }, - } - ] - request = MockRequest("http://example.com", "get", "/resource") - with pytest.warns(DeprecationWarning): - result = openapi_request_validator.validate( - spec, request, base_url="http://example.com" - ) - - assert len(result.errors) == 0 - assert result.body is None - assert result.parameters == Parameters() - - def test_request_override_param_uniqueness(self, spec, spec_dict): - # add parameter on operation with same name as on path but - # different location - spec_dict["paths"]["/resource"]["get"]["parameters"] = [ - { - # full valid parameter object required - "name": "resId", - "in": "header", - "required": False, - "schema": { - "type": "integer", - }, - } - ] - request = MockRequest("http://example.com", "get", "/resource") - with pytest.warns(DeprecationWarning): - result = openapi_request_validator.validate( - spec, request, base_url="http://example.com" - ) - - assert len(result.errors) == 1 - assert type(result.errors[0]) == MissingRequiredParameter - assert result.body is None - assert result.parameters == Parameters() - - def test_request_object_deep_object_params(self, spec, spec_dict): - # override path parameter on operation - spec_dict["paths"]["/resource"]["parameters"] = [ - { - # full valid parameter object required - "name": "paramObj", - "in": "query", - "required": True, - "schema": { - "x-model": "paramObj", - "type": "object", - "properties": { - "count": {"type": "integer"}, - "name": {"type": "string"}, - }, - }, - "explode": True, - "style": "deepObject", - } - ] - - request = MockRequest( - "http://example.com", - "get", - "/resource", - args={"paramObj[count]": 2, "paramObj[name]": "John"}, - ) - with pytest.warns(DeprecationWarning): - result = openapi_request_validator.validate( - spec, request, base_url="http://example.com" - ) - - assert len(result.errors) == 0 - assert result.body is None - assert len(result.parameters.query) == 1 - assert is_dataclass(result.parameters.query["paramObj"]) - assert result.parameters.query["paramObj"].count == 2 - assert result.parameters.query["paramObj"].name == "John" - - -class TestResponseValidator: - host_url = "http://petstore.swagger.io" - - @pytest.fixture(scope="session") - def spec_dict(self, factory): - content, _ = factory.content_from_file("data/v3.0/petstore.yaml") - return content - - @pytest.fixture(scope="session") - def spec(self, spec_dict): - return Spec.from_dict(spec_dict) - - @pytest.fixture(scope="session") - def response_validator(self, spec): - return V30ResponseValidator(spec) - - def test_invalid_server(self, response_validator): - request = MockRequest("http://petstore.invalid.net/v1", "get", "/") - response = MockResponse("Not Found", status_code=404) - - result = response_validator.validate(request, response) - - assert len(result.errors) == 1 - assert type(result.errors[0]) == PathNotFound - assert result.data is None - assert result.headers == {} - - def test_invalid_operation(self, response_validator): - request = MockRequest(self.host_url, "patch", "/v1/pets") - response = MockResponse("Not Found", status_code=404) - - result = response_validator.validate(request, response) - - assert len(result.errors) == 1 - assert type(result.errors[0]) == OperationNotFound - assert result.data is None - assert result.headers == {} - - def test_invalid_response(self, response_validator): - request = MockRequest(self.host_url, "get", "/v1/pets") - response = MockResponse("Not Found", status_code=409) - - result = response_validator.validate(request, response) - - assert len(result.errors) == 1 - assert type(result.errors[0]) == ResponseNotFound - assert result.data is None - assert result.headers == {} - - def test_invalid_content_type(self, response_validator): - request = MockRequest(self.host_url, "get", "/v1/pets") - response = MockResponse("Not Found", mimetype="text/csv") - - result = response_validator.validate(request, response) - - assert result.errors == [DataError()] - assert type(result.errors[0].__cause__) == MediaTypeNotFound - assert result.data is None - assert result.headers == {} - - def test_missing_body(self, response_validator): - request = MockRequest(self.host_url, "get", "/v1/pets") - response = MockResponse(None) - - result = response_validator.validate(request, response) - - assert result.errors == [MissingData()] - assert result.data is None - assert result.headers == {} - - def test_invalid_media_type(self, response_validator): - request = MockRequest(self.host_url, "get", "/v1/pets") - response = MockResponse("abcde") - - result = response_validator.validate(request, response) - - assert result.errors == [DataError()] - assert result.errors[0].__cause__ == MediaTypeDeserializeError( - mimetype="application/json", value="abcde" - ) - assert result.data is None - assert result.headers == {} - - def test_invalid_media_type_value(self, response_validator): - request = MockRequest(self.host_url, "get", "/v1/pets") - response = MockResponse("{}") - - with pytest.warns(DeprecationWarning): - result = response_validator.validate(request, response) - - assert result.errors == [InvalidData()] - assert type(result.errors[0].__cause__) == InvalidSchemaValue - assert result.data is None - assert result.headers == {} - - def test_invalid_value(self, response_validator): - request = MockRequest(self.host_url, "get", "/v1/tags") - response_json = { - "data": [ - {"id": 1, "name": "Sparky"}, - ], - } - response_data = json.dumps(response_json) - response = MockResponse(response_data) - - with pytest.warns(DeprecationWarning): - result = response_validator.validate(request, response) - - assert result.errors == [InvalidData()] - assert type(result.errors[0].__cause__) == InvalidSchemaValue - assert result.data is None - assert result.headers == {} - - def test_invalid_header(self, response_validator): - userdata = { - "name": 1, - } - userdata_json = json.dumps(userdata) - request = MockRequest( - self.host_url, - "delete", - "/v1/tags", - path_pattern="/v1/tags", - ) - response_json = { - "data": [ - { - "id": 1, - "name": "Sparky", - "ears": { - "healthy": True, - }, - }, - ], - } - response_data = json.dumps(response_json) - headers = { - "x-delete-confirm": "true", - "x-delete-date": "today", - } - response = MockResponse(response_data, headers=headers) - - with pytest.warns(DeprecationWarning): - result = response_validator.validate(request, response) - - assert result.errors == [InvalidHeader(name="x-delete-date")] - assert result.data is None - assert result.headers == {"x-delete-confirm": True} - - def test_get_pets(self, response_validator): - request = MockRequest(self.host_url, "get", "/v1/pets") - response_json = { - "data": [ - { - "id": 1, - "name": "Sparky", - "ears": { - "healthy": True, - }, - }, - ], - } - response_data = json.dumps(response_json) - response = MockResponse(response_data) - - with pytest.warns(DeprecationWarning): - result = response_validator.validate(request, response) - - assert result.errors == [] - assert is_dataclass(result.data) - assert len(result.data.data) == 1 - assert result.data.data[0].id == 1 - assert result.data.data[0].name == "Sparky" - assert result.headers == {} diff --git a/tests/unit/unmarshalling/test_unmarshal.py b/tests/unit/unmarshalling/test_unmarshal.py index a512512d..15a604c0 100644 --- a/tests/unit/unmarshalling/test_unmarshal.py +++ b/tests/unit/unmarshalling/test_unmarshal.py @@ -1,40 +1,44 @@ from functools import partial import pytest -from openapi_schema_validator import OAS30Validator from openapi_core.spec.paths import Spec +from openapi_core.unmarshalling.schemas import oas30_write_types_unmarshaller from openapi_core.unmarshalling.schemas.exceptions import ( FormatterNotFoundError, ) -from openapi_core.unmarshalling.schemas.exceptions import ( - InvalidSchemaFormatValue, -) +from openapi_core.unmarshalling.schemas.exceptions import FormatUnmarshalError from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) from openapi_core.unmarshalling.schemas.formatters import Formatter +from openapi_core.validation.schemas import ( + oas30_write_schema_validators_factory, +) @pytest.fixture def schema_unmarshaller_factory(): def create_unmarshaller( - validator, schema, custom_formatters=None, context=None + validators_factory, schema, custom_formatters=None ): custom_formatters = custom_formatters or {} return SchemaUnmarshallersFactory( - validator, + validators_factory, + oas30_write_types_unmarshaller, custom_formatters=custom_formatters, - context=context, ).create(schema) return create_unmarshaller -class TestOAS30SchemaUnmarshallerUnmarshal: +class TestOAS30SchemaUnmarshallerCall: @pytest.fixture def unmarshaller_factory(self, schema_unmarshaller_factory): - return partial(schema_unmarshaller_factory, OAS30Validator) + return partial( + schema_unmarshaller_factory, + oas30_write_schema_validators_factory, + ) def test_schema_custom_format_invalid(self, unmarshaller_factory): class CustomFormatter(Formatter): @@ -51,19 +55,13 @@ def format(self, value): "format": "custom", } spec = Spec.from_dict(schema, validator=None) - value = "test" + value = "x" - with pytest.raises(InvalidSchemaFormatValue): + with pytest.raises(FormatUnmarshalError): unmarshaller_factory( spec, custom_formatters=custom_formatters, - ).unmarshal(value) - - -class TestOAS30SchemaUnmarshallerCall: - @pytest.fixture - def unmarshaller_factory(self, schema_unmarshaller_factory): - return partial(schema_unmarshaller_factory, OAS30Validator) + )(value) def test_string_format_custom(self, unmarshaller_factory): formatted = "x-custom" @@ -133,7 +131,7 @@ def format(self, value): custom_format: formatter, } - with pytest.raises(InvalidSchemaFormatValue): + with pytest.raises(FormatUnmarshalError): unmarshaller_factory(spec, custom_formatters=custom_formatters)( value ) diff --git a/tests/unit/unmarshalling/test_validate.py b/tests/unit/unmarshalling/test_validate.py index e5976f60..afaeac56 100644 --- a/tests/unit/unmarshalling/test_validate.py +++ b/tests/unit/unmarshalling/test_validate.py @@ -3,20 +3,20 @@ import pytest from openapi_core.spec.paths import Spec -from openapi_core.unmarshalling.schemas import ( - oas30_request_schema_unmarshallers_factory, -) from openapi_core.unmarshalling.schemas.exceptions import ( FormatterNotFoundError, ) -from openapi_core.unmarshalling.schemas.exceptions import InvalidSchemaValue +from openapi_core.validation.schemas import ( + oas30_write_schema_validators_factory, +) +from openapi_core.validation.schemas.exceptions import InvalidSchemaValue class TestSchemaValidate: @pytest.fixture def validator_factory(self): def create_validator(schema): - return oas30_request_schema_unmarshallers_factory.create(schema) + return oas30_write_schema_validators_factory.create(schema) return create_validator @@ -29,8 +29,7 @@ def test_string_format_custom_missing(self, validator_factory): spec = Spec.from_dict(schema, validator=None) value = "x" - with pytest.raises(FormatterNotFoundError): - validator_factory(spec).validate(value) + validator_factory(spec).validate(value) @pytest.mark.parametrize("value", [0, 1, 2]) def test_integer_minimum_invalid(self, value, validator_factory): diff --git a/tests/unit/validation/test_request_response_validators.py b/tests/unit/validation/test_request_response_validators.py index fc5a0b15..352e1e88 100644 --- a/tests/unit/validation/test_request_response_validators.py +++ b/tests/unit/validation/test_request_response_validators.py @@ -3,6 +3,7 @@ import pytest from openapi_schema_validator import OAS31Validator +from openapi_core.unmarshalling.schemas import oas31_types_unmarshaller from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) @@ -11,6 +12,7 @@ from openapi_core.validation import openapi_response_validator from openapi_core.validation.request.validators import RequestValidator from openapi_core.validation.response.validators import ResponseValidator +from openapi_core.validation.schemas import oas31_schema_validators_factory class BaseTestValidate: @@ -18,7 +20,8 @@ class BaseTestValidate: def schema_unmarshallers_factory(self): CUSTOM_FORMATTERS = {"custom": Formatter.from_callables()} return SchemaUnmarshallersFactory( - OAS31Validator, + oas31_schema_validators_factory, + oas31_types_unmarshaller, custom_formatters=CUSTOM_FORMATTERS, ) From a38426b001a56f0c92ddc1ef5b293621ca263394 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Wed, 15 Feb 2023 12:33:52 +0000 Subject: [PATCH 2/4] unmarshallers and validators restructure --- docs/customizations.rst | 1 - openapi_core/__init__.py | 54 +- openapi_core/contrib/django/middlewares.py | 18 +- openapi_core/contrib/django/requests.py | 2 +- openapi_core/contrib/falcon/middlewares.py | 34 +- openapi_core/contrib/falcon/requests.py | 2 +- openapi_core/contrib/flask/decorators.py | 36 +- openapi_core/contrib/flask/requests.py | 2 +- openapi_core/contrib/requests/requests.py | 2 +- openapi_core/contrib/starlette/requests.py | 2 +- openapi_core/contrib/werkzeug/requests.py | 2 +- openapi_core/datatypes.py | 42 ++ openapi_core/exceptions.py | 4 + openapi_core/finders.py | 49 ++ openapi_core/protocols.py | 146 +++++ openapi_core/security/providers.py | 2 +- openapi_core/shortcuts.py | 314 ++++++++++ openapi_core/testing/datatypes.py | 2 +- openapi_core/testing/requests.py | 2 +- .../datatypes.py | 2 +- openapi_core/unmarshalling/processors.py | 43 ++ .../unmarshalling/request/__init__.py | 66 ++ .../unmarshalling/request/datatypes.py | 15 + .../unmarshalling/request/protocols.py | 39 ++ .../request/proxies.py | 30 +- openapi_core/unmarshalling/request/types.py | 13 + .../unmarshalling/request/unmarshallers.py | 423 +++++++++++++ .../unmarshalling/response/__init__.py | 67 ++ .../response/datatypes.py | 6 +- .../unmarshalling/response/protocols.py | 46 ++ .../response/proxies.py | 34 +- openapi_core/unmarshalling/response/types.py | 13 + .../unmarshalling/response/unmarshallers.py | 328 ++++++++++ .../unmarshalling/schemas/__init__.py | 30 +- openapi_core/unmarshalling/schemas/enums.py | 7 - .../unmarshalling/schemas/factories.py | 54 +- .../unmarshalling/schemas/unmarshallers.py | 234 ++++--- openapi_core/unmarshalling/schemas/util.py | 17 - openapi_core/unmarshalling/unmarshallers.py | 88 +++ openapi_core/validation/__init__.py | 7 - openapi_core/validation/exceptions.py | 4 - openapi_core/validation/processors.py | 68 +-- openapi_core/validation/request/__init__.py | 41 -- openapi_core/validation/request/datatypes.py | 59 +- openapi_core/validation/request/exceptions.py | 2 +- openapi_core/validation/request/protocols.py | 121 +--- openapi_core/validation/request/types.py | 11 + openapi_core/validation/request/validators.py | 209 +++---- openapi_core/validation/response/__init__.py | 42 -- openapi_core/validation/response/protocols.py | 59 +- openapi_core/validation/response/types.py | 11 + .../validation/response/validators.py | 185 ++---- openapi_core/validation/schemas/factories.py | 43 +- openapi_core/validation/schemas/util.py | 27 - openapi_core/validation/schemas/validators.py | 2 +- openapi_core/validation/shortcuts.py | 178 ------ openapi_core/validation/validators.py | 31 +- .../contrib/flask/test_flask_decorator.py | 2 +- .../contrib/flask/test_flask_validator.py | 6 +- .../requests/test_requests_validation.py | 52 +- .../werkzeug/test_werkzeug_validation.py | 16 +- tests/integration/schema/test_spec.py | 6 +- .../{validation => }/test_minimal.py | 0 .../{validation => }/test_petstore.py | 240 ++++---- .../test_read_only_write_only.py | 32 +- .../test_request_unmarshaller.py} | 62 +- .../test_response_unmarshaller.py} | 50 +- .../test_security_override.py | 28 +- .../unmarshalling/test_unmarshallers.py | 192 +++--- .../validation/test_request_validators.py | 132 ++++ .../validation/test_response_validators.py | 160 +++++ tests/unit/{validation => }/conftest.py | 0 tests/unit/contrib/django/test_django.py | 2 +- .../unit/contrib/flask/test_flask_requests.py | 2 +- .../requests/test_requests_requests.py | 2 +- tests/unit/test_shortcuts.py | 570 ++++++++++++++++++ .../test_path_item_params_validator.py | 179 ++++++ ...arshal.py => test_schema_unmarshallers.py} | 115 ++-- .../test_request_response_validators.py | 48 +- .../test_schema_validators.py} | 5 - tests/unit/validation/test_shortcuts.py | 288 --------- 81 files changed, 3702 insertions(+), 1858 deletions(-) create mode 100644 openapi_core/datatypes.py create mode 100644 openapi_core/finders.py create mode 100644 openapi_core/protocols.py create mode 100644 openapi_core/shortcuts.py rename openapi_core/{validation => unmarshalling}/datatypes.py (91%) create mode 100644 openapi_core/unmarshalling/processors.py create mode 100644 openapi_core/unmarshalling/request/__init__.py create mode 100644 openapi_core/unmarshalling/request/datatypes.py create mode 100644 openapi_core/unmarshalling/request/protocols.py rename openapi_core/{validation => unmarshalling}/request/proxies.py (76%) create mode 100644 openapi_core/unmarshalling/request/types.py create mode 100644 openapi_core/unmarshalling/request/unmarshallers.py create mode 100644 openapi_core/unmarshalling/response/__init__.py rename openapi_core/{validation => unmarshalling}/response/datatypes.py (57%) create mode 100644 openapi_core/unmarshalling/response/protocols.py rename openapi_core/{validation => unmarshalling}/response/proxies.py (76%) create mode 100644 openapi_core/unmarshalling/response/types.py create mode 100644 openapi_core/unmarshalling/response/unmarshallers.py delete mode 100644 openapi_core/unmarshalling/schemas/enums.py create mode 100644 openapi_core/unmarshalling/unmarshallers.py create mode 100644 openapi_core/validation/request/types.py create mode 100644 openapi_core/validation/response/types.py delete mode 100644 openapi_core/validation/schemas/util.py delete mode 100644 openapi_core/validation/shortcuts.py rename tests/integration/{validation => }/test_minimal.py (100%) rename tests/integration/{validation => }/test_petstore.py (88%) rename tests/integration/{validation => unmarshalling}/test_read_only_write_only.py (71%) rename tests/integration/{validation/test_request_validator.py => unmarshalling/test_request_unmarshaller.py} (86%) rename tests/integration/{validation/test_response_validator.py => unmarshalling/test_response_unmarshaller.py} (78%) rename tests/integration/{validation => unmarshalling}/test_security_override.py (73%) create mode 100644 tests/integration/validation/test_request_validators.py create mode 100644 tests/integration/validation/test_response_validators.py rename tests/unit/{validation => }/conftest.py (100%) create mode 100644 tests/unit/test_shortcuts.py create mode 100644 tests/unit/unmarshalling/test_path_item_params_validator.py rename tests/unit/unmarshalling/{test_unmarshal.py => test_schema_unmarshallers.py} (75%) rename tests/unit/{unmarshalling/test_validate.py => validation/test_schema_validators.py} (98%) delete mode 100644 tests/unit/validation/test_shortcuts.py diff --git a/docs/customizations.rst b/docs/customizations.rst index 6d77de2e..70c12c9d 100644 --- a/docs/customizations.rst +++ b/docs/customizations.rst @@ -72,7 +72,6 @@ Here's how you could add support for a ``usdate`` format that handles dates of t schema_unmarshallers_factory = SchemaUnmarshallersFactory( OAS30Validator, custom_formatters=custom_formatters, - context=ValidationContext.RESPONSE, ) result = validate_response( diff --git a/openapi_core/__init__.py b/openapi_core/__init__.py index 4d8953b0..75f382c8 100644 --- a/openapi_core/__init__.py +++ b/openapi_core/__init__.py @@ -1,25 +1,41 @@ """OpenAPI core module""" +from openapi_core.shortcuts import unmarshal_request +from openapi_core.shortcuts import unmarshal_response +from openapi_core.shortcuts import unmarshal_webhook_request +from openapi_core.shortcuts import unmarshal_webhook_response +from openapi_core.shortcuts import validate_request +from openapi_core.shortcuts import validate_response from openapi_core.spec import Spec +from openapi_core.unmarshalling.request import RequestValidator +from openapi_core.unmarshalling.request import V3RequestUnmarshaller +from openapi_core.unmarshalling.request import V3WebhookRequestUnmarshaller +from openapi_core.unmarshalling.request import V30RequestUnmarshaller +from openapi_core.unmarshalling.request import V31RequestUnmarshaller +from openapi_core.unmarshalling.request import V31WebhookRequestUnmarshaller +from openapi_core.unmarshalling.request import openapi_request_validator +from openapi_core.unmarshalling.request import openapi_v3_request_validator +from openapi_core.unmarshalling.request import openapi_v30_request_validator +from openapi_core.unmarshalling.request import openapi_v31_request_validator +from openapi_core.unmarshalling.response import ResponseValidator +from openapi_core.unmarshalling.response import V3ResponseUnmarshaller +from openapi_core.unmarshalling.response import V3WebhookResponseUnmarshaller +from openapi_core.unmarshalling.response import V30ResponseUnmarshaller +from openapi_core.unmarshalling.response import V31ResponseUnmarshaller +from openapi_core.unmarshalling.response import V31WebhookResponseUnmarshaller +from openapi_core.unmarshalling.response import openapi_response_validator +from openapi_core.unmarshalling.response import openapi_v3_response_validator +from openapi_core.unmarshalling.response import openapi_v30_response_validator +from openapi_core.unmarshalling.response import openapi_v31_response_validator from openapi_core.validation.request import V3RequestValidator from openapi_core.validation.request import V3WebhookRequestValidator from openapi_core.validation.request import V30RequestValidator from openapi_core.validation.request import V31RequestValidator from openapi_core.validation.request import V31WebhookRequestValidator -from openapi_core.validation.request import openapi_request_validator -from openapi_core.validation.request import openapi_v3_request_validator -from openapi_core.validation.request import openapi_v30_request_validator -from openapi_core.validation.request import openapi_v31_request_validator from openapi_core.validation.response import V3ResponseValidator from openapi_core.validation.response import V3WebhookResponseValidator from openapi_core.validation.response import V30ResponseValidator from openapi_core.validation.response import V31ResponseValidator from openapi_core.validation.response import V31WebhookResponseValidator -from openapi_core.validation.response import openapi_response_validator -from openapi_core.validation.response import openapi_v3_response_validator -from openapi_core.validation.response import openapi_v30_response_validator -from openapi_core.validation.response import openapi_v31_response_validator -from openapi_core.validation.shortcuts import validate_request -from openapi_core.validation.shortcuts import validate_response __author__ = "Artur Maciag" __email__ = "maciag.artur@gmail.com" @@ -29,11 +45,25 @@ __all__ = [ "Spec", + "unmarshal_request", + "unmarshal_response", + "unmarshal_webhook_request", + "unmarshal_webhook_response", "validate_request", "validate_response", + "V30RequestUnmarshaller", + "V30ResponseUnmarshaller", + "V31RequestUnmarshaller", + "V31ResponseUnmarshaller", + "V31WebhookRequestUnmarshaller", + "V31WebhookResponseUnmarshaller", + "V3RequestUnmarshaller", + "V3ResponseUnmarshaller", + "V3WebhookRequestUnmarshaller", + "V3WebhookResponseUnmarshaller", "V30RequestValidator", - "V31RequestValidator", "V30ResponseValidator", + "V31RequestValidator", "V31ResponseValidator", "V31WebhookRequestValidator", "V31WebhookResponseValidator", @@ -41,6 +71,8 @@ "V3ResponseValidator", "V3WebhookRequestValidator", "V3WebhookResponseValidator", + "RequestValidator", + "ResponseValidator", "openapi_v3_request_validator", "openapi_v30_request_validator", "openapi_v31_request_validator", diff --git a/openapi_core/contrib/django/middlewares.py b/openapi_core/contrib/django/middlewares.py index 280fdacb..5950cff6 100644 --- a/openapi_core/contrib/django/middlewares.py +++ b/openapi_core/contrib/django/middlewares.py @@ -10,9 +10,11 @@ from openapi_core.contrib.django.handlers import DjangoOpenAPIErrorsHandler from openapi_core.contrib.django.requests import DjangoOpenAPIRequest from openapi_core.contrib.django.responses import DjangoOpenAPIResponse -from openapi_core.validation.processors import OpenAPIProcessor -from openapi_core.validation.request.datatypes import RequestValidationResult -from openapi_core.validation.response.datatypes import ResponseValidationResult +from openapi_core.unmarshalling.processors import UnmarshallingProcessor +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) class DjangoOpenAPIMiddleware: @@ -26,11 +28,11 @@ def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): if not hasattr(settings, "OPENAPI_SPEC"): raise ImproperlyConfigured("OPENAPI_SPEC not defined in settings") - self.validation_processor = OpenAPIProcessor(settings.OPENAPI_SPEC) + self.processor = UnmarshallingProcessor(settings.OPENAPI_SPEC) def __call__(self, request: HttpRequest) -> HttpResponse: openapi_request = self._get_openapi_request(request) - req_result = self.validation_processor.process_request(openapi_request) + req_result = self.processor.process_request(openapi_request) if req_result.errors: response = self._handle_request_errors(req_result, request) else: @@ -38,7 +40,7 @@ def __call__(self, request: HttpRequest) -> HttpResponse: response = self.get_response(request) openapi_response = self._get_openapi_response(response) - resp_result = self.validation_processor.process_response( + resp_result = self.processor.process_response( openapi_request, openapi_response ) if resp_result.errors: @@ -47,13 +49,13 @@ def __call__(self, request: HttpRequest) -> HttpResponse: return response def _handle_request_errors( - self, request_result: RequestValidationResult, req: HttpRequest + self, request_result: RequestUnmarshalResult, req: HttpRequest ) -> JsonResponse: return self.errors_handler.handle(request_result.errors, req, None) def _handle_response_errors( self, - response_result: ResponseValidationResult, + response_result: ResponseUnmarshalResult, req: HttpRequest, resp: HttpResponse, ) -> JsonResponse: diff --git a/openapi_core/contrib/django/requests.py b/openapi_core/contrib/django/requests.py index ac98d5d7..dffe0387 100644 --- a/openapi_core/contrib/django/requests.py +++ b/openapi_core/contrib/django/requests.py @@ -6,7 +6,7 @@ from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableMultiDict -from openapi_core.validation.request.datatypes import RequestParameters +from openapi_core.datatypes import RequestParameters # https://docs.djangoproject.com/en/stable/topics/http/urls/ # diff --git a/openapi_core/contrib/falcon/middlewares.py b/openapi_core/contrib/falcon/middlewares.py index bb44e03f..287ea5a9 100644 --- a/openapi_core/contrib/falcon/middlewares.py +++ b/openapi_core/contrib/falcon/middlewares.py @@ -10,14 +10,16 @@ from openapi_core.contrib.falcon.requests import FalconOpenAPIRequest from openapi_core.contrib.falcon.responses import FalconOpenAPIResponse from openapi_core.spec import Spec -from openapi_core.validation.processors import OpenAPIProcessor -from openapi_core.validation.request.datatypes import RequestValidationResult -from openapi_core.validation.request.protocols import RequestValidator -from openapi_core.validation.response.datatypes import ResponseValidationResult -from openapi_core.validation.response.protocols import ResponseValidator +from openapi_core.unmarshalling.processors import UnmarshallingProcessor +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult +from openapi_core.unmarshalling.request.types import RequestUnmarshallerType +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) +from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType -class FalconOpenAPIMiddleware(OpenAPIProcessor): +class FalconOpenAPIMiddleware(UnmarshallingProcessor): request_class = FalconOpenAPIRequest response_class = FalconOpenAPIResponse errors_handler = FalconOpenAPIErrorsHandler() @@ -25,16 +27,16 @@ class FalconOpenAPIMiddleware(OpenAPIProcessor): def __init__( self, spec: Spec, - request_validator_cls: Optional[Type[RequestValidator]] = None, - response_validator_cls: Optional[Type[ResponseValidator]] = None, + request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, + response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, request_class: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, response_class: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, errors_handler: Optional[FalconOpenAPIErrorsHandler] = None, ): super().__init__( spec, - request_validator_cls=request_validator_cls, - response_validator_cls=response_validator_cls, + request_unmarshaller_cls=request_unmarshaller_cls, + response_unmarshaller_cls=response_unmarshaller_cls, ) self.request_class = request_class or self.request_class self.response_class = response_class or self.response_class @@ -44,16 +46,16 @@ def __init__( def from_spec( cls, spec: Spec, - request_validator_cls: Optional[Type[RequestValidator]] = None, - response_validator_cls: Optional[Type[ResponseValidator]] = None, + request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, + response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, request_class: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, response_class: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, errors_handler: Optional[FalconOpenAPIErrorsHandler] = None, ) -> "FalconOpenAPIMiddleware": return cls( spec, - request_validator_cls=request_validator_cls, - response_validator_cls=response_validator_cls, + request_unmarshaller_cls=request_unmarshaller_cls, + response_unmarshaller_cls=response_unmarshaller_cls, request_class=request_class, response_class=response_class, errors_handler=errors_handler, @@ -82,7 +84,7 @@ def _handle_request_errors( self, req: Request, resp: Response, - request_result: RequestValidationResult, + request_result: RequestUnmarshalResult, ) -> None: return self.errors_handler.handle(req, resp, request_result.errors) @@ -90,7 +92,7 @@ def _handle_response_errors( self, req: Request, resp: Response, - response_result: ResponseValidationResult, + response_result: ResponseUnmarshalResult, ) -> None: return self.errors_handler.handle(req, resp, response_result.errors) diff --git a/openapi_core/contrib/falcon/requests.py b/openapi_core/contrib/falcon/requests.py index bb23586e..51d34ef0 100644 --- a/openapi_core/contrib/falcon/requests.py +++ b/openapi_core/contrib/falcon/requests.py @@ -9,7 +9,7 @@ from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableMultiDict -from openapi_core.validation.request.datatypes import RequestParameters +from openapi_core.datatypes import RequestParameters class FalconOpenAPIRequest: diff --git a/openapi_core/contrib/flask/decorators.py b/openapi_core/contrib/flask/decorators.py index 81778ca2..1da178ac 100644 --- a/openapi_core/contrib/flask/decorators.py +++ b/openapi_core/contrib/flask/decorators.py @@ -15,19 +15,21 @@ from openapi_core.contrib.flask.requests import FlaskOpenAPIRequest from openapi_core.contrib.flask.responses import FlaskOpenAPIResponse from openapi_core.spec import Spec -from openapi_core.validation.processors import OpenAPIProcessor -from openapi_core.validation.request.datatypes import RequestValidationResult -from openapi_core.validation.request.protocols import RequestValidator -from openapi_core.validation.response.datatypes import ResponseValidationResult -from openapi_core.validation.response.protocols import ResponseValidator +from openapi_core.unmarshalling.processors import UnmarshallingProcessor +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult +from openapi_core.unmarshalling.request.types import RequestUnmarshallerType +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) +from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType -class FlaskOpenAPIViewDecorator(OpenAPIProcessor): +class FlaskOpenAPIViewDecorator(UnmarshallingProcessor): def __init__( self, spec: Spec, - request_validator_cls: Optional[Type[RequestValidator]] = None, - response_validator_cls: Optional[Type[ResponseValidator]] = None, + request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, + response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, request_class: Type[FlaskOpenAPIRequest] = FlaskOpenAPIRequest, response_class: Type[FlaskOpenAPIResponse] = FlaskOpenAPIResponse, request_provider: Type[FlaskRequestProvider] = FlaskRequestProvider, @@ -37,8 +39,8 @@ def __init__( ): super().__init__( spec, - request_validator_cls=request_validator_cls, - response_validator_cls=response_validator_cls, + request_unmarshaller_cls=request_unmarshaller_cls, + response_unmarshaller_cls=response_unmarshaller_cls, ) self.request_class = request_class self.response_class = response_class @@ -68,7 +70,7 @@ def decorated(*args: Any, **kwargs: Any) -> Response: def _handle_request_view( self, - request_result: RequestValidationResult, + request_result: RequestUnmarshalResult, view: Callable[[Any], Response], *args: Any, **kwargs: Any @@ -79,12 +81,12 @@ def _handle_request_view( return make_response(rv) def _handle_request_errors( - self, request_result: RequestValidationResult + self, request_result: RequestUnmarshalResult ) -> Response: return self.openapi_errors_handler.handle(request_result.errors) def _handle_response_errors( - self, response_result: ResponseValidationResult + self, response_result: ResponseUnmarshalResult ) -> Response: return self.openapi_errors_handler.handle(response_result.errors) @@ -103,8 +105,8 @@ def _get_openapi_response( def from_spec( cls, spec: Spec, - request_validator_cls: Optional[Type[RequestValidator]] = None, - response_validator_cls: Optional[Type[ResponseValidator]] = None, + request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, + response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, request_class: Type[FlaskOpenAPIRequest] = FlaskOpenAPIRequest, response_class: Type[FlaskOpenAPIResponse] = FlaskOpenAPIResponse, request_provider: Type[FlaskRequestProvider] = FlaskRequestProvider, @@ -114,8 +116,8 @@ def from_spec( ) -> "FlaskOpenAPIViewDecorator": return cls( spec, - request_validator_cls=request_validator_cls, - response_validator_cls=response_validator_cls, + request_unmarshaller_cls=request_unmarshaller_cls, + response_unmarshaller_cls=response_unmarshaller_cls, request_class=request_class, response_class=response_class, request_provider=request_provider, diff --git a/openapi_core/contrib/flask/requests.py b/openapi_core/contrib/flask/requests.py index 656ad9b6..dfc21bdd 100644 --- a/openapi_core/contrib/flask/requests.py +++ b/openapi_core/contrib/flask/requests.py @@ -4,7 +4,7 @@ from werkzeug.datastructures import ImmutableMultiDict from openapi_core.contrib.werkzeug.requests import WerkzeugOpenAPIRequest -from openapi_core.validation.request.datatypes import RequestParameters +from openapi_core.datatypes import RequestParameters class FlaskOpenAPIRequest(WerkzeugOpenAPIRequest): diff --git a/openapi_core/contrib/requests/requests.py b/openapi_core/contrib/requests/requests.py index 90feaad8..70ae3fd2 100644 --- a/openapi_core/contrib/requests/requests.py +++ b/openapi_core/contrib/requests/requests.py @@ -11,7 +11,7 @@ from werkzeug.datastructures import ImmutableMultiDict from openapi_core.contrib.requests.protocols import SupportsCookieJar -from openapi_core.validation.request.datatypes import RequestParameters +from openapi_core.datatypes import RequestParameters class RequestsOpenAPIRequest: diff --git a/openapi_core/contrib/starlette/requests.py b/openapi_core/contrib/starlette/requests.py index 4073003d..fa9c8b4d 100644 --- a/openapi_core/contrib/starlette/requests.py +++ b/openapi_core/contrib/starlette/requests.py @@ -4,7 +4,7 @@ from asgiref.sync import AsyncToSync from starlette.requests import Request -from openapi_core.validation.request.datatypes import RequestParameters +from openapi_core.datatypes import RequestParameters class StarletteOpenAPIRequest: diff --git a/openapi_core/contrib/werkzeug/requests.py b/openapi_core/contrib/werkzeug/requests.py index 5bd726cc..1765c360 100644 --- a/openapi_core/contrib/werkzeug/requests.py +++ b/openapi_core/contrib/werkzeug/requests.py @@ -6,7 +6,7 @@ from werkzeug.datastructures import ImmutableMultiDict from werkzeug.wrappers import Request -from openapi_core.validation.request.datatypes import RequestParameters +from openapi_core.datatypes import RequestParameters # http://flask.pocoo.org/docs/1.0/quickstart/#variable-rules PATH_PARAMETER_PATTERN = r"<(?:(?:string|int|float|path|uuid):)?(\w+)>" diff --git a/openapi_core/datatypes.py b/openapi_core/datatypes.py new file mode 100644 index 00000000..d3ed7500 --- /dev/null +++ b/openapi_core/datatypes.py @@ -0,0 +1,42 @@ +"""OpenAPI core validation request datatypes module""" +from __future__ import annotations + +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Mapping + +from werkzeug.datastructures import Headers +from werkzeug.datastructures import ImmutableMultiDict + + +@dataclass +class RequestParameters: + """OpenAPI request parameters dataclass. + + Attributes: + query + Query string parameters as MultiDict. Must support getlist method. + header + Request headers as Headers. + cookie + Request cookies as MultiDict. + path + Path parameters as dict. Gets resolved against spec if empty. + """ + + query: Mapping[str, Any] = field(default_factory=ImmutableMultiDict) + header: Mapping[str, Any] = field(default_factory=Headers) + cookie: Mapping[str, Any] = field(default_factory=ImmutableMultiDict) + path: Mapping[str, Any] = field(default_factory=dict) + + def __getitem__(self, location: str) -> Any: + return getattr(self, location) + + +@dataclass +class Parameters: + query: Mapping[str, Any] = field(default_factory=dict) + header: Mapping[str, Any] = field(default_factory=dict) + cookie: Mapping[str, Any] = field(default_factory=dict) + path: Mapping[str, Any] = field(default_factory=dict) diff --git a/openapi_core/exceptions.py b/openapi_core/exceptions.py index 504173c5..707b2ae1 100644 --- a/openapi_core/exceptions.py +++ b/openapi_core/exceptions.py @@ -3,3 +3,7 @@ class OpenAPIError(Exception): pass + + +class SpecError(OpenAPIError): + pass diff --git a/openapi_core/finders.py b/openapi_core/finders.py new file mode 100644 index 00000000..9fbef8a1 --- /dev/null +++ b/openapi_core/finders.py @@ -0,0 +1,49 @@ +from typing import Mapping +from typing import NamedTuple +from typing import Optional +from typing import Type + +from openapi_core.exceptions import SpecError +from openapi_core.spec import Spec +from openapi_core.unmarshalling.request.types import RequestUnmarshallerType +from openapi_core.unmarshalling.request.types import ( + WebhookRequestUnmarshallerType, +) +from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType +from openapi_core.unmarshalling.response.types import ( + WebhookResponseUnmarshallerType, +) +from openapi_core.validation.request.types import RequestValidatorType +from openapi_core.validation.request.types import WebhookRequestValidatorType +from openapi_core.validation.response.types import ResponseValidatorType +from openapi_core.validation.response.types import WebhookResponseValidatorType +from openapi_core.validation.validators import BaseValidator + + +class SpecVersion(NamedTuple): + name: str + version: str + + +class SpecClasses(NamedTuple): + request_validator_cls: RequestValidatorType + response_validator_cls: ResponseValidatorType + webhook_request_validator_cls: Optional[WebhookRequestValidatorType] + webhook_response_validator_cls: Optional[WebhookResponseValidatorType] + request_unmarshaller_cls: RequestUnmarshallerType + response_unmarshaller_cls: ResponseUnmarshallerType + webhook_request_unmarshaller_cls: Optional[WebhookRequestUnmarshallerType] + webhook_response_unmarshaller_cls: Optional[ + WebhookResponseUnmarshallerType + ] + + +class SpecFinder: + def __init__(self, specs: Mapping[SpecVersion, SpecClasses]) -> None: + self.specs = specs + + def get_classes(self, spec: Spec) -> SpecClasses: + for v, classes in self.specs.items(): + if v.name in spec and spec[v.name].startswith(v.version): + return classes + raise SpecError("Spec schema version not detected") diff --git a/openapi_core/protocols.py b/openapi_core/protocols.py new file mode 100644 index 00000000..98015762 --- /dev/null +++ b/openapi_core/protocols.py @@ -0,0 +1,146 @@ +"""OpenAPI core protocols module""" +import sys +from typing import Any +from typing import Mapping +from typing import Optional + +if sys.version_info >= (3, 8): + from typing import Protocol + from typing import runtime_checkable +else: + from typing_extensions import Protocol + from typing_extensions import runtime_checkable + +from openapi_core.datatypes import RequestParameters + + +@runtime_checkable +class BaseRequest(Protocol): + parameters: RequestParameters + + @property + def method(self) -> str: + ... + + @property + def body(self) -> Optional[str]: + ... + + @property + def mimetype(self) -> str: + ... + + +@runtime_checkable +class Request(BaseRequest, Protocol): + """Request attributes protocol. + + Attributes: + host_url + Url with scheme and host + For example: + https://localhost:8000 + path + Request path + full_url_pattern + The matched url with scheme, host and path pattern. + For example: + https://localhost:8000/api/v1/pets + https://localhost:8000/api/v1/pets/{pet_id} + method + The request method, as lowercase string. + parameters + A RequestParameters object. Needs to supports path attribute setter + to write resolved path parameters. + body + The request body, as string. + mimetype + Like content type, but without parameters (eg, without charset, + type etc.) and always lowercase. + For example if the content type is "text/HTML; charset=utf-8" + the mimetype would be "text/html". + """ + + @property + def host_url(self) -> str: + ... + + @property + def path(self) -> str: + ... + + +@runtime_checkable +class WebhookRequest(BaseRequest, Protocol): + """Webhook request attributes protocol. + + Attributes: + name + Webhook name + method + The request method, as lowercase string. + parameters + A RequestParameters object. Needs to supports path attribute setter + to write resolved path parameters. + body + The request body, as string. + mimetype + Like content type, but without parameters (eg, without charset, + type etc.) and always lowercase. + For example if the content type is "text/HTML; charset=utf-8" + the mimetype would be "text/html". + """ + + @property + def name(self) -> str: + ... + + +@runtime_checkable +class SupportsPathPattern(Protocol): + """Supports path_pattern attribute protocol. + + You also need to provide path variables in RequestParameters. + + Attributes: + path_pattern + The matched path pattern. + For example: + /api/v1/pets/{pet_id} + """ + + @property + def path_pattern(self) -> str: + ... + + +@runtime_checkable +class Response(Protocol): + """Response protocol. + + Attributes: + data + The response body, as string. + status_code + The status code as integer. + headers + Response headers as Headers. + mimetype + Lowercase content type without charset. + """ + + @property + def data(self) -> str: + ... + + @property + def status_code(self) -> int: + ... + + @property + def mimetype(self) -> str: + ... + + @property + def headers(self) -> Mapping[str, Any]: + ... diff --git a/openapi_core/security/providers.py b/openapi_core/security/providers.py index 93aa465e..3864682b 100644 --- a/openapi_core/security/providers.py +++ b/openapi_core/security/providers.py @@ -1,9 +1,9 @@ import warnings from typing import Any +from openapi_core.datatypes import RequestParameters from openapi_core.security.exceptions import SecurityProviderError from openapi_core.spec import Spec -from openapi_core.validation.request.datatypes import RequestParameters class BaseProvider: diff --git a/openapi_core/shortcuts.py b/openapi_core/shortcuts.py new file mode 100644 index 00000000..91a5fc3e --- /dev/null +++ b/openapi_core/shortcuts.py @@ -0,0 +1,314 @@ +"""OpenAPI core validation shortcuts module""" +import warnings +from typing import Any +from typing import Dict +from typing import Optional +from typing import Union + +from openapi_core.exceptions import SpecError +from openapi_core.finders import SpecClasses +from openapi_core.finders import SpecFinder +from openapi_core.finders import SpecVersion +from openapi_core.protocols import Request +from openapi_core.protocols import Response +from openapi_core.protocols import WebhookRequest +from openapi_core.spec import Spec +from openapi_core.unmarshalling.request import V30RequestUnmarshaller +from openapi_core.unmarshalling.request import V31RequestUnmarshaller +from openapi_core.unmarshalling.request import V31WebhookRequestUnmarshaller +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult +from openapi_core.unmarshalling.request.protocols import RequestUnmarshaller +from openapi_core.unmarshalling.request.protocols import ( + WebhookRequestUnmarshaller, +) +from openapi_core.unmarshalling.request.proxies import ( + SpecRequestValidatorProxy, +) +from openapi_core.unmarshalling.request.types import AnyRequestUnmarshallerType +from openapi_core.unmarshalling.request.types import RequestUnmarshallerType +from openapi_core.unmarshalling.request.types import ( + WebhookRequestUnmarshallerType, +) +from openapi_core.unmarshalling.request.unmarshallers import ( + BaseAPICallRequestUnmarshaller, +) +from openapi_core.unmarshalling.request.unmarshallers import ( + BaseWebhookRequestUnmarshaller, +) +from openapi_core.unmarshalling.response import V30ResponseUnmarshaller +from openapi_core.unmarshalling.response import V31ResponseUnmarshaller +from openapi_core.unmarshalling.response import V31WebhookResponseUnmarshaller +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) +from openapi_core.unmarshalling.response.protocols import ResponseUnmarshaller +from openapi_core.unmarshalling.response.protocols import ( + WebhookResponseUnmarshaller, +) +from openapi_core.unmarshalling.response.proxies import ( + SpecResponseValidatorProxy, +) +from openapi_core.unmarshalling.response.types import ( + AnyResponseUnmarshallerType, +) +from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType +from openapi_core.unmarshalling.response.types import ( + WebhookResponseUnmarshallerType, +) +from openapi_core.validation.request import V30RequestValidator +from openapi_core.validation.request import V31RequestValidator +from openapi_core.validation.request import V31WebhookRequestValidator +from openapi_core.validation.response import V30ResponseValidator +from openapi_core.validation.response import V31ResponseValidator +from openapi_core.validation.response import V31WebhookResponseValidator + +AnyRequest = Union[Request, WebhookRequest] + +SPECS: Dict[SpecVersion, SpecClasses] = { + SpecVersion("openapi", "3.0"): SpecClasses( + V30RequestValidator, + V30ResponseValidator, + None, + None, + V30RequestUnmarshaller, + V30ResponseUnmarshaller, + None, + None, + ), + SpecVersion("openapi", "3.1"): SpecClasses( + V31RequestValidator, + V31ResponseValidator, + V31WebhookRequestValidator, + V31WebhookResponseValidator, + V31RequestUnmarshaller, + V31ResponseUnmarshaller, + V31WebhookRequestUnmarshaller, + V31WebhookResponseUnmarshaller, + ), +} + + +def get_classes(spec: Spec) -> SpecClasses: + return SpecFinder(SPECS).get_classes(spec) + + +def unmarshal_request( + request: Request, + spec: Spec, + base_url: Optional[str] = None, + cls: Optional[RequestUnmarshallerType] = None, + **unmarshaller_kwargs: Any, +) -> RequestUnmarshalResult: + if not isinstance(request, Request): + raise TypeError("'request' argument is not type of Request") + if not isinstance(spec, Spec): + raise TypeError("'spec' argument is not type of Spec") + if cls is None: + classes = get_classes(spec) + cls = classes.request_unmarshaller_cls + if not issubclass(cls, RequestUnmarshaller): + raise TypeError("'cls' argument is not type of RequestUnmarshaller") + v = cls(spec, base_url=base_url, **unmarshaller_kwargs) + result = v.unmarshal(request) + result.raise_for_errors() + return result + + +def unmarshal_webhook_request( + request: WebhookRequest, + spec: Spec, + base_url: Optional[str] = None, + cls: Optional[WebhookRequestUnmarshallerType] = None, + **unmarshaller_kwargs: Any, +) -> RequestUnmarshalResult: + if not isinstance(request, WebhookRequest): + raise TypeError("'request' argument is not type of WebhookRequest") + if not isinstance(spec, Spec): + raise TypeError("'spec' argument is not type of Spec") + if cls is None: + classes = get_classes(spec) + cls = classes.webhook_request_unmarshaller_cls + if cls is None: + raise SpecError("Unmarshaller class not found") + if not issubclass(cls, WebhookRequestUnmarshaller): + raise TypeError( + "'cls' argument is not type of WebhookRequestUnmarshaller" + ) + v = cls(spec, base_url=base_url, **unmarshaller_kwargs) + result = v.unmarshal(request) + result.raise_for_errors() + return result + + +def unmarshal_response( + request: Request, + response: Response, + spec: Spec, + base_url: Optional[str] = None, + cls: Optional[ResponseUnmarshallerType] = None, + **unmarshaller_kwargs: Any, +) -> ResponseUnmarshalResult: + if not isinstance(request, Request): + raise TypeError("'request' argument is not type of Request") + if not isinstance(response, Response): + raise TypeError("'response' argument is not type of Response") + if not isinstance(spec, Spec): + raise TypeError("'spec' argument is not type of Spec") + if cls is None: + classes = get_classes(spec) + cls = classes.response_unmarshaller_cls + if not issubclass(cls, ResponseUnmarshaller): + raise TypeError("'cls' argument is not type of ResponseUnmarshaller") + v = cls(spec, base_url=base_url, **unmarshaller_kwargs) + result = v.unmarshal(request, response) + result.raise_for_errors() + return result + + +def unmarshal_webhook_response( + request: WebhookRequest, + response: Response, + spec: Spec, + base_url: Optional[str] = None, + cls: Optional[WebhookResponseUnmarshallerType] = None, + **unmarshaller_kwargs: Any, +) -> ResponseUnmarshalResult: + if not isinstance(request, WebhookRequest): + raise TypeError("'request' argument is not type of WebhookRequest") + if not isinstance(response, Response): + raise TypeError("'response' argument is not type of Response") + if not isinstance(spec, Spec): + raise TypeError("'spec' argument is not type of Spec") + if cls is None: + classes = get_classes(spec) + cls = classes.webhook_response_unmarshaller_cls + if cls is None: + raise SpecError("Unmarshaller class not found") + if not issubclass(cls, WebhookResponseUnmarshaller): + raise TypeError( + "'cls' argument is not type of WebhookResponseUnmarshaller" + ) + v = cls(spec, base_url=base_url, **unmarshaller_kwargs) + result = v.unmarshal(request, response) + result.raise_for_errors() + return result + + +def validate_request( + request: AnyRequest, + spec: Spec, + base_url: Optional[str] = None, + validator: Optional[SpecRequestValidatorProxy] = None, + cls: Optional[AnyRequestUnmarshallerType] = None, + **validator_kwargs: Any, +) -> RequestUnmarshalResult: + if isinstance(spec, (Request, WebhookRequest)) and isinstance( + request, Spec + ): + warnings.warn( + "spec parameter as a first argument is deprecated. " + "Move it to second argument instead.", + DeprecationWarning, + ) + request, spec = spec, request + + if not isinstance(request, (Request, WebhookRequest)): + raise TypeError("'request' argument is not type of (Webhook)Request") + if not isinstance(spec, Spec): + raise TypeError("'spec' argument is not type of Spec") + + if validator is not None and isinstance(request, Request): + warnings.warn( + "validator parameter is deprecated. Use cls instead.", + DeprecationWarning, + ) + result = validator.validate(spec, request, base_url=base_url) + result.raise_for_errors() + return result + + if isinstance(request, WebhookRequest): + if cls is None or issubclass(cls, WebhookRequestUnmarshaller): + return unmarshal_webhook_request( + request, spec, base_url=base_url, cls=cls, **validator_kwargs + ) + else: + raise TypeError( + "'cls' argument is not type of WebhookRequestUnmarshaller" + ) + elif isinstance(request, Request): + if cls is None or issubclass(cls, RequestUnmarshaller): + return unmarshal_request( + request, spec, base_url=base_url, cls=cls, **validator_kwargs + ) + else: + raise TypeError( + "'cls' argument is not type of RequestUnmarshaller" + ) + + +def validate_response( + request: Union[Request, WebhookRequest, Spec], + response: Union[Response, Request, WebhookRequest], + spec: Union[Spec, Response], + base_url: Optional[str] = None, + validator: Optional[SpecResponseValidatorProxy] = None, + cls: Optional[AnyResponseUnmarshallerType] = None, + **validator_kwargs: Any, +) -> ResponseUnmarshalResult: + if ( + isinstance(request, Spec) + and isinstance(response, (Request, WebhookRequest)) + and isinstance(spec, Response) + ): + warnings.warn( + "spec parameter as a first argument is deprecated. " + "Move it to third argument instead.", + DeprecationWarning, + ) + args = request, response, spec + spec, request, response = args + + if not isinstance(request, (Request, WebhookRequest)): + raise TypeError("'request' argument is not type of (Webhook)Request") + if not isinstance(response, Response): + raise TypeError("'response' argument is not type of Response") + if not isinstance(spec, Spec): + raise TypeError("'spec' argument is not type of Spec") + + if validator is not None and isinstance(request, Request): + warnings.warn( + "validator parameter is deprecated. Use cls instead.", + DeprecationWarning, + ) + result = validator.validate(spec, request, response, base_url=base_url) + result.raise_for_errors() + return result + + if isinstance(request, WebhookRequest): + if cls is None or issubclass(cls, WebhookResponseUnmarshaller): + return unmarshal_webhook_response( + request, + response, + spec, + base_url=base_url, + cls=cls, + **validator_kwargs, + ) + else: + raise TypeError( + "'cls' argument is not type of WebhookResponseUnmarshaller" + ) + elif isinstance(request, Request): + if cls is None or issubclass(cls, ResponseUnmarshaller): + return unmarshal_response( + request, + response, + spec, + base_url=base_url, + cls=cls, + **validator_kwargs, + ) + else: + raise TypeError( + "'cls' argument is not type of ResponseUnmarshaller" + ) diff --git a/openapi_core/testing/datatypes.py b/openapi_core/testing/datatypes.py index 7bdc3a0e..8f4ee138 100644 --- a/openapi_core/testing/datatypes.py +++ b/openapi_core/testing/datatypes.py @@ -1,6 +1,6 @@ from typing import Optional -from openapi_core.validation.request.datatypes import Parameters +from openapi_core.datatypes import Parameters class ResultMock: diff --git a/openapi_core/testing/requests.py b/openapi_core/testing/requests.py index 9df4827c..49357fda 100644 --- a/openapi_core/testing/requests.py +++ b/openapi_core/testing/requests.py @@ -6,7 +6,7 @@ from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableMultiDict -from openapi_core.validation.request.datatypes import RequestParameters +from openapi_core.datatypes import RequestParameters class MockRequest: diff --git a/openapi_core/validation/datatypes.py b/openapi_core/unmarshalling/datatypes.py similarity index 91% rename from openapi_core/validation/datatypes.py rename to openapi_core/unmarshalling/datatypes.py index 4bece8f5..78036dda 100644 --- a/openapi_core/validation/datatypes.py +++ b/openapi_core/unmarshalling/datatypes.py @@ -6,7 +6,7 @@ @dataclass -class BaseValidationResult: +class BaseUnmarshalResult: errors: Iterable[OpenAPIError] def raise_for_errors(self) -> None: diff --git a/openapi_core/unmarshalling/processors.py b/openapi_core/unmarshalling/processors.py new file mode 100644 index 00000000..b2200a90 --- /dev/null +++ b/openapi_core/unmarshalling/processors.py @@ -0,0 +1,43 @@ +"""OpenAPI core unmarshalling processors module""" +from typing import Optional +from typing import Type + +from openapi_core.protocols import Request +from openapi_core.protocols import Response +from openapi_core.shortcuts import get_classes +from openapi_core.spec import Spec +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult +from openapi_core.unmarshalling.request.types import RequestUnmarshallerType +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) +from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType + + +class UnmarshallingProcessor: + def __init__( + self, + spec: Spec, + request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, + response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, + ): + self.spec = spec + if ( + request_unmarshaller_cls is None + or response_unmarshaller_cls is None + ): + classes = get_classes(self.spec) + if request_unmarshaller_cls is None: + request_unmarshaller_cls = classes.request_unmarshaller_cls + if response_unmarshaller_cls is None: + response_unmarshaller_cls = classes.response_unmarshaller_cls + self.request_unmarshaller = request_unmarshaller_cls(self.spec) + self.response_unmarshaller = response_unmarshaller_cls(self.spec) + + def process_request(self, request: Request) -> RequestUnmarshalResult: + return self.request_unmarshaller.unmarshal(request) + + def process_response( + self, request: Request, response: Response + ) -> ResponseUnmarshalResult: + return self.response_unmarshaller.unmarshal(request, response) diff --git a/openapi_core/unmarshalling/request/__init__.py b/openapi_core/unmarshalling/request/__init__.py new file mode 100644 index 00000000..710f17df --- /dev/null +++ b/openapi_core/unmarshalling/request/__init__.py @@ -0,0 +1,66 @@ +"""OpenAPI core unmarshalling request module""" +from openapi_core.unmarshalling.request.proxies import ( + DetectSpecRequestValidatorProxy, +) +from openapi_core.unmarshalling.request.proxies import ( + SpecRequestValidatorProxy, +) +from openapi_core.unmarshalling.request.unmarshallers import ( + APICallRequestUnmarshaller, +) +from openapi_core.unmarshalling.request.unmarshallers import RequestValidator +from openapi_core.unmarshalling.request.unmarshallers import ( + V30RequestUnmarshaller, +) +from openapi_core.unmarshalling.request.unmarshallers import ( + V31RequestUnmarshaller, +) +from openapi_core.unmarshalling.request.unmarshallers import ( + V31WebhookRequestUnmarshaller, +) +from openapi_core.unmarshalling.schemas import ( + oas30_write_schema_unmarshallers_factory, +) +from openapi_core.unmarshalling.schemas import ( + oas31_schema_unmarshallers_factory, +) + +__all__ = [ + "V30RequestUnmarshaller", + "V31RequestUnmarshaller", + "V31WebhookRequestUnmarshaller", + "RequestValidator", + "openapi_v30_request_validator", + "openapi_v31_request_validator", + "openapi_v3_request_validator", + "openapi_request_validator", +] + +# alias to the latest v3 version +V3RequestUnmarshaller = V31RequestUnmarshaller +V3WebhookRequestUnmarshaller = V31WebhookRequestUnmarshaller + +# spec validators +openapi_v30_request_validator = SpecRequestValidatorProxy( + APICallRequestUnmarshaller, + schema_unmarshallers_factory=oas30_write_schema_unmarshallers_factory, + deprecated="openapi_v30_request_validator", + use="V30RequestValidator", +) +openapi_v31_request_validator = SpecRequestValidatorProxy( + APICallRequestUnmarshaller, + schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, + deprecated="openapi_v31_request_validator", + use="V31RequestValidator", +) + +# spec validators alias to the latest v3 version +openapi_v3_request_validator = openapi_v31_request_validator + +# detect version spec +openapi_request_validator = DetectSpecRequestValidatorProxy( + { + ("openapi", "3.0"): openapi_v30_request_validator, + ("openapi", "3.1"): openapi_v31_request_validator, + }, +) diff --git a/openapi_core/unmarshalling/request/datatypes.py b/openapi_core/unmarshalling/request/datatypes.py new file mode 100644 index 00000000..739d2bf8 --- /dev/null +++ b/openapi_core/unmarshalling/request/datatypes.py @@ -0,0 +1,15 @@ +"""OpenAPI core unmarshalling request datatypes module""" +from __future__ import annotations + +from dataclasses import dataclass +from dataclasses import field + +from openapi_core.datatypes import Parameters +from openapi_core.unmarshalling.datatypes import BaseUnmarshalResult + + +@dataclass +class RequestUnmarshalResult(BaseUnmarshalResult): + body: str | None = None + parameters: Parameters = field(default_factory=Parameters) + security: dict[str, str] | None = None diff --git a/openapi_core/unmarshalling/request/protocols.py b/openapi_core/unmarshalling/request/protocols.py new file mode 100644 index 00000000..2fee6437 --- /dev/null +++ b/openapi_core/unmarshalling/request/protocols.py @@ -0,0 +1,39 @@ +"""OpenAPI core validation request protocols module""" +import sys +from typing import Optional + +if sys.version_info >= (3, 8): + from typing import Protocol + from typing import runtime_checkable +else: + from typing_extensions import Protocol + from typing_extensions import runtime_checkable + +from openapi_core.protocols import Request +from openapi_core.protocols import WebhookRequest +from openapi_core.spec import Spec +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult + + +@runtime_checkable +class RequestUnmarshaller(Protocol): + def __init__(self, spec: Spec, base_url: Optional[str] = None): + ... + + def unmarshal( + self, + request: Request, + ) -> RequestUnmarshalResult: + ... + + +@runtime_checkable +class WebhookRequestUnmarshaller(Protocol): + def __init__(self, spec: Spec, base_url: Optional[str] = None): + ... + + def unmarshal( + self, + request: WebhookRequest, + ) -> RequestUnmarshalResult: + ... diff --git a/openapi_core/validation/request/proxies.py b/openapi_core/unmarshalling/request/proxies.py similarity index 76% rename from openapi_core/validation/request/proxies.py rename to openapi_core/unmarshalling/request/proxies.py index bb6f49ec..04024c1a 100644 --- a/openapi_core/validation/request/proxies.py +++ b/openapi_core/unmarshalling/request/proxies.py @@ -8,43 +8,37 @@ from typing import Tuple from typing import Type +from openapi_core.exceptions import SpecError +from openapi_core.protocols import Request from openapi_core.spec import Spec -from openapi_core.validation.exceptions import ValidatorDetectError -from openapi_core.validation.request.datatypes import RequestValidationResult -from openapi_core.validation.request.protocols import Request +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult if TYPE_CHECKING: - from openapi_core.validation.request.validators import ( - BaseAPICallRequestValidator, + from openapi_core.unmarshalling.request.unmarshallers import ( + APICallRequestUnmarshaller, ) class SpecRequestValidatorProxy: def __init__( self, - unmarshaller_cls_name: str, + unmarshaller_cls: Type["APICallRequestUnmarshaller"], deprecated: str = "RequestValidator", use: Optional[str] = None, **unmarshaller_kwargs: Any, ): - self.unmarshaller_cls_name = unmarshaller_cls_name + self.unmarshaller_cls = unmarshaller_cls self.unmarshaller_kwargs = unmarshaller_kwargs self.deprecated = deprecated - self.use = use or self.unmarshaller_cls_name - - @property - def unmarshaller_cls(self) -> Type["BaseAPICallRequestValidator"]: - from openapi_core.unmarshalling.request import unmarshallers - - return getattr(unmarshallers, self.unmarshaller_cls_name) + self.use = use or self.unmarshaller_cls.__name__ def validate( self, spec: Spec, request: Request, base_url: Optional[str] = None, - ) -> RequestValidationResult: + ) -> "RequestUnmarshalResult": warnings.warn( f"{self.deprecated} is deprecated. Use {self.use} instead.", DeprecationWarning, @@ -52,7 +46,7 @@ def validate( unmarshaller = self.unmarshaller_cls( spec, base_url=base_url, **self.unmarshaller_kwargs ) - return unmarshaller.validate(request) + return unmarshaller.unmarshal(request) def is_valid( self, @@ -88,14 +82,14 @@ def detect(self, spec: Spec) -> SpecRequestValidatorProxy: for (key, value), validator in self.choices.items(): if key in spec and spec[key].startswith(value): return validator - raise ValidatorDetectError("Spec schema version not detected") + raise SpecError("Spec schema version not detected") def validate( self, spec: Spec, request: Request, base_url: Optional[str] = None, - ) -> RequestValidationResult: + ) -> "RequestUnmarshalResult": validator = self.detect(spec) return validator.validate(spec, request, base_url=base_url) diff --git a/openapi_core/unmarshalling/request/types.py b/openapi_core/unmarshalling/request/types.py new file mode 100644 index 00000000..e889bfec --- /dev/null +++ b/openapi_core/unmarshalling/request/types.py @@ -0,0 +1,13 @@ +from typing import Type +from typing import Union + +from openapi_core.unmarshalling.request.protocols import RequestUnmarshaller +from openapi_core.unmarshalling.request.protocols import ( + WebhookRequestUnmarshaller, +) + +RequestUnmarshallerType = Type[RequestUnmarshaller] +WebhookRequestUnmarshallerType = Type[WebhookRequestUnmarshaller] +AnyRequestUnmarshallerType = Union[ + RequestUnmarshallerType, WebhookRequestUnmarshallerType +] diff --git a/openapi_core/unmarshalling/request/unmarshallers.py b/openapi_core/unmarshalling/request/unmarshallers.py new file mode 100644 index 00000000..e828d8a6 --- /dev/null +++ b/openapi_core/unmarshalling/request/unmarshallers.py @@ -0,0 +1,423 @@ +from typing import Any +from typing import Optional + +from openapi_core.casting.schemas import schema_casters_factory +from openapi_core.casting.schemas.factories import SchemaCastersFactory +from openapi_core.deserializing.media_types import ( + media_type_deserializers_factory, +) +from openapi_core.deserializing.media_types.factories import ( + MediaTypeDeserializersFactory, +) +from openapi_core.deserializing.parameters import ( + parameter_deserializers_factory, +) +from openapi_core.deserializing.parameters.factories import ( + ParameterDeserializersFactory, +) +from openapi_core.protocols import BaseRequest +from openapi_core.protocols import Request +from openapi_core.protocols import WebhookRequest +from openapi_core.security import security_provider_factory +from openapi_core.security.factories import SecurityProviderFactory +from openapi_core.spec import Spec +from openapi_core.templating.paths.exceptions import PathError +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult +from openapi_core.unmarshalling.request.proxies import ( + SpecRequestValidatorProxy, +) +from openapi_core.unmarshalling.schemas import ( + oas30_write_schema_unmarshallers_factory, +) +from openapi_core.unmarshalling.schemas import ( + oas31_schema_unmarshallers_factory, +) +from openapi_core.unmarshalling.schemas.factories import ( + SchemaUnmarshallersFactory, +) +from openapi_core.unmarshalling.unmarshallers import BaseUnmarshaller +from openapi_core.util import chainiters +from openapi_core.validation.request.exceptions import MissingRequestBody +from openapi_core.validation.request.exceptions import ParametersError +from openapi_core.validation.request.exceptions import RequestBodyError +from openapi_core.validation.request.exceptions import SecurityError +from openapi_core.validation.request.validators import APICallRequestValidator +from openapi_core.validation.request.validators import BaseRequestValidator +from openapi_core.validation.request.validators import V30RequestBodyValidator +from openapi_core.validation.request.validators import ( + V30RequestParametersValidator, +) +from openapi_core.validation.request.validators import ( + V30RequestSecurityValidator, +) +from openapi_core.validation.request.validators import V30RequestValidator +from openapi_core.validation.request.validators import V31RequestBodyValidator +from openapi_core.validation.request.validators import ( + V31RequestParametersValidator, +) +from openapi_core.validation.request.validators import ( + V31RequestSecurityValidator, +) +from openapi_core.validation.request.validators import V31RequestValidator +from openapi_core.validation.request.validators import ( + V31WebhookRequestBodyValidator, +) +from openapi_core.validation.request.validators import ( + V31WebhookRequestParametersValidator, +) +from openapi_core.validation.request.validators import ( + V31WebhookRequestSecurityValidator, +) +from openapi_core.validation.request.validators import ( + V31WebhookRequestValidator, +) +from openapi_core.validation.request.validators import WebhookRequestValidator +from openapi_core.validation.schemas.factories import SchemaValidatorsFactory + + +class BaseRequestUnmarshaller(BaseRequestValidator, BaseUnmarshaller): + def __init__( + self, + spec: Spec, + base_url: Optional[str] = None, + schema_casters_factory: SchemaCastersFactory = schema_casters_factory, + parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory, + media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, + schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + security_provider_factory: SecurityProviderFactory = security_provider_factory, + schema_unmarshallers_factory: Optional[ + SchemaUnmarshallersFactory + ] = None, + ): + BaseUnmarshaller.__init__( + self, + spec, + base_url=base_url, + schema_casters_factory=schema_casters_factory, + parameter_deserializers_factory=parameter_deserializers_factory, + media_type_deserializers_factory=media_type_deserializers_factory, + schema_validators_factory=schema_validators_factory, + schema_unmarshallers_factory=schema_unmarshallers_factory, + ) + BaseRequestValidator.__init__( + self, + spec, + base_url=base_url, + schema_casters_factory=schema_casters_factory, + parameter_deserializers_factory=parameter_deserializers_factory, + media_type_deserializers_factory=media_type_deserializers_factory, + schema_validators_factory=schema_validators_factory, + security_provider_factory=security_provider_factory, + ) + + def _unmarshal( + self, request: BaseRequest, operation: Spec, path: Spec + ) -> RequestUnmarshalResult: + try: + security = self._get_security(request.parameters, operation) + except SecurityError as exc: + return RequestUnmarshalResult(errors=[exc]) + + try: + params = self._get_parameters(request.parameters, operation, path) + except ParametersError as exc: + params = exc.parameters + params_errors = exc.errors + else: + params_errors = [] + + try: + body = self._get_body(request.body, request.mimetype, operation) + except MissingRequestBody: + body = None + body_errors = [] + except RequestBodyError as exc: + body = None + body_errors = [exc] + else: + body_errors = [] + + errors = list(chainiters(params_errors, body_errors)) + return RequestUnmarshalResult( + errors=errors, + body=body, + parameters=params, + security=security, + ) + + def _unmarshal_body( + self, request: BaseRequest, operation: Spec, path: Spec + ) -> RequestUnmarshalResult: + try: + body = self._get_body(request.body, request.mimetype, operation) + except MissingRequestBody: + body = None + errors = [] + except RequestBodyError as exc: + body = None + errors = [exc] + else: + errors = [] + + return RequestUnmarshalResult( + errors=errors, + body=body, + ) + + def _unmarshal_parameters( + self, request: BaseRequest, operation: Spec, path: Spec + ) -> RequestUnmarshalResult: + try: + params = self._get_parameters(request.parameters, path, operation) + except ParametersError as exc: + params = exc.parameters + params_errors = exc.errors + else: + params_errors = [] + + return RequestUnmarshalResult( + errors=params_errors, + parameters=params, + ) + + def _unmarshal_security( + self, request: BaseRequest, operation: Spec, path: Spec + ) -> RequestUnmarshalResult: + try: + security = self._get_security(request.parameters, operation) + except SecurityError as exc: + return RequestUnmarshalResult(errors=[exc]) + + return RequestUnmarshalResult( + errors=[], + security=security, + ) + + +class BaseAPICallRequestUnmarshaller(BaseRequestUnmarshaller): + pass + + +class BaseWebhookRequestUnmarshaller(BaseRequestUnmarshaller): + pass + + +class APICallRequestUnmarshaller( + APICallRequestValidator, BaseAPICallRequestUnmarshaller +): + def unmarshal(self, request: Request) -> RequestUnmarshalResult: + try: + path, operation, _, path_result, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return RequestUnmarshalResult(errors=[exc]) + + request.parameters.path = ( + request.parameters.path or path_result.variables + ) + + return self._unmarshal(request, operation, path) + + +class APICallRequestBodyUnmarshaller( + APICallRequestValidator, BaseAPICallRequestUnmarshaller +): + def unmarshal(self, request: Request) -> RequestUnmarshalResult: + try: + path, operation, _, path_result, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return RequestUnmarshalResult(errors=[exc]) + + request.parameters.path = ( + request.parameters.path or path_result.variables + ) + + return self._unmarshal_body(request, operation, path) + + +class APICallRequestParametersUnmarshaller( + APICallRequestValidator, BaseAPICallRequestUnmarshaller +): + def unmarshal(self, request: Request) -> RequestUnmarshalResult: + try: + path, operation, _, path_result, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return RequestUnmarshalResult(errors=[exc]) + + request.parameters.path = ( + request.parameters.path or path_result.variables + ) + + return self._unmarshal_parameters(request, operation, path) + + +class APICallRequestSecurityUnmarshaller( + APICallRequestValidator, BaseAPICallRequestUnmarshaller +): + def unmarshal(self, request: Request) -> RequestUnmarshalResult: + try: + path, operation, _, path_result, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return RequestUnmarshalResult(errors=[exc]) + + request.parameters.path = ( + request.parameters.path or path_result.variables + ) + + return self._unmarshal_security(request, operation, path) + + +class WebhookRequestUnmarshaller( + WebhookRequestValidator, BaseWebhookRequestUnmarshaller +): + def unmarshal(self, request: WebhookRequest) -> RequestUnmarshalResult: + try: + path, operation, _, path_result, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return RequestUnmarshalResult(errors=[exc]) + + request.parameters.path = ( + request.parameters.path or path_result.variables + ) + + return self._unmarshal(request, operation, path) + + +class WebhookRequestBodyUnmarshaller( + WebhookRequestValidator, BaseWebhookRequestUnmarshaller +): + def unmarshal(self, request: WebhookRequest) -> RequestUnmarshalResult: + try: + path, operation, _, path_result, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return RequestUnmarshalResult(errors=[exc]) + + request.parameters.path = ( + request.parameters.path or path_result.variables + ) + + return self._unmarshal_body(request, operation, path) + + +class WebhookRequestParametersUnmarshaller( + WebhookRequestValidator, BaseWebhookRequestUnmarshaller +): + def unmarshal(self, request: WebhookRequest) -> RequestUnmarshalResult: + try: + path, operation, _, path_result, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return RequestUnmarshalResult(errors=[exc]) + + request.parameters.path = ( + request.parameters.path or path_result.variables + ) + + return self._unmarshal_parameters(request, operation, path) + + +class WebhookRequestSecuritysUnmarshaller( + WebhookRequestValidator, BaseWebhookRequestUnmarshaller +): + def unmarshal(self, request: WebhookRequest) -> RequestUnmarshalResult: + try: + path, operation, _, path_result, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return RequestUnmarshalResult(errors=[exc]) + + request.parameters.path = ( + request.parameters.path or path_result.variables + ) + + return self._unmarshal_security(request, operation, path) + + +class V30RequestBodyUnmarshaller( + V30RequestBodyValidator, APICallRequestBodyUnmarshaller +): + schema_unmarshallers_factory = oas30_write_schema_unmarshallers_factory + + +class V30RequestParametersUnmarshaller( + V30RequestParametersValidator, APICallRequestParametersUnmarshaller +): + schema_unmarshallers_factory = oas30_write_schema_unmarshallers_factory + + +class V30RequestSecurityUnmarshaller( + V30RequestSecurityValidator, APICallRequestSecurityUnmarshaller +): + schema_unmarshallers_factory = oas30_write_schema_unmarshallers_factory + + +class V30RequestUnmarshaller(V30RequestValidator, APICallRequestUnmarshaller): + schema_unmarshallers_factory = oas30_write_schema_unmarshallers_factory + + +class V31RequestBodyUnmarshaller( + V31RequestBodyValidator, APICallRequestBodyUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31RequestParametersUnmarshaller( + V31RequestParametersValidator, APICallRequestParametersUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31RequestSecurityUnmarshaller( + V31RequestSecurityValidator, APICallRequestSecurityUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31RequestUnmarshaller(V31RequestValidator, APICallRequestUnmarshaller): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31WebhookRequestBodyUnmarshaller( + V31WebhookRequestBodyValidator, WebhookRequestBodyUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31WebhookRequestParametersUnmarshaller( + V31WebhookRequestParametersValidator, WebhookRequestParametersUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31WebhookRequestSecurityUnmarshaller( + V31WebhookRequestSecurityValidator, WebhookRequestSecuritysUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31WebhookRequestUnmarshaller( + V31WebhookRequestValidator, WebhookRequestUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +# backward compatibility +class RequestValidator(SpecRequestValidatorProxy): + def __init__( + self, + schema_unmarshallers_factory: "SchemaUnmarshallersFactory", + **kwargs: Any, + ): + super().__init__( + APICallRequestUnmarshaller, + schema_validators_factory=( + schema_unmarshallers_factory.schema_validators_factory + ), + schema_unmarshallers_factory=schema_unmarshallers_factory, + **kwargs, + ) diff --git a/openapi_core/unmarshalling/response/__init__.py b/openapi_core/unmarshalling/response/__init__.py new file mode 100644 index 00000000..60ec202f --- /dev/null +++ b/openapi_core/unmarshalling/response/__init__.py @@ -0,0 +1,67 @@ +"""OpenAPI core unmarshalling response module""" +from openapi_core.unmarshalling.response.proxies import ( + DetectResponseValidatorProxy, +) +from openapi_core.unmarshalling.response.proxies import ( + SpecResponseValidatorProxy, +) +from openapi_core.unmarshalling.response.unmarshallers import ( + APICallResponseUnmarshaller, +) +from openapi_core.unmarshalling.response.unmarshallers import ResponseValidator +from openapi_core.unmarshalling.response.unmarshallers import ( + V30ResponseUnmarshaller, +) +from openapi_core.unmarshalling.response.unmarshallers import ( + V31ResponseUnmarshaller, +) +from openapi_core.unmarshalling.response.unmarshallers import ( + V31WebhookResponseUnmarshaller, +) +from openapi_core.unmarshalling.schemas import ( + oas30_read_schema_unmarshallers_factory, +) +from openapi_core.unmarshalling.schemas import ( + oas31_schema_unmarshallers_factory, +) + +__all__ = [ + "V30ResponseUnmarshaller", + "V31ResponseUnmarshaller", + "V31WebhookResponseUnmarshaller", + "ResponseValidator", + "openapi_v30_response_validator", + "openapi_v31_response_validator", + "openapi_v3_response_validator", + "openapi_response_validator", +] + +# alias to the latest v3 version +V3ResponseUnmarshaller = V31ResponseUnmarshaller +V3WebhookResponseUnmarshaller = V31WebhookResponseUnmarshaller + +# spec validators +openapi_v30_response_validator = SpecResponseValidatorProxy( + APICallResponseUnmarshaller, + schema_unmarshallers_factory=oas30_read_schema_unmarshallers_factory, + deprecated="openapi_v30_response_validator", + use="V30ResponseUnmarshaller", +) + +openapi_v31_response_validator = SpecResponseValidatorProxy( + APICallResponseUnmarshaller, + schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, + deprecated="openapi_v31_response_validator", + use="V31ResponseUnmarshaller", +) + +# spec validators alias to the latest v3 version +openapi_v3_response_validator = openapi_v31_response_validator + +# detect version spec +openapi_response_validator = DetectResponseValidatorProxy( + { + ("openapi", "3.0"): openapi_v30_response_validator, + ("openapi", "3.1"): openapi_v31_response_validator, + }, +) diff --git a/openapi_core/validation/response/datatypes.py b/openapi_core/unmarshalling/response/datatypes.py similarity index 57% rename from openapi_core/validation/response/datatypes.py rename to openapi_core/unmarshalling/response/datatypes.py index f820936b..5a27d1fa 100644 --- a/openapi_core/validation/response/datatypes.py +++ b/openapi_core/unmarshalling/response/datatypes.py @@ -1,14 +1,14 @@ -"""OpenAPI core validation response datatypes module""" +"""OpenAPI core unmarshalling response datatypes module""" from dataclasses import dataclass from dataclasses import field from typing import Any from typing import Dict from typing import Optional -from openapi_core.validation.datatypes import BaseValidationResult +from openapi_core.unmarshalling.datatypes import BaseUnmarshalResult @dataclass -class ResponseValidationResult(BaseValidationResult): +class ResponseUnmarshalResult(BaseUnmarshalResult): data: Optional[str] = None headers: Dict[str, Any] = field(default_factory=dict) diff --git a/openapi_core/unmarshalling/response/protocols.py b/openapi_core/unmarshalling/response/protocols.py new file mode 100644 index 00000000..6c382865 --- /dev/null +++ b/openapi_core/unmarshalling/response/protocols.py @@ -0,0 +1,46 @@ +"""OpenAPI core validation response protocols module""" +import sys +from typing import Any +from typing import Mapping +from typing import Optional + +if sys.version_info >= (3, 8): + from typing import Protocol + from typing import runtime_checkable +else: + from typing_extensions import Protocol + from typing_extensions import runtime_checkable + +from openapi_core.protocols import Request +from openapi_core.protocols import Response +from openapi_core.protocols import WebhookRequest +from openapi_core.spec import Spec +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) + + +@runtime_checkable +class ResponseUnmarshaller(Protocol): + def __init__(self, spec: Spec, base_url: Optional[str] = None): + ... + + def unmarshal( + self, + request: Request, + response: Response, + ) -> ResponseUnmarshalResult: + ... + + +@runtime_checkable +class WebhookResponseUnmarshaller(Protocol): + def __init__(self, spec: Spec, base_url: Optional[str] = None): + ... + + def unmarshal( + self, + request: WebhookRequest, + response: Response, + ) -> ResponseUnmarshalResult: + ... diff --git a/openapi_core/validation/response/proxies.py b/openapi_core/unmarshalling/response/proxies.py similarity index 76% rename from openapi_core/validation/response/proxies.py rename to openapi_core/unmarshalling/response/proxies.py index 1221fe22..5d364386 100644 --- a/openapi_core/validation/response/proxies.py +++ b/openapi_core/unmarshalling/response/proxies.py @@ -8,37 +8,33 @@ from typing import Tuple from typing import Type +from openapi_core.exceptions import SpecError +from openapi_core.protocols import Request +from openapi_core.protocols import Response from openapi_core.spec import Spec -from openapi_core.validation.exceptions import ValidatorDetectError -from openapi_core.validation.request.protocols import Request -from openapi_core.validation.response.datatypes import ResponseValidationResult -from openapi_core.validation.response.protocols import Response +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) if TYPE_CHECKING: - from openapi_core.validation.response.validators import ( - BaseAPICallResponseValidator, + from openapi_core.unmarshalling.response.unmarshallers import ( + APICallResponseUnmarshaller, ) class SpecResponseValidatorProxy: def __init__( self, - unmarshaller_cls_name: Type["BaseAPICallResponseValidator"], + unmarshaller_cls: Type["APICallResponseUnmarshaller"], deprecated: str = "ResponseValidator", use: Optional[str] = None, **unmarshaller_kwargs: Any, ): - self.unmarshaller_cls_name = unmarshaller_cls_name + self.unmarshaller_cls = unmarshaller_cls self.unmarshaller_kwargs = unmarshaller_kwargs self.deprecated = deprecated - self.use = use or self.unmarshaller_cls_name - - @property - def unmarshaller_cls(self) -> Type["BaseAPICallResponseValidator"]: - from openapi_core.unmarshalling.response import unmarshallers - - return getattr(unmarshallers, self.unmarshaller_cls_name) + self.use = use or self.unmarshaller_cls.__name__ def validate( self, @@ -46,7 +42,7 @@ def validate( request: Request, response: Response, base_url: Optional[str] = None, - ) -> ResponseValidationResult: + ) -> "ResponseUnmarshalResult": warnings.warn( f"{self.deprecated} is deprecated. Use {self.use} instead.", DeprecationWarning, @@ -54,7 +50,7 @@ def validate( unmarshaller = self.unmarshaller_cls( spec, base_url=base_url, **self.unmarshaller_kwargs ) - return unmarshaller.validate(request, response) + return unmarshaller.unmarshal(request, response) def is_valid( self, @@ -95,7 +91,7 @@ def detect(self, spec: Spec) -> SpecResponseValidatorProxy: for (key, value), validator in self.choices.items(): if key in spec and spec[key].startswith(value): return validator - raise ValidatorDetectError("Spec schema version not detected") + raise SpecError("Spec schema version not detected") def validate( self, @@ -103,7 +99,7 @@ def validate( request: Request, response: Response, base_url: Optional[str] = None, - ) -> ResponseValidationResult: + ) -> "ResponseUnmarshalResult": validator = self.detect(spec) return validator.validate(spec, request, response, base_url=base_url) diff --git a/openapi_core/unmarshalling/response/types.py b/openapi_core/unmarshalling/response/types.py new file mode 100644 index 00000000..bc3e004e --- /dev/null +++ b/openapi_core/unmarshalling/response/types.py @@ -0,0 +1,13 @@ +from typing import Type +from typing import Union + +from openapi_core.unmarshalling.response.protocols import ResponseUnmarshaller +from openapi_core.unmarshalling.response.protocols import ( + WebhookResponseUnmarshaller, +) + +ResponseUnmarshallerType = Type[ResponseUnmarshaller] +WebhookResponseUnmarshallerType = Type[WebhookResponseUnmarshaller] +AnyResponseUnmarshallerType = Union[ + ResponseUnmarshallerType, WebhookResponseUnmarshallerType +] diff --git a/openapi_core/unmarshalling/response/unmarshallers.py b/openapi_core/unmarshalling/response/unmarshallers.py new file mode 100644 index 00000000..dfcb33d1 --- /dev/null +++ b/openapi_core/unmarshalling/response/unmarshallers.py @@ -0,0 +1,328 @@ +from typing import Any +from typing import Mapping + +from openapi_core.protocols import BaseRequest +from openapi_core.protocols import Request +from openapi_core.protocols import Response +from openapi_core.protocols import WebhookRequest +from openapi_core.spec import Spec +from openapi_core.templating.paths.exceptions import PathError +from openapi_core.templating.responses.exceptions import ResponseFinderError +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) +from openapi_core.unmarshalling.response.proxies import ( + SpecResponseValidatorProxy, +) +from openapi_core.unmarshalling.schemas import ( + oas30_read_schema_unmarshallers_factory, +) +from openapi_core.unmarshalling.schemas import ( + oas31_schema_unmarshallers_factory, +) +from openapi_core.unmarshalling.schemas.factories import ( + SchemaUnmarshallersFactory, +) +from openapi_core.unmarshalling.unmarshallers import BaseUnmarshaller +from openapi_core.util import chainiters +from openapi_core.validation.response.exceptions import DataError +from openapi_core.validation.response.exceptions import HeadersError +from openapi_core.validation.response.validators import ( + APICallResponseValidator, +) +from openapi_core.validation.response.validators import BaseResponseValidator +from openapi_core.validation.response.validators import ( + V30ResponseDataValidator, +) +from openapi_core.validation.response.validators import ( + V30ResponseHeadersValidator, +) +from openapi_core.validation.response.validators import V30ResponseValidator +from openapi_core.validation.response.validators import ( + V31ResponseDataValidator, +) +from openapi_core.validation.response.validators import ( + V31ResponseHeadersValidator, +) +from openapi_core.validation.response.validators import V31ResponseValidator +from openapi_core.validation.response.validators import ( + V31WebhookResponseDataValidator, +) +from openapi_core.validation.response.validators import ( + V31WebhookResponseHeadersValidator, +) +from openapi_core.validation.response.validators import ( + V31WebhookResponseValidator, +) +from openapi_core.validation.response.validators import ( + WebhookResponseValidator, +) + + +class BaseResponseUnmarshaller(BaseResponseValidator, BaseUnmarshaller): + def _unmarshal( + self, + response: Response, + operation: Spec, + ) -> ResponseUnmarshalResult: + try: + operation_response = self._get_operation_response( + response.status_code, operation + ) + # don't process if operation errors + except ResponseFinderError as exc: + return ResponseUnmarshalResult(errors=[exc]) + + try: + validated_data = self._get_data( + response.data, response.mimetype, operation_response + ) + except DataError as exc: + validated_data = None + data_errors = [exc] + else: + data_errors = [] + + try: + validated_headers = self._get_headers( + response.headers, operation_response + ) + except HeadersError as exc: + validated_headers = exc.headers + headers_errors = exc.context + else: + headers_errors = [] + + errors = list(chainiters(data_errors, headers_errors)) + return ResponseUnmarshalResult( + errors=errors, + data=validated_data, + headers=validated_headers, + ) + + def _unmarshal_data( + self, + response: Response, + operation: Spec, + ) -> ResponseUnmarshalResult: + try: + operation_response = self._get_operation_response( + response.status_code, operation + ) + # don't process if operation errors + except ResponseFinderError as exc: + return ResponseUnmarshalResult(errors=[exc]) + + try: + validated = self._get_data( + response.data, response.mimetype, operation_response + ) + except DataError as exc: + validated = None + data_errors = [exc] + else: + data_errors = [] + + return ResponseUnmarshalResult( + errors=data_errors, + data=validated, + ) + + def _unmarshal_headers( + self, + response: Response, + operation: Spec, + ) -> ResponseUnmarshalResult: + try: + operation_response = self._get_operation_response( + response.status_code, operation + ) + # don't process if operation errors + except ResponseFinderError as exc: + return ResponseUnmarshalResult(errors=[exc]) + + try: + validated = self._get_headers(response.headers, operation_response) + except HeadersError as exc: + validated = exc.headers + headers_errors = exc.context + else: + headers_errors = [] + + return ResponseUnmarshalResult( + errors=headers_errors, + headers=validated, + ) + + +class APICallResponseUnmarshaller( + APICallResponseValidator, BaseResponseUnmarshaller +): + def unmarshal( + self, + request: Request, + response: Response, + ) -> ResponseUnmarshalResult: + try: + _, operation, _, _, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return ResponseUnmarshalResult(errors=[exc]) + + return self._unmarshal(response, operation) + + +class APICallResponseDataUnmarshaller( + APICallResponseValidator, BaseResponseUnmarshaller +): + def unmarshal( + self, + request: Request, + response: Response, + ) -> ResponseUnmarshalResult: + try: + _, operation, _, _, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return ResponseUnmarshalResult(errors=[exc]) + + return self._unmarshal_data(response, operation) + + +class APICallResponseHeadersUnmarshaller( + APICallResponseValidator, BaseResponseUnmarshaller +): + def unmarshal( + self, + request: Request, + response: Response, + ) -> ResponseUnmarshalResult: + try: + _, operation, _, _, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return ResponseUnmarshalResult(errors=[exc]) + + return self._unmarshal_headers(response, operation) + + +class WebhookResponseUnmarshaller( + WebhookResponseValidator, BaseResponseUnmarshaller +): + def unmarshal( + self, + request: WebhookRequest, + response: Response, + ) -> ResponseUnmarshalResult: + try: + _, operation, _, _, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return ResponseUnmarshalResult(errors=[exc]) + + return self._unmarshal(response, operation) + + +class WebhookResponseDataUnmarshaller( + WebhookResponseValidator, BaseResponseUnmarshaller +): + def unmarshal( + self, + request: WebhookRequest, + response: Response, + ) -> ResponseUnmarshalResult: + try: + _, operation, _, _, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return ResponseUnmarshalResult(errors=[exc]) + + return self._unmarshal_data(response, operation) + + +class WebhookResponseHeadersUnmarshaller( + WebhookResponseValidator, BaseResponseUnmarshaller +): + def unmarshal( + self, + request: WebhookRequest, + response: Response, + ) -> ResponseUnmarshalResult: + try: + _, operation, _, _, _ = self._find_path(request) + # don't process if operation errors + except PathError as exc: + return ResponseUnmarshalResult(errors=[exc]) + + return self._unmarshal_headers(response, operation) + + +class V30ResponseDataUnmarshaller( + V30ResponseDataValidator, APICallResponseDataUnmarshaller +): + schema_unmarshallers_factory = oas30_read_schema_unmarshallers_factory + + +class V30ResponseHeadersUnmarshaller( + V30ResponseHeadersValidator, APICallResponseHeadersUnmarshaller +): + schema_unmarshallers_factory = oas30_read_schema_unmarshallers_factory + + +class V30ResponseUnmarshaller( + V30ResponseValidator, APICallResponseUnmarshaller +): + schema_unmarshallers_factory = oas30_read_schema_unmarshallers_factory + + +class V31ResponseDataUnmarshaller( + V31ResponseDataValidator, APICallResponseDataUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31ResponseHeadersUnmarshaller( + V31ResponseHeadersValidator, APICallResponseHeadersUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31ResponseUnmarshaller( + V31ResponseValidator, APICallResponseUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31WebhookResponseDataUnmarshaller( + V31WebhookResponseDataValidator, WebhookResponseDataUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31WebhookResponseHeadersUnmarshaller( + V31WebhookResponseHeadersValidator, WebhookResponseHeadersUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +class V31WebhookResponseUnmarshaller( + V31WebhookResponseValidator, WebhookResponseUnmarshaller +): + schema_unmarshallers_factory = oas31_schema_unmarshallers_factory + + +# backward compatibility +class ResponseValidator(SpecResponseValidatorProxy): + def __init__( + self, + schema_unmarshallers_factory: "SchemaUnmarshallersFactory", + **kwargs: Any, + ): + super().__init__( + APICallResponseUnmarshaller, + schema_validators_factory=( + schema_unmarshallers_factory.schema_validators_factory + ), + schema_unmarshallers_factory=schema_unmarshallers_factory, + **kwargs, + ) diff --git a/openapi_core/unmarshalling/schemas/__init__.py b/openapi_core/unmarshalling/schemas/__init__.py index d74e2eb9..9011bcc3 100644 --- a/openapi_core/unmarshalling/schemas/__init__.py +++ b/openapi_core/unmarshalling/schemas/__init__.py @@ -1,9 +1,7 @@ from collections import OrderedDict -from functools import partial from isodate.isodatetime import parse_datetime -from openapi_core.unmarshalling.schemas.enums import ValidationContext from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) @@ -12,13 +10,7 @@ from openapi_core.unmarshalling.schemas.unmarshallers import ( MultiTypeUnmarshaller, ) -from openapi_core.unmarshalling.schemas.unmarshallers import ( - ObjectReadUnmarshaller, -) from openapi_core.unmarshalling.schemas.unmarshallers import ObjectUnmarshaller -from openapi_core.unmarshalling.schemas.unmarshallers import ( - ObjectWriteUnmarshaller, -) from openapi_core.unmarshalling.schemas.unmarshallers import ( PrimitiveUnmarshaller, ) @@ -52,18 +44,6 @@ ("object", ObjectUnmarshaller), ] ) -oas30_write_unmarshallers_dict = oas30_unmarshallers_dict.copy() -oas30_write_unmarshallers_dict.update( - { - "object": ObjectWriteUnmarshaller, - } -) -oas30_read_unmarshallers_dict = oas30_unmarshallers_dict.copy() -oas30_read_unmarshallers_dict.update( - { - "object": ObjectReadUnmarshaller, - } -) oas31_unmarshallers_dict = oas30_unmarshallers_dict.copy() oas31_unmarshallers_dict.update( { @@ -71,11 +51,7 @@ } ) -oas30_write_types_unmarshaller = TypesUnmarshaller( - oas30_unmarshallers_dict, - AnyUnmarshaller, -) -oas30_read_types_unmarshaller = TypesUnmarshaller( +oas30_types_unmarshaller = TypesUnmarshaller( oas30_unmarshallers_dict, AnyUnmarshaller, ) @@ -97,13 +73,13 @@ oas30_write_schema_unmarshallers_factory = SchemaUnmarshallersFactory( oas30_write_schema_validators_factory, - oas30_write_types_unmarshaller, + oas30_types_unmarshaller, format_unmarshallers=oas30_format_unmarshallers, ) oas30_read_schema_unmarshallers_factory = SchemaUnmarshallersFactory( oas30_read_schema_validators_factory, - oas30_read_types_unmarshaller, + oas30_types_unmarshaller, format_unmarshallers=oas30_format_unmarshallers, ) diff --git a/openapi_core/unmarshalling/schemas/enums.py b/openapi_core/unmarshalling/schemas/enums.py deleted file mode 100644 index 2f8d88f2..00000000 --- a/openapi_core/unmarshalling/schemas/enums.py +++ /dev/null @@ -1,7 +0,0 @@ -"""OpenAPI core unmarshalling schemas enums module""" -from enum import Enum - - -class ValidationContext(Enum): - REQUEST = "request" - RESPONSE = "response" diff --git a/openapi_core/unmarshalling/schemas/factories.py b/openapi_core/unmarshalling/schemas/factories.py index 9cce1ce7..a3c36243 100644 --- a/openapi_core/unmarshalling/schemas/factories.py +++ b/openapi_core/unmarshalling/schemas/factories.py @@ -11,40 +11,17 @@ from openapi_core.unmarshalling.schemas.datatypes import CustomFormattersDict from openapi_core.unmarshalling.schemas.datatypes import FormatUnmarshaller from openapi_core.unmarshalling.schemas.datatypes import UnmarshallersDict -from openapi_core.unmarshalling.schemas.enums import ValidationContext from openapi_core.unmarshalling.schemas.exceptions import ( FormatterNotFoundError, ) +from openapi_core.unmarshalling.schemas.unmarshallers import ( + FormatsUnmarshaller, +) from openapi_core.unmarshalling.schemas.unmarshallers import SchemaUnmarshaller from openapi_core.unmarshalling.schemas.unmarshallers import TypesUnmarshaller from openapi_core.validation.schemas.factories import SchemaValidatorsFactory -class SchemaFormatUnmarshallersFactory: - def __init__( - self, - schema_validators_factory: SchemaValidatorsFactory, - format_unmarshallers: Optional[UnmarshallersDict] = None, - custom_formatters: Optional[CustomFormattersDict] = None, - ): - self.schema_validators_factory = schema_validators_factory - if format_unmarshallers is None: - format_unmarshallers = {} - self.format_unmarshallers = format_unmarshallers - if custom_formatters is None: - custom_formatters = {} - self.custom_formatters = custom_formatters - - def create(self, schema_format: str) -> Optional[FormatUnmarshaller]: - if schema_format in self.custom_formatters: - formatter = self.custom_formatters[schema_format] - return formatter.format - if schema_format in self.format_unmarshallers: - return self.format_unmarshallers[schema_format] - - return None - - class SchemaUnmarshallersFactory: def __init__( self, @@ -57,29 +34,19 @@ def __init__( self.types_unmarshaller = types_unmarshaller if custom_formatters is None: custom_formatters = {} - else: - warnings.warn( - "custom_formatters is deprecated. " - "Register new checks to FormatChecker to validate custom formats " - "and add format_unmarshallers to unmarshal custom formats.", - DeprecationWarning, - ) if format_unmarshallers is None: format_unmarshallers = {} self.format_unmarshallers = format_unmarshallers self.custom_formatters = custom_formatters @cached_property - def format_unmarshallers_factory(self) -> SchemaFormatUnmarshallersFactory: - return SchemaFormatUnmarshallersFactory( - self.schema_validators_factory, + def formats_unmarshaller(self) -> FormatsUnmarshaller: + return FormatsUnmarshaller( self.format_unmarshallers, self.custom_formatters, ) - def create( - self, schema: Spec, type_override: Optional[str] = None - ) -> SchemaUnmarshaller: + def create(self, schema: Spec) -> SchemaUnmarshaller: """Create unmarshaller from the schema.""" if schema is None: raise TypeError("Invalid schema") @@ -91,9 +58,9 @@ def create( name: formatter.validate for name, formatter in self.custom_formatters.items() } - self.schema_validators_factory.add_checks(**formatters_checks) - - schema_validator = self.schema_validators_factory.create(schema) + schema_validator = self.schema_validators_factory.create( + schema, **formatters_checks + ) schema_format = schema.getkey("format") @@ -109,7 +76,6 @@ def create( return SchemaUnmarshaller( schema, schema_validator, - self, - self.format_unmarshallers_factory, self.types_unmarshaller, + self.formats_unmarshaller, ) diff --git a/openapi_core/unmarshalling/schemas/unmarshallers.py b/openapi_core/unmarshalling/schemas/unmarshallers.py index f5ae678a..353e50d9 100644 --- a/openapi_core/unmarshalling/schemas/unmarshallers.py +++ b/openapi_core/unmarshalling/schemas/unmarshallers.py @@ -1,7 +1,5 @@ import logging import warnings -from functools import partial -from typing import TYPE_CHECKING from typing import Any from typing import Iterable from typing import Iterator @@ -10,119 +8,68 @@ from typing import Optional from typing import Type from typing import Union -from typing import cast from openapi_core.extensions.models.factories import ModelPathFactory from openapi_core.schema.schemas import get_properties from openapi_core.spec import Spec -from openapi_core.unmarshalling.schemas.enums import ValidationContext +from openapi_core.unmarshalling.schemas.datatypes import CustomFormattersDict +from openapi_core.unmarshalling.schemas.datatypes import FormatUnmarshaller +from openapi_core.unmarshalling.schemas.datatypes import UnmarshallersDict from openapi_core.unmarshalling.schemas.exceptions import FormatUnmarshalError from openapi_core.unmarshalling.schemas.exceptions import UnmarshallerError -from openapi_core.validation.schemas.exceptions import ValidateError from openapi_core.validation.schemas.validators import SchemaValidator -if TYPE_CHECKING: - from openapi_core.unmarshalling.schemas.factories import ( - SchemaFormatUnmarshallersFactory, - ) - from openapi_core.unmarshalling.schemas.factories import ( - SchemaUnmarshallersFactory, - ) - from openapi_core.unmarshalling.schemas.factories import ( - SchemaValidatorsFactory, - ) - log = logging.getLogger(__name__) class PrimitiveUnmarshaller: def __init__( self, - schema, - schema_validator, - schema_unmarshaller, - schema_unmarshallers_factory, + schema: Spec, + schema_validator: SchemaValidator, + schema_unmarshaller: "SchemaUnmarshaller", ) -> None: self.schema = schema self.schema_validator = schema_validator self.schema_unmarshaller = schema_unmarshaller - self.schema_unmarshallers_factory = schema_unmarshallers_factory - - self.schema_format = schema.getkey("format") - - def __call__(self, value: Any, subschemas: bool = True) -> Any: - best_format = self._get_format(value, subschemas=subschemas) - format_unmarshaller = self.schema_unmarshallers_factory.format_unmarshallers_factory.create( - best_format - ) - if format_unmarshaller is None: - return value - try: - return format_unmarshaller(value) - except (ValueError, TypeError) as exc: - raise FormatUnmarshalError(value, self.schema_format, exc) - def _get_format( - self, value: Any, subschemas: bool = True - ) -> Optional[str]: - if "format" in self.schema: - return self.schema.getkey("format") - - if subschemas is False: - return None - - one_of_schema = self.schema_validator.get_one_of_schema(value) - if one_of_schema is not None and "format" in one_of_schema: - return one_of_schema.getkey("format") - - any_of_schemas = self.schema_validator.iter_any_of_schemas(value) - for any_of_schema in any_of_schemas: - if "format" in any_of_schema: - return any_of_schema.getkey("format") - - all_of_schemas = self.schema_validator.iter_all_of_schemas(value) - for all_of_schema in all_of_schemas: - if "format" in all_of_schema: - return all_of_schema.getkey("format") - - return None + def __call__(self, value: Any) -> Any: + return value class ArrayUnmarshaller(PrimitiveUnmarshaller): + def __call__(self, value: Any) -> Optional[List[Any]]: + return list(map(self.items_unmarshaller.unmarshal, value)) + @property - def items_unmarshaller(self) -> "PrimitiveUnmarshaller": + def items_unmarshaller(self) -> "SchemaUnmarshaller": # sometimes we don't have any schema i.e. free-form objects items_schema = self.schema.get( "items", Spec.from_dict({}, validator=None) ) return self.schema_unmarshaller.evolve(items_schema) - def __call__(self, value: Any) -> Optional[List[Any]]: - return list(map(self.items_unmarshaller.unmarshal, value)) - class ObjectUnmarshaller(PrimitiveUnmarshaller): - context = NotImplemented - - @property - def object_class_factory(self) -> ModelPathFactory: - return ModelPathFactory() - def __call__(self, value: Any) -> Any: - properties = self._unmarshal_raw(value) + properties = self._unmarshal_properties(value) fields: Iterable[str] = properties and properties.keys() or [] object_class = self.object_class_factory.create(self.schema, fields) return object_class(**properties) - def _unmarshal_raw(self, value: Any, schema_only: bool = False) -> Any: - formatted = super().__call__(value) - return self._unmarshal_properties(formatted, schema_only=schema_only) + @property + def object_class_factory(self) -> ModelPathFactory: + return ModelPathFactory() def evolve(self, schema: Spec) -> "ObjectUnmarshaller": - return self.schema_unmarshaller.evolve(schema).get_unmarshaller( - "object" + cls = self.__class__ + + return cls( + schema, + self.schema_validator.evolve(schema), + self.schema_unmarshaller, ) def _unmarshal_properties( @@ -132,34 +79,26 @@ def _unmarshal_properties( one_of_schema = self.schema_validator.get_one_of_schema(value) if one_of_schema is not None: - one_of_properties = self.evolve(one_of_schema)._unmarshal_raw( - value, schema_only=True - ) + one_of_properties = self.evolve( + one_of_schema + )._unmarshal_properties(value, schema_only=True) properties.update(one_of_properties) any_of_schemas = self.schema_validator.iter_any_of_schemas(value) for any_of_schema in any_of_schemas: - any_of_properties = self.evolve(any_of_schema)._unmarshal_raw( - value, schema_only=True - ) + any_of_properties = self.evolve( + any_of_schema + )._unmarshal_properties(value, schema_only=True) properties.update(any_of_properties) all_of_schemas = self.schema_validator.iter_all_of_schemas(value) for all_of_schema in all_of_schemas: - all_of_properties = self.evolve(all_of_schema)._unmarshal_raw( - value, schema_only=True - ) + all_of_properties = self.evolve( + all_of_schema + )._unmarshal_properties(value, schema_only=True) properties.update(all_of_properties) for prop_name, prop_schema in get_properties(self.schema).items(): - # check for context in OpenAPI 3.0 - if self.context is not NotImplemented: - read_only = prop_schema.getkey("readOnly", False) - if self.context == ValidationContext.REQUEST and read_only: - continue - write_only = prop_schema.getkey("writeOnly", False) - if self.context == ValidationContext.RESPONSE and write_only: - continue try: prop_value = value[prop_name] except KeyError: @@ -167,7 +106,7 @@ def _unmarshal_properties( continue prop_value = prop_schema["default"] - properties[prop_name] = self.schema_unmarshallers_factory.create( + properties[prop_name] = self.schema_unmarshaller.evolve( prop_schema ).unmarshal(prop_value) @@ -186,10 +125,8 @@ def _unmarshal_properties( # defined schema else: additional_prop_schema = self.schema / "additionalProperties" - additional_prop_unmarshaler = ( - self.schema_unmarshallers_factory.create( - additional_prop_schema - ) + additional_prop_unmarshaler = self.schema_unmarshaller.evolve( + additional_prop_schema ) for prop_name, prop_value in value.items(): if prop_name in properties: @@ -201,15 +138,11 @@ def _unmarshal_properties( return properties -class ObjectReadUnmarshaller(ObjectUnmarshaller): - context = ValidationContext.RESPONSE - - -class ObjectWriteUnmarshaller(ObjectUnmarshaller): - context = ValidationContext.REQUEST - - class MultiTypeUnmarshaller(PrimitiveUnmarshaller): + def __call__(self, value: Any) -> Any: + unmarshaller = self._get_best_unmarshaller(value) + return unmarshaller(value) + @property def type(self) -> List[str]: types = self.schema.getkey("type", ["any"]) @@ -226,14 +159,10 @@ def _get_best_unmarshaller(self, value: Any) -> "PrimitiveUnmarshaller": result = self.schema_validator.format_validator(value) if not result: continue - return self.schema_unmarshaller.get_unmarshaller(schema_type) + return self.schema_unmarshaller.get_type_unmarshaller(schema_type) raise UnmarshallerError("Unmarshaller not found for type(s)") - def __call__(self, value: Any) -> Any: - unmarshaller = self._get_best_unmarshaller(value) - return unmarshaller(value) - class AnyUnmarshaller(MultiTypeUnmarshaller): SCHEMA_TYPES_ORDER = [ @@ -258,15 +187,15 @@ def __init__( self, unmarshallers: Mapping[str, Type[PrimitiveUnmarshaller]], default: Type[PrimitiveUnmarshaller], - multi: bool = False, + multi: Optional[Type[PrimitiveUnmarshaller]] = None, ): self.unmarshallers = unmarshallers self.default = default self.multi = multi - def get_type_unmarshaller( + def get_unmarshaller( self, - schema_type: Optional[Union[Iterable, str]], + schema_type: Optional[Union[Iterable[str], str]], ) -> Type["PrimitiveUnmarshaller"]: if schema_type is None: return self.default @@ -280,22 +209,53 @@ def get_type_unmarshaller( return self.unmarshallers[schema_type] +class FormatsUnmarshaller: + def __init__( + self, + format_unmarshallers: Optional[UnmarshallersDict] = None, + custom_formatters: Optional[CustomFormattersDict] = None, + ): + if format_unmarshallers is None: + format_unmarshallers = {} + self.format_unmarshallers = format_unmarshallers + if custom_formatters is None: + custom_formatters = {} + self.custom_formatters = custom_formatters + + def unmarshal(self, schema_format: str, value: Any) -> Any: + format_unmarshaller = self.get_unmarshaller(schema_format) + if format_unmarshaller is None: + return value + try: + return format_unmarshaller(value) + except (ValueError, TypeError) as exc: + raise FormatUnmarshalError(value, schema_format, exc) + + def get_unmarshaller( + self, schema_format: str + ) -> Optional[FormatUnmarshaller]: + if schema_format in self.custom_formatters: + formatter = self.custom_formatters[schema_format] + return formatter.format + if schema_format in self.format_unmarshallers: + return self.format_unmarshallers[schema_format] + + return None + + class SchemaUnmarshaller: def __init__( self, schema: Spec, schema_validator: SchemaValidator, - schema_unmarshallers_factory: "SchemaUnmarshallersFactory", - format_unmarshallers_factory: "SchemaFormatUnmarshallersFactory", types_unmarshaller: TypesUnmarshaller, + formats_unmarshaller: FormatsUnmarshaller, ): self.schema = schema self.schema_validator = schema_validator - self.schema_unmarshallers_factory = schema_unmarshallers_factory - self.format_unmarshallers_factory = format_unmarshallers_factory - self.types_unmarshaller = types_unmarshaller + self.formats_unmarshaller = formats_unmarshaller def __call__(self, value: Any) -> Any: warnings.warn( @@ -305,7 +265,7 @@ def __call__(self, value: Any) -> Any: ) return self.unmarshal(value) - def unmarshal(self, value: Any, subschemas: bool = True) -> Any: + def unmarshal(self, value: Any) -> Any: self.schema_validator.validate(value) # skip unmarshalling for nullable in OpenAPI 3.0 @@ -313,19 +273,22 @@ def unmarshal(self, value: Any, subschemas: bool = True) -> Any: return value schema_type = self.schema.getkey("type") - unmarshaller = self.get_unmarshaller(schema_type) - return unmarshaller(value) + type_unmarshaller = self.get_type_unmarshaller(schema_type) + typed = type_unmarshaller(value) + schema_format = self.find_format(value) + if schema_format is None: + return typed + return self.formats_unmarshaller.unmarshal(schema_format, typed) - def get_unmarshaller( + def get_type_unmarshaller( self, - schema_type: Optional[Union[Iterable, str]], - ): - klass = self.types_unmarshaller.get_type_unmarshaller(schema_type) + schema_type: Optional[Union[Iterable[str], str]], + ) -> PrimitiveUnmarshaller: + klass = self.types_unmarshaller.get_unmarshaller(schema_type) return klass( self.schema, self.schema_validator, self, - self.schema_unmarshallers_factory, ) def evolve(self, schema: Spec) -> "SchemaUnmarshaller": @@ -334,7 +297,22 @@ def evolve(self, schema: Spec) -> "SchemaUnmarshaller": return cls( schema, self.schema_validator.evolve(schema), - self.schema_unmarshallers_factory, - self.format_unmarshallers_factory, self.types_unmarshaller, + self.formats_unmarshaller, ) + + def find_format(self, value: Any) -> Optional[str]: + for schema in self.iter_valid_schemas(value): + if "format" in schema: + return str(schema.getkey("format")) + return None + + def iter_valid_schemas(self, value: Any) -> Iterator[Spec]: + yield self.schema + + one_of_schema = self.schema_validator.get_one_of_schema(value) + if one_of_schema is not None: + yield one_of_schema + + yield from self.schema_validator.iter_any_of_schemas(value) + yield from self.schema_validator.iter_all_of_schemas(value) diff --git a/openapi_core/unmarshalling/schemas/util.py b/openapi_core/unmarshalling/schemas/util.py index 91ae690e..f0a3138b 100644 --- a/openapi_core/unmarshalling/schemas/util.py +++ b/openapi_core/unmarshalling/schemas/util.py @@ -2,16 +2,10 @@ from base64 import b64decode from datetime import date from datetime import datetime -from typing import TYPE_CHECKING from typing import Any from typing import Union from uuid import UUID -if TYPE_CHECKING: - StaticMethod = staticmethod[Any] -else: - StaticMethod = staticmethod - def format_date(value: str) -> date: return datetime.strptime(value, "%Y-%m-%d").date() @@ -32,14 +26,3 @@ def format_number(value: str) -> Union[int, float]: return value return float(value) - - -class callable_staticmethod(StaticMethod): - """Callable version of staticmethod. - - Prior to Python 3.10, staticmethods are not directly callable - from inside the class. - """ - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - return self.__func__(*args, **kwargs) diff --git a/openapi_core/unmarshalling/unmarshallers.py b/openapi_core/unmarshalling/unmarshallers.py new file mode 100644 index 00000000..f381811b --- /dev/null +++ b/openapi_core/unmarshalling/unmarshallers.py @@ -0,0 +1,88 @@ +from typing import Any +from typing import Mapping +from typing import Optional +from typing import Tuple + +from openapi_core.casting.schemas import schema_casters_factory +from openapi_core.casting.schemas.factories import SchemaCastersFactory +from openapi_core.deserializing.media_types import ( + media_type_deserializers_factory, +) +from openapi_core.deserializing.media_types.factories import ( + MediaTypeDeserializersFactory, +) +from openapi_core.deserializing.parameters import ( + parameter_deserializers_factory, +) +from openapi_core.deserializing.parameters.factories import ( + ParameterDeserializersFactory, +) +from openapi_core.spec import Spec +from openapi_core.unmarshalling.schemas.factories import ( + SchemaUnmarshallersFactory, +) +from openapi_core.validation.schemas.factories import SchemaValidatorsFactory +from openapi_core.validation.validators import BaseValidator + + +class BaseUnmarshaller(BaseValidator): + schema_unmarshallers_factory: SchemaUnmarshallersFactory = NotImplemented + + def __init__( + self, + spec: Spec, + base_url: Optional[str] = None, + schema_casters_factory: SchemaCastersFactory = schema_casters_factory, + parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory, + media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, + schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + schema_unmarshallers_factory: Optional[ + SchemaUnmarshallersFactory + ] = None, + ): + if schema_validators_factory is None and schema_unmarshallers_factory: + schema_validators_factory = ( + schema_unmarshallers_factory.schema_validators_factory + ) + super().__init__( + spec, + base_url=base_url, + schema_casters_factory=schema_casters_factory, + parameter_deserializers_factory=parameter_deserializers_factory, + media_type_deserializers_factory=media_type_deserializers_factory, + schema_validators_factory=schema_validators_factory, + ) + self.schema_unmarshallers_factory = ( + schema_unmarshallers_factory or self.schema_unmarshallers_factory + ) + if self.schema_unmarshallers_factory is NotImplemented: + raise NotImplementedError( + "schema_unmarshallers_factory is not assigned" + ) + + def _unmarshal_schema(self, schema: Spec, value: Any) -> Any: + unmarshaller = self.schema_unmarshallers_factory.create(schema) + return unmarshaller.unmarshal(value) + + def _get_param_or_header_value( + self, + param_or_header: Spec, + location: Mapping[str, Any], + name: Optional[str] = None, + ) -> Any: + casted, schema = self._get_param_or_header_value_and_schema( + param_or_header, location, name + ) + if schema is None: + return casted + return self._unmarshal_schema(schema, casted) + + def _get_content_value( + self, raw: Any, mimetype: str, content: Spec + ) -> Any: + casted, schema = self._get_content_value_and_schema( + raw, mimetype, content + ) + if schema is None: + return casted + return self._unmarshal_schema(schema, casted) diff --git a/openapi_core/validation/__init__.py b/openapi_core/validation/__init__.py index 96f8098f..21c27dda 100644 --- a/openapi_core/validation/__init__.py +++ b/openapi_core/validation/__init__.py @@ -1,8 +1 @@ """OpenAPI core validation module""" -from openapi_core.validation.request import openapi_request_validator -from openapi_core.validation.response import openapi_response_validator - -__all__ = [ - "openapi_request_validator", - "openapi_response_validator", -] diff --git a/openapi_core/validation/exceptions.py b/openapi_core/validation/exceptions.py index e94096a5..229714bd 100644 --- a/openapi_core/validation/exceptions.py +++ b/openapi_core/validation/exceptions.py @@ -4,10 +4,6 @@ from openapi_core.exceptions import OpenAPIError -class ValidatorDetectError(OpenAPIError): - pass - - @dataclass class ValidationError(OpenAPIError): def __str__(self) -> str: diff --git a/openapi_core/validation/processors.py b/openapi_core/validation/processors.py index 8f7eb3df..860b7006 100644 --- a/openapi_core/validation/processors.py +++ b/openapi_core/validation/processors.py @@ -1,62 +1,34 @@ """OpenAPI core validation processors module""" -from typing import Optional -from typing import Type - +from openapi_core.protocols import Request +from openapi_core.protocols import Response from openapi_core.spec import Spec -from openapi_core.validation.request.datatypes import RequestValidationResult -from openapi_core.validation.request.protocols import Request -from openapi_core.validation.request.protocols import RequestValidator -from openapi_core.validation.request.proxies import SpecRequestValidatorProxy -from openapi_core.validation.response.datatypes import ResponseValidationResult -from openapi_core.validation.response.protocols import Response -from openapi_core.validation.response.protocols import ResponseValidator -from openapi_core.validation.response.proxies import SpecResponseValidatorProxy -from openapi_core.validation.shortcuts import get_validators -from openapi_core.validation.shortcuts import validate_request -from openapi_core.validation.shortcuts import validate_response +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult +from openapi_core.unmarshalling.request.proxies import ( + SpecRequestValidatorProxy, +) +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) +from openapi_core.unmarshalling.response.proxies import ( + SpecResponseValidatorProxy, +) class OpenAPISpecProcessor: def __init__( self, - request_validator: SpecRequestValidatorProxy, - response_validator: SpecResponseValidatorProxy, + request_unmarshaller: SpecRequestValidatorProxy, + response_unmarshaller: SpecResponseValidatorProxy, ): - self.request_validator = request_validator - self.response_validator = response_validator + self.request_unmarshaller = request_unmarshaller + self.response_unmarshaller = response_unmarshaller def process_request( self, spec: Spec, request: Request - ) -> RequestValidationResult: - return self.request_validator.validate(spec, request) + ) -> RequestUnmarshalResult: + return self.request_unmarshaller.validate(spec, request) def process_response( self, spec: Spec, request: Request, response: Response - ) -> ResponseValidationResult: - return self.response_validator.validate(spec, request, response) - - -class OpenAPIProcessor: - def __init__( - self, - spec: Spec, - request_validator_cls: Optional[Type[RequestValidator]] = None, - response_validator_cls: Optional[Type[ResponseValidator]] = None, - ): - self.spec = spec - if request_validator_cls is None or response_validator_cls is None: - validators = get_validators(self.spec) - if request_validator_cls is None: - request_validator_cls = validators.request_cls - if response_validator_cls is None: - response_validator_cls = validators.response_cls - self.request_validator = request_validator_cls(self.spec) - self.response_validator = response_validator_cls(self.spec) - - def process_request(self, request: Request) -> RequestValidationResult: - return self.request_validator.validate(request) - - def process_response( - self, request: Request, response: Response - ) -> ResponseValidationResult: - return self.response_validator.validate(request, response) + ) -> ResponseUnmarshalResult: + return self.response_unmarshaller.validate(spec, request, response) diff --git a/openapi_core/validation/request/__init__.py b/openapi_core/validation/request/__init__.py index 27828cb0..d79102cc 100644 --- a/openapi_core/validation/request/__init__.py +++ b/openapi_core/validation/request/__init__.py @@ -1,16 +1,4 @@ """OpenAPI core validation request module""" -from functools import partial - -from openapi_core.unmarshalling.schemas import ( - oas30_write_schema_unmarshallers_factory, -) -from openapi_core.unmarshalling.schemas import ( - oas31_schema_unmarshallers_factory, -) -from openapi_core.validation.request.proxies import ( - DetectSpecRequestValidatorProxy, -) -from openapi_core.validation.request.proxies import SpecRequestValidatorProxy from openapi_core.validation.request.validators import V30RequestBodyValidator from openapi_core.validation.request.validators import ( V30RequestParametersValidator, @@ -55,37 +43,8 @@ "V31WebhookRequestValidator", "V3RequestValidator", "V3WebhookRequestValidator", - "openapi_v30_request_validator", - "openapi_v31_request_validator", - "openapi_v3_request_validator", - "openapi_request_validator", ] # alias to the latest v3 version V3RequestValidator = V31RequestValidator V3WebhookRequestValidator = V31WebhookRequestValidator - -# spec validators -openapi_v30_request_validator = SpecRequestValidatorProxy( - "APICallRequestUnmarshaller", - schema_unmarshallers_factory=oas30_write_schema_unmarshallers_factory, - deprecated="openapi_v30_request_validator", - use="V30RequestValidator", -) -openapi_v31_request_validator = SpecRequestValidatorProxy( - "APICallRequestUnmarshaller", - schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, - deprecated="openapi_v31_request_validator", - use="V31RequestValidator", -) - -# spec validators alias to the latest v3 version -openapi_v3_request_validator = openapi_v31_request_validator - -# detect version spec -openapi_request_validator = DetectSpecRequestValidatorProxy( - { - ("openapi", "3.0"): openapi_v30_request_validator, - ("openapi", "3.1"): openapi_v31_request_validator, - }, -) diff --git a/openapi_core/validation/request/datatypes.py b/openapi_core/validation/request/datatypes.py index c359ad38..b673410a 100644 --- a/openapi_core/validation/request/datatypes.py +++ b/openapi_core/validation/request/datatypes.py @@ -1,51 +1,8 @@ -"""OpenAPI core validation request datatypes module""" -from __future__ import annotations - -from dataclasses import dataclass -from dataclasses import field -from typing import Any -from typing import Mapping - -from werkzeug.datastructures import Headers -from werkzeug.datastructures import ImmutableMultiDict - -from openapi_core.validation.datatypes import BaseValidationResult - - -@dataclass -class RequestParameters: - """OpenAPI request parameters dataclass. - - Attributes: - query - Query string parameters as MultiDict. Must support getlist method. - header - Request headers as Headers. - cookie - Request cookies as MultiDict. - path - Path parameters as dict. Gets resolved against spec if empty. - """ - - query: Mapping[str, Any] = field(default_factory=ImmutableMultiDict) - header: Mapping[str, Any] = field(default_factory=Headers) - cookie: Mapping[str, Any] = field(default_factory=ImmutableMultiDict) - path: Mapping[str, Any] = field(default_factory=dict) - - def __getitem__(self, location: str) -> Any: - return getattr(self, location) - - -@dataclass -class Parameters: - query: Mapping[str, Any] = field(default_factory=dict) - header: Mapping[str, Any] = field(default_factory=dict) - cookie: Mapping[str, Any] = field(default_factory=dict) - path: Mapping[str, Any] = field(default_factory=dict) - - -@dataclass -class RequestValidationResult(BaseValidationResult): - body: str | None = None - parameters: Parameters = field(default_factory=Parameters) - security: dict[str, str] | None = None +from openapi_core.datatypes import Parameters +from openapi_core.datatypes import RequestParameters + +# Backward compatibility +__all__ = [ + "Parameters", + "RequestParameters", +] diff --git a/openapi_core/validation/request/exceptions.py b/openapi_core/validation/request/exceptions.py index 54a02b8a..678b3105 100644 --- a/openapi_core/validation/request/exceptions.py +++ b/openapi_core/validation/request/exceptions.py @@ -2,10 +2,10 @@ from dataclasses import dataclass from typing import Iterable +from openapi_core.datatypes import Parameters from openapi_core.exceptions import OpenAPIError from openapi_core.spec import Spec from openapi_core.validation.exceptions import ValidationError -from openapi_core.validation.request.datatypes import Parameters from openapi_core.validation.schemas.exceptions import ValidateError diff --git a/openapi_core/validation/request/protocols.py b/openapi_core/validation/request/protocols.py index d0671d36..6e2677fd 100644 --- a/openapi_core/validation/request/protocols.py +++ b/openapi_core/validation/request/protocols.py @@ -1,5 +1,6 @@ """OpenAPI core validation request protocols module""" import sys +from typing import Iterator from typing import Optional if sys.version_info >= (3, 8): @@ -9,109 +10,9 @@ from typing_extensions import Protocol from typing_extensions import runtime_checkable +from openapi_core.protocols import Request +from openapi_core.protocols import WebhookRequest from openapi_core.spec import Spec -from openapi_core.validation.request.datatypes import RequestParameters -from openapi_core.validation.request.datatypes import RequestValidationResult - - -@runtime_checkable -class BaseRequest(Protocol): - parameters: RequestParameters - - @property - def method(self) -> str: - ... - - @property - def body(self) -> Optional[str]: - ... - - @property - def mimetype(self) -> str: - ... - - -@runtime_checkable -class Request(BaseRequest, Protocol): - """Request attributes protocol. - - Attributes: - host_url - Url with scheme and host - For example: - https://localhost:8000 - path - Request path - full_url_pattern - The matched url with scheme, host and path pattern. - For example: - https://localhost:8000/api/v1/pets - https://localhost:8000/api/v1/pets/{pet_id} - method - The request method, as lowercase string. - parameters - A RequestParameters object. Needs to supports path attribute setter - to write resolved path parameters. - body - The request body, as string. - mimetype - Like content type, but without parameters (eg, without charset, - type etc.) and always lowercase. - For example if the content type is "text/HTML; charset=utf-8" - the mimetype would be "text/html". - """ - - @property - def host_url(self) -> str: - ... - - @property - def path(self) -> str: - ... - - -@runtime_checkable -class WebhookRequest(BaseRequest, Protocol): - """Webhook request attributes protocol. - - Attributes: - name - Webhook name - method - The request method, as lowercase string. - parameters - A RequestParameters object. Needs to supports path attribute setter - to write resolved path parameters. - body - The request body, as string. - mimetype - Like content type, but without parameters (eg, without charset, - type etc.) and always lowercase. - For example if the content type is "text/HTML; charset=utf-8" - the mimetype would be "text/html". - """ - - @property - def name(self) -> str: - ... - - -@runtime_checkable -class SupportsPathPattern(Protocol): - """Supports path_pattern attribute protocol. - - You also need to provide path variables in RequestParameters. - - Attributes: - path_pattern - The matched path pattern. - For example: - /api/v1/pets/{pet_id} - """ - - @property - def path_pattern(self) -> str: - ... @runtime_checkable @@ -119,10 +20,16 @@ class RequestValidator(Protocol): def __init__(self, spec: Spec, base_url: Optional[str] = None): ... + def iter_errors( + self, + request: Request, + ) -> Iterator[Exception]: + ... + def validate( self, request: Request, - ) -> RequestValidationResult: + ) -> None: ... @@ -131,8 +38,14 @@ class WebhookRequestValidator(Protocol): def __init__(self, spec: Spec, base_url: Optional[str] = None): ... + def iter_errors( + self, + request: WebhookRequest, + ) -> Iterator[Exception]: + ... + def validate( self, request: WebhookRequest, - ) -> RequestValidationResult: + ) -> None: ... diff --git a/openapi_core/validation/request/types.py b/openapi_core/validation/request/types.py new file mode 100644 index 00000000..068e8cc6 --- /dev/null +++ b/openapi_core/validation/request/types.py @@ -0,0 +1,11 @@ +from typing import Type +from typing import Union + +from openapi_core.validation.request.protocols import RequestValidator +from openapi_core.validation.request.protocols import WebhookRequestValidator + +RequestValidatorType = Type[RequestValidator] +WebhookRequestValidatorType = Type[WebhookRequestValidator] +AnyRequestValidatorType = Union[ + RequestValidatorType, WebhookRequestValidatorType +] diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index 80d823f1..c79bfe3e 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -4,13 +4,11 @@ from typing import Dict from typing import Iterator from typing import Optional -from typing import Tuple -from urllib.parse import urljoin from openapi_core.casting.schemas import schema_casters_factory -from openapi_core.casting.schemas.exceptions import CastError from openapi_core.casting.schemas.factories import SchemaCastersFactory -from openapi_core.deserializing.exceptions import DeserializeError +from openapi_core.datatypes import Parameters +from openapi_core.datatypes import RequestParameters from openapi_core.deserializing.media_types import ( media_type_deserializers_factory, ) @@ -23,32 +21,18 @@ from openapi_core.deserializing.parameters.factories import ( ParameterDeserializersFactory, ) -from openapi_core.exceptions import OpenAPIError +from openapi_core.protocols import BaseRequest +from openapi_core.protocols import Request +from openapi_core.protocols import WebhookRequest from openapi_core.security import security_provider_factory from openapi_core.security.exceptions import SecurityProviderError from openapi_core.security.factories import SecurityProviderFactory from openapi_core.spec.paths import Spec -from openapi_core.templating.media_types.exceptions import MediaTypeFinderError -from openapi_core.templating.paths.datatypes import PathOperationServer from openapi_core.templating.paths.exceptions import PathError -from openapi_core.templating.paths.finders import APICallPathFinder from openapi_core.templating.paths.finders import WebhookPathFinder from openapi_core.templating.security.exceptions import SecurityNotFound -from openapi_core.unmarshalling.schemas import ( - oas30_write_schema_unmarshallers_factory, -) -from openapi_core.unmarshalling.schemas import ( - oas31_schema_unmarshallers_factory, -) -from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError -from openapi_core.unmarshalling.schemas.factories import ( - SchemaUnmarshallersFactory, -) from openapi_core.util import chainiters from openapi_core.validation.decorators import ValidationErrorWrapper -from openapi_core.validation.request.datatypes import Parameters -from openapi_core.validation.request.datatypes import RequestParameters -from openapi_core.validation.request.datatypes import RequestValidationResult from openapi_core.validation.request.exceptions import InvalidParameter from openapi_core.validation.request.exceptions import InvalidRequestBody from openapi_core.validation.request.exceptions import InvalidSecurity @@ -62,10 +46,6 @@ from openapi_core.validation.request.exceptions import ParametersError from openapi_core.validation.request.exceptions import RequestBodyError from openapi_core.validation.request.exceptions import SecurityError -from openapi_core.validation.request.protocols import BaseRequest -from openapi_core.validation.request.protocols import Request -from openapi_core.validation.request.protocols import WebhookRequest -from openapi_core.validation.request.proxies import SpecRequestValidatorProxy from openapi_core.validation.schemas import ( oas30_write_schema_validators_factory, ) @@ -97,88 +77,49 @@ def __init__( ) self.security_provider_factory = security_provider_factory - def _validate( + def _iter_errors( self, request: BaseRequest, operation: Spec, path: Spec - ) -> RequestValidationResult: + ) -> Iterator[Exception]: try: - security = self._get_security(request.parameters, operation) + self._get_security(request.parameters, operation) + # don't process if security errors except SecurityError as exc: - return RequestValidationResult(errors=[exc]) + yield exc + return try: - params = self._get_parameters(request.parameters, operation, path) + self._get_parameters(request.parameters, operation, path) except ParametersError as exc: - params = exc.parameters - params_errors = exc.errors - else: - params_errors = [] + yield from exc.errors try: - body = self._get_body(request.body, request.mimetype, operation) - except MissingRequestBody: - body = None - body_errors = [] + self._get_body(request.body, request.mimetype, operation) except RequestBodyError as exc: - body = None - body_errors = [exc] - else: - body_errors = [] - - errors = list(chainiters(params_errors, body_errors)) - return RequestValidationResult( - errors=errors, - body=body, - parameters=params, - security=security, - ) + yield exc - def _validate_body( + def _iter_body_errors( self, request: BaseRequest, operation: Spec - ) -> RequestValidationResult: + ) -> Iterator[Exception]: try: - body = self._get_body(request.body, request.mimetype, operation) - except MissingRequestBody: - body = None - errors = [] + self._get_body(request.body, request.mimetype, operation) except RequestBodyError as exc: - body = None - errors = [exc] - else: - errors = [] - - return RequestValidationResult( - errors=errors, - body=body, - ) + yield exc - def _validate_parameters( + def _iter_parameters_errors( self, request: BaseRequest, operation: Spec, path: Spec - ) -> RequestValidationResult: + ) -> Iterator[Exception]: try: - params = self._get_parameters(request.parameters, path, operation) + self._get_parameters(request.parameters, path, operation) except ParametersError as exc: - params = exc.parameters - params_errors = exc.errors - else: - params_errors = [] - - return RequestValidationResult( - errors=params_errors, - parameters=params, - ) + yield from exc.errors - def _validate_security( + def _iter_security_errors( self, request: BaseRequest, operation: Spec - ) -> RequestValidationResult: + ) -> Iterator[Exception]: try: - security = self._get_security(request.parameters, operation) + self._get_security(request.parameters, operation) except SecurityError as exc: - return RequestValidationResult(errors=[exc]) - - return RequestValidationResult( - errors=[], - security=security, - ) + yield exc def _get_parameters( self, @@ -286,7 +227,7 @@ def _get_security_value( @ValidationErrorWrapper(RequestBodyError, InvalidRequestBody) def _get_body( self, body: Optional[str], mimetype: str, operation: Spec - ) -> Tuple[Any, Optional[Spec]]: + ) -> Any: if "requestBody" not in operation: return None @@ -295,10 +236,7 @@ def _get_body( content = request_body / "content" raw_body = self._get_body_value(body, request_body) - casted, _ = self._get_content_value_and_schema( - raw_body, mimetype, content - ) - return casted + return self._get_content_value(raw_body, mimetype, content) def _get_body_value(self, body: Optional[str], request_body: Spec) -> Any: if not body: @@ -310,118 +248,126 @@ def _get_body_value(self, body: Optional[str], request_body: Spec) -> Any: class BaseAPICallRequestValidator(BaseRequestValidator, BaseAPICallValidator): def iter_errors(self, request: Request) -> Iterator[Exception]: - result = self.validate(request) - yield from result.errors - - def validate(self, request: Request) -> RequestValidationResult: raise NotImplementedError + def validate(self, request: Request) -> None: + for err in self.iter_errors(request): + raise err + class BaseWebhookRequestValidator(BaseRequestValidator, BaseWebhookValidator): def iter_errors(self, request: WebhookRequest) -> Iterator[Exception]: - result = self.validate(request) - yield from result.errors - - def validate(self, request: WebhookRequest) -> RequestValidationResult: raise NotImplementedError + def validate(self, request: WebhookRequest) -> None: + for err in self.iter_errors(request): + raise err + class APICallRequestBodyValidator(BaseAPICallRequestValidator): - def validate(self, request: Request) -> RequestValidationResult: + def iter_errors(self, request: Request) -> Iterator[Exception]: try: _, operation, _, _, _ = self._find_path(request) except PathError as exc: - return RequestValidationResult(errors=[exc]) + yield exc + return - return self._validate_body(request, operation) + yield from self._iter_body_errors(request, operation) class APICallRequestParametersValidator(BaseAPICallRequestValidator): - def validate(self, request: Request) -> RequestValidationResult: + def iter_errors(self, request: Request) -> Iterator[Exception]: try: path, operation, _, path_result, _ = self._find_path(request) except PathError as exc: - return RequestValidationResult(errors=[exc]) + yield exc + return request.parameters.path = ( request.parameters.path or path_result.variables ) - return self._validate_parameters(request, operation, path) + yield from self._iter_parameters_errors(request, operation, path) class APICallRequestSecurityValidator(BaseAPICallRequestValidator): - def validate(self, request: Request) -> RequestValidationResult: + def iter_errors(self, request: Request) -> Iterator[Exception]: try: _, operation, _, _, _ = self._find_path(request) except PathError as exc: - return RequestValidationResult(errors=[exc]) + yield exc + return - return self._validate_security(request, operation) + yield from self._iter_security_errors(request, operation) class APICallRequestValidator(BaseAPICallRequestValidator): - def validate(self, request: Request) -> RequestValidationResult: + def iter_errors(self, request: Request) -> Iterator[Exception]: try: path, operation, _, path_result, _ = self._find_path(request) # don't process if operation errors except PathError as exc: - return RequestValidationResult(errors=[exc]) + yield exc + return request.parameters.path = ( request.parameters.path or path_result.variables ) - return self._validate(request, operation, path) + yield from self._iter_errors(request, operation, path) class WebhookRequestValidator(BaseWebhookRequestValidator): - def validate(self, request: WebhookRequest) -> RequestValidationResult: + def iter_errors(self, request: WebhookRequest) -> Iterator[Exception]: try: path, operation, _, path_result, _ = self._find_path(request) # don't process if operation errors except PathError as exc: - return RequestValidationResult(errors=[exc]) + yield exc + return request.parameters.path = ( request.parameters.path or path_result.variables ) - return self._validate(request, operation, path) + yield from self._iter_errors(request, operation, path) class WebhookRequestBodyValidator(BaseWebhookRequestValidator): - def validate(self, request: WebhookRequest) -> RequestValidationResult: + def iter_errors(self, request: WebhookRequest) -> Iterator[Exception]: try: _, operation, _, _, _ = self._find_path(request) except PathError as exc: - return RequestValidationResult(errors=[exc]) + yield exc + return - return self._validate_body(request, operation) + yield from self._iter_body_errors(request, operation) class WebhookRequestParametersValidator(BaseWebhookRequestValidator): - def validate(self, request: WebhookRequest) -> RequestValidationResult: + def iter_errors(self, request: WebhookRequest) -> Iterator[Exception]: try: path, operation, _, path_result, _ = self._find_path(request) except PathError as exc: - return RequestValidationResult(errors=[exc]) + yield exc + return request.parameters.path = ( request.parameters.path or path_result.variables ) - return self._validate_parameters(request, operation, path) + yield from self._iter_parameters_errors(request, operation, path) class WebhookRequestSecurityValidator(BaseWebhookRequestValidator): - def validate(self, request: WebhookRequest) -> RequestValidationResult: + def iter_errors(self, request: WebhookRequest) -> Iterator[Exception]: try: _, operation, _, _, _ = self._find_path(request) except PathError as exc: - return RequestValidationResult(errors=[exc]) + yield exc + return - return self._validate_security(request, operation) + yield from self._iter_security_errors(request, operation) class V30RequestBodyValidator(APICallRequestBodyValidator): @@ -475,20 +421,3 @@ class V31WebhookRequestSecurityValidator(WebhookRequestSecurityValidator): class V31WebhookRequestValidator(WebhookRequestValidator): schema_validators_factory = oas31_schema_validators_factory path_finder_cls = WebhookPathFinder - - -# backward compatibility -class RequestValidator(SpecRequestValidatorProxy): - def __init__( - self, - schema_unmarshallers_factory: SchemaUnmarshallersFactory, - **kwargs: Any, - ): - super().__init__( - "APICallRequestUnmarshaller", - schema_validators_factory=( - schema_unmarshallers_factory.schema_validators_factory - ), - schema_unmarshallers_factory=schema_unmarshallers_factory, - **kwargs, - ) diff --git a/openapi_core/validation/response/__init__.py b/openapi_core/validation/response/__init__.py index fcb2b036..5c62af3f 100644 --- a/openapi_core/validation/response/__init__.py +++ b/openapi_core/validation/response/__init__.py @@ -1,16 +1,4 @@ """OpenAPI core validation response module""" -from functools import partial - -from openapi_core.unmarshalling.schemas import ( - oas30_read_schema_unmarshallers_factory, -) -from openapi_core.unmarshalling.schemas import ( - oas31_schema_unmarshallers_factory, -) -from openapi_core.validation.response.proxies import ( - DetectResponseValidatorProxy, -) -from openapi_core.validation.response.proxies import SpecResponseValidatorProxy from openapi_core.validation.response.validators import ( V30ResponseDataValidator, ) @@ -47,38 +35,8 @@ "V31WebhookResponseValidator", "V3ResponseValidator", "V3WebhookResponseValidator", - "openapi_v30_response_validator", - "openapi_v31_response_validator", - "openapi_v3_response_validator", - "openapi_response_validator", ] # alias to the latest v3 version V3ResponseValidator = V31ResponseValidator V3WebhookResponseValidator = V31WebhookResponseValidator - -# spec validators -openapi_v30_response_validator = SpecResponseValidatorProxy( - "APICallResponseUnmarshaller", - schema_unmarshallers_factory=oas30_read_schema_unmarshallers_factory, - deprecated="openapi_v30_response_validator", - use="V30ResponseValidator", -) - -openapi_v31_response_validator = SpecResponseValidatorProxy( - "APICallResponseUnmarshaller", - schema_unmarshallers_factory=oas31_schema_unmarshallers_factory, - deprecated="openapi_v31_response_validator", - use="V31ResponseValidator", -) - -# spec validators alias to the latest v3 version -openapi_v3_response_validator = openapi_v31_response_validator - -# detect version spec -openapi_response_validator = DetectResponseValidatorProxy( - { - ("openapi", "3.0"): openapi_v30_response_validator, - ("openapi", "3.1"): openapi_v31_response_validator, - }, -) diff --git a/openapi_core/validation/response/protocols.py b/openapi_core/validation/response/protocols.py index dfcb9a87..d23b7a1a 100644 --- a/openapi_core/validation/response/protocols.py +++ b/openapi_core/validation/response/protocols.py @@ -1,7 +1,6 @@ """OpenAPI core validation response protocols module""" import sys -from typing import Any -from typing import Mapping +from typing import Iterator from typing import Optional if sys.version_info >= (3, 8): @@ -11,42 +10,10 @@ from typing_extensions import Protocol from typing_extensions import runtime_checkable +from openapi_core.protocols import Request +from openapi_core.protocols import Response +from openapi_core.protocols import WebhookRequest from openapi_core.spec import Spec -from openapi_core.validation.request.protocols import Request -from openapi_core.validation.request.protocols import WebhookRequest -from openapi_core.validation.response.datatypes import ResponseValidationResult - - -@runtime_checkable -class Response(Protocol): - """Response protocol. - - Attributes: - data - The response body, as string. - status_code - The status code as integer. - headers - Response headers as Headers. - mimetype - Lowercase content type without charset. - """ - - @property - def data(self) -> str: - ... - - @property - def status_code(self) -> int: - ... - - @property - def mimetype(self) -> str: - ... - - @property - def headers(self) -> Mapping[str, Any]: - ... @runtime_checkable @@ -54,11 +21,18 @@ class ResponseValidator(Protocol): def __init__(self, spec: Spec, base_url: Optional[str] = None): ... + def iter_errors( + self, + request: Request, + response: Response, + ) -> Iterator[Exception]: + ... + def validate( self, request: Request, response: Response, - ) -> ResponseValidationResult: + ) -> None: ... @@ -67,9 +41,16 @@ class WebhookResponseValidator(Protocol): def __init__(self, spec: Spec, base_url: Optional[str] = None): ... + def iter_errors( + self, + request: WebhookRequest, + response: Response, + ) -> Iterator[Exception]: + ... + def validate( self, request: WebhookRequest, response: Response, - ) -> ResponseValidationResult: + ) -> None: ... diff --git a/openapi_core/validation/response/types.py b/openapi_core/validation/response/types.py new file mode 100644 index 00000000..3446dd4d --- /dev/null +++ b/openapi_core/validation/response/types.py @@ -0,0 +1,11 @@ +from typing import Type +from typing import Union + +from openapi_core.validation.response.protocols import ResponseValidator +from openapi_core.validation.response.protocols import WebhookResponseValidator + +ResponseValidatorType = Type[ResponseValidator] +WebhookResponseValidatorType = Type[WebhookResponseValidator] +AnyResponseValidatorType = Union[ + ResponseValidatorType, WebhookResponseValidatorType +] diff --git a/openapi_core/validation/response/validators.py b/openapi_core/validation/response/validators.py index b80e22c4..31a9bdfc 100644 --- a/openapi_core/validation/response/validators.py +++ b/openapi_core/validation/response/validators.py @@ -5,35 +5,16 @@ from typing import Iterator from typing import List from typing import Mapping -from typing import Optional -from urllib.parse import urljoin -from openapi_core.casting.schemas.exceptions import CastError -from openapi_core.deserializing.exceptions import DeserializeError from openapi_core.exceptions import OpenAPIError +from openapi_core.protocols import Request +from openapi_core.protocols import Response +from openapi_core.protocols import WebhookRequest from openapi_core.spec import Spec -from openapi_core.templating.media_types.exceptions import MediaTypeFinderError -from openapi_core.templating.paths.datatypes import PathOperationServer from openapi_core.templating.paths.exceptions import PathError -from openapi_core.templating.paths.finders import APICallPathFinder -from openapi_core.templating.paths.finders import WebhookPathFinder from openapi_core.templating.responses.exceptions import ResponseFinderError -from openapi_core.unmarshalling.schemas import ( - oas30_read_schema_unmarshallers_factory, -) -from openapi_core.unmarshalling.schemas import ( - oas31_schema_unmarshallers_factory, -) -from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError -from openapi_core.unmarshalling.schemas.factories import ( - SchemaUnmarshallersFactory, -) -from openapi_core.util import chainiters from openapi_core.validation.decorators import ValidationErrorWrapper from openapi_core.validation.exceptions import ValidationError -from openapi_core.validation.request.protocols import Request -from openapi_core.validation.request.protocols import WebhookRequest -from openapi_core.validation.response.datatypes import ResponseValidationResult from openapi_core.validation.response.exceptions import DataError from openapi_core.validation.response.exceptions import HeaderError from openapi_core.validation.response.exceptions import HeadersError @@ -42,8 +23,6 @@ from openapi_core.validation.response.exceptions import MissingData from openapi_core.validation.response.exceptions import MissingHeader from openapi_core.validation.response.exceptions import MissingRequiredHeader -from openapi_core.validation.response.protocols import Response -from openapi_core.validation.response.proxies import SpecResponseValidatorProxy from openapi_core.validation.schemas import ( oas30_read_schema_validators_factory, ) @@ -54,92 +33,66 @@ class BaseResponseValidator(BaseValidator): - def _validate( + def _iter_errors( self, status_code: int, data: str, headers: Mapping[str, Any], mimetype: str, operation: Spec, - ) -> ResponseValidationResult: + ) -> Iterator[Exception]: try: operation_response = self._get_operation_response( status_code, operation ) # don't process if operation errors except ResponseFinderError as exc: - return ResponseValidationResult(errors=[exc]) + yield exc + return try: - validated_data = self._get_data(data, mimetype, operation_response) + self._get_data(data, mimetype, operation_response) except DataError as exc: - validated_data = None - data_errors = [exc] - else: - data_errors = [] + yield exc try: - validated_headers = self._get_headers(headers, operation_response) + self._get_headers(headers, operation_response) except HeadersError as exc: - validated_headers = exc.headers - headers_errors = exc.context - else: - headers_errors = [] - - errors = list(chainiters(data_errors, headers_errors)) - return ResponseValidationResult( - errors=errors, - data=validated_data, - headers=validated_headers, - ) + yield from exc.context - def _validate_data( + def _iter_data_errors( self, status_code: int, data: str, mimetype: str, operation: Spec - ) -> ResponseValidationResult: + ) -> Iterator[Exception]: try: operation_response = self._get_operation_response( status_code, operation ) # don't process if operation errors except ResponseFinderError as exc: - return ResponseValidationResult(errors=[exc]) + yield exc + return try: - validated = self._get_data(data, mimetype, operation_response) + self._get_data(data, mimetype, operation_response) except DataError as exc: - validated = None - data_errors = [exc] - else: - data_errors = [] - - return ResponseValidationResult( - errors=data_errors, - data=validated, - ) + yield exc - def _validate_headers( + def _iter_headers_errors( self, status_code: int, headers: Mapping[str, Any], operation: Spec - ) -> ResponseValidationResult: + ) -> Iterator[Exception]: try: operation_response = self._get_operation_response( status_code, operation ) # don't process if operation errors except ResponseFinderError as exc: - return ResponseValidationResult(errors=[exc]) + yield exc + return try: - validated = self._get_headers(headers, operation_response) + self._get_headers(headers, operation_response) except HeadersError as exc: - validated = exc.headers - headers_errors = exc.context - else: - headers_errors = [] - - return ResponseValidationResult( - errors=headers_errors, - headers=validated, - ) + yield from exc.context def _get_operation_response( self, @@ -161,10 +114,7 @@ def _get_data( content = operation_response / "content" raw_data = self._get_data_value(data) - casted, _ = self._get_content_value_and_schema( - raw_data, mimetype, content - ) - return casted + return self._get_content_value(raw_data, mimetype, content) def _get_data_value(self, data: str) -> Any: if not data: @@ -229,15 +179,15 @@ def iter_errors( request: Request, response: Response, ) -> Iterator[Exception]: - result = self.validate(request, response) - yield from result.errors + raise NotImplementedError def validate( self, request: Request, response: Response, - ) -> ResponseValidationResult: - raise NotImplementedError + ) -> None: + for err in self.iter_errors(request, response): + raise err class BaseWebhookResponseValidator( @@ -248,64 +198,67 @@ def iter_errors( request: WebhookRequest, response: Response, ) -> Iterator[Exception]: - result = self.validate(request, response) - yield from result.errors + raise NotImplementedError def validate( self, request: WebhookRequest, response: Response, - ) -> ResponseValidationResult: - raise NotImplementedError + ) -> None: + for err in self.iter_errors(request, response): + raise err class APICallResponseDataValidator(BaseAPICallResponseValidator): - def validate( + def iter_errors( self, request: Request, response: Response, - ) -> ResponseValidationResult: + ) -> Iterator[Exception]: try: _, operation, _, _, _ = self._find_path(request) # don't process if operation errors except PathError as exc: - return ResponseValidationResult(errors=[exc]) + yield exc + return - return self._validate_data( + yield from self._iter_data_errors( response.status_code, response.data, response.mimetype, operation ) class APICallResponseHeadersValidator(BaseAPICallResponseValidator): - def validate( + def iter_errors( self, request: Request, response: Response, - ) -> ResponseValidationResult: + ) -> Iterator[Exception]: try: _, operation, _, _, _ = self._find_path(request) # don't process if operation errors except PathError as exc: - return ResponseValidationResult(errors=[exc]) + yield exc + return - return self._validate_headers( + yield from self._iter_headers_errors( response.status_code, response.headers, operation ) class APICallResponseValidator(BaseAPICallResponseValidator): - def validate( + def iter_errors( self, request: Request, response: Response, - ) -> ResponseValidationResult: + ) -> Iterator[Exception]: try: _, operation, _, _, _ = self._find_path(request) # don't process if operation errors except PathError as exc: - return ResponseValidationResult(errors=[exc]) + yield exc + return - return self._validate( + yield from self._iter_errors( response.status_code, response.data, response.headers, @@ -315,52 +268,55 @@ def validate( class WebhookResponseDataValidator(BaseWebhookResponseValidator): - def validate( + def iter_errors( self, request: WebhookRequest, response: Response, - ) -> ResponseValidationResult: + ) -> Iterator[Exception]: try: _, operation, _, _, _ = self._find_path(request) # don't process if operation errors except PathError as exc: - return ResponseValidationResult(errors=[exc]) + yield exc + return - return self._validate_data( + yield from self._iter_data_errors( response.status_code, response.data, response.mimetype, operation ) class WebhookResponseHeadersValidator(BaseWebhookResponseValidator): - def validate( + def iter_errors( self, request: WebhookRequest, response: Response, - ) -> ResponseValidationResult: + ) -> Iterator[Exception]: try: _, operation, _, _, _ = self._find_path(request) # don't process if operation errors except PathError as exc: - return ResponseValidationResult(errors=[exc]) + yield exc + return - return self._validate_headers( + yield from self._iter_headers_errors( response.status_code, response.headers, operation ) class WebhookResponseValidator(BaseWebhookResponseValidator): - def validate( + def iter_errors( self, request: WebhookRequest, response: Response, - ) -> ResponseValidationResult: + ) -> Iterator[Exception]: try: _, operation, _, _, _ = self._find_path(request) # don't process if operation errors except PathError as exc: - return ResponseValidationResult(errors=[exc]) + yield exc + return - return self._validate( + yield from self._iter_errors( response.status_code, response.data, response.headers, @@ -403,20 +359,3 @@ class V31WebhookResponseHeadersValidator(WebhookResponseHeadersValidator): class V31WebhookResponseValidator(WebhookResponseValidator): schema_validators_factory = oas31_schema_validators_factory - - -# backward compatibility -class ResponseValidator(SpecResponseValidatorProxy): - def __init__( - self, - schema_unmarshallers_factory: SchemaUnmarshallersFactory, - **kwargs: Any, - ): - super().__init__( - "APICallResponseUnmarshaller", - schema_validators_factory=( - schema_unmarshallers_factory.schema_validators_factory - ), - schema_unmarshallers_factory=schema_unmarshallers_factory, - **kwargs, - ) diff --git a/openapi_core/validation/schemas/factories.py b/openapi_core/validation/schemas/factories.py index 3a0e9984..41122724 100644 --- a/openapi_core/validation/schemas/factories.py +++ b/openapi_core/validation/schemas/factories.py @@ -1,23 +1,13 @@ -import sys from copy import deepcopy -from functools import partial -from typing import Any -from typing import Callable -from typing import Dict +from typing import Mapping from typing import Optional from typing import Type -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from backports.cached_property import cached_property from jsonschema._format import FormatChecker from jsonschema.protocols import Validator from openapi_core.spec import Spec -from openapi_core.unmarshalling.schemas.datatypes import CustomFormattersDict from openapi_core.validation.schemas.datatypes import FormatValidator -from openapi_core.validation.schemas.util import build_format_checker from openapi_core.validation.schemas.validators import SchemaValidator @@ -26,37 +16,30 @@ def __init__( self, schema_validator_class: Type[Validator], format_checker: Optional[FormatChecker] = None, - formatters: Optional[CustomFormattersDict] = None, - custom_formatters: Optional[CustomFormattersDict] = None, ): self.schema_validator_class = schema_validator_class if format_checker is None: format_checker = self.schema_validator_class.FORMAT_CHECKER - self.format_checker = deepcopy(format_checker) - if formatters is None: - formatters = {} - self.formatters = formatters - if custom_formatters is None: - custom_formatters = {} - self.custom_formatters = custom_formatters + self.format_checker = format_checker - def add_checks(self, **format_checks) -> None: + def get_format_checker( + self, **format_checks: FormatValidator + ) -> FormatChecker: + format_checker = deepcopy(self.format_checker) for name, check in format_checks.items(): - self.format_checker.checks(name)(check) + format_checker.checks(name)(check) + return format_checker - def get_checker(self, name: str) -> FormatValidator: - if name in self.format_checker.checkers: - return partial(self.format_checker.check, format=name) - - return lambda x: True - - def create(self, schema: Spec) -> Validator: + def create( + self, schema: Spec, **format_checks: FormatValidator + ) -> Validator: + format_checker = self.get_format_checker(**format_checks) resolver = schema.accessor.resolver # type: ignore with schema.open() as schema_dict: jsonschema_validator = self.schema_validator_class( schema_dict, resolver=resolver, - format_checker=self.format_checker, + format_checker=format_checker, ) return SchemaValidator(schema, jsonschema_validator) diff --git a/openapi_core/validation/schemas/util.py b/openapi_core/validation/schemas/util.py deleted file mode 100644 index 3290f0e3..00000000 --- a/openapi_core/validation/schemas/util.py +++ /dev/null @@ -1,27 +0,0 @@ -"""OpenAPI core validation schemas util module""" -from copy import deepcopy -from functools import lru_cache -from typing import Any -from typing import Callable -from typing import Optional - -from jsonschema._format import FormatChecker - - -@lru_cache() -def build_format_checker( - format_checker: Optional[FormatChecker] = None, - **format_checks: Callable[[Any], Any], -) -> Any: - if format_checker is None: - fc = FormatChecker() - else: - if not format_checks: - return format_checker - fc = deepcopy(format_checker) - - for name, check in format_checks.items(): - if name in fc.checkers: - continue - fc.checks(name)(check) - return fc diff --git a/openapi_core/validation/schemas/validators.py b/openapi_core/validation/schemas/validators.py index b6866e96..e46dad31 100644 --- a/openapi_core/validation/schemas/validators.py +++ b/openapi_core/validation/schemas/validators.py @@ -13,7 +13,7 @@ else: from backports.cached_property import cached_property -from openapi_core import Spec +from openapi_core.spec import Spec from openapi_core.validation.schemas.datatypes import FormatValidator from openapi_core.validation.schemas.exceptions import InvalidSchemaValue from openapi_core.validation.schemas.exceptions import ValidateError diff --git a/openapi_core/validation/shortcuts.py b/openapi_core/validation/shortcuts.py deleted file mode 100644 index 01f846af..00000000 --- a/openapi_core/validation/shortcuts.py +++ /dev/null @@ -1,178 +0,0 @@ -"""OpenAPI core validation shortcuts module""" -import warnings -from typing import Any -from typing import Dict -from typing import NamedTuple -from typing import Optional -from typing import Type -from typing import Union - -from openapi_core.spec import Spec -from openapi_core.validation.exceptions import ValidatorDetectError -from openapi_core.validation.request import V30RequestValidator -from openapi_core.validation.request import V31RequestValidator -from openapi_core.validation.request import V31WebhookRequestValidator -from openapi_core.validation.request.datatypes import RequestValidationResult -from openapi_core.validation.request.protocols import Request -from openapi_core.validation.request.protocols import RequestValidator -from openapi_core.validation.request.protocols import WebhookRequest -from openapi_core.validation.request.protocols import WebhookRequestValidator -from openapi_core.validation.request.proxies import SpecRequestValidatorProxy -from openapi_core.validation.response import V30ResponseValidator -from openapi_core.validation.response import V31ResponseValidator -from openapi_core.validation.response import V31WebhookResponseValidator -from openapi_core.validation.response.datatypes import ResponseValidationResult -from openapi_core.validation.response.protocols import Response -from openapi_core.validation.response.protocols import ResponseValidator -from openapi_core.validation.response.protocols import WebhookResponseValidator -from openapi_core.validation.response.proxies import SpecResponseValidatorProxy - -AnyRequest = Union[Request, WebhookRequest] -RequestValidatorType = Type[RequestValidator] -ResponseValidatorType = Type[ResponseValidator] -WebhookRequestValidatorType = Type[WebhookRequestValidator] -WebhookResponseValidatorType = Type[WebhookResponseValidator] -AnyRequestValidatorType = Union[ - RequestValidatorType, WebhookRequestValidatorType -] -AnyResponseValidatorType = Union[ - ResponseValidatorType, WebhookResponseValidatorType -] - - -class SpecVersion(NamedTuple): - name: str - version: str - - -class SpecValidators(NamedTuple): - request_cls: Type[RequestValidator] - response_cls: Type[ResponseValidator] - webhook_request_cls: Optional[Type[WebhookRequestValidator]] - webhook_response_cls: Optional[Type[WebhookResponseValidator]] - - -SPECS: Dict[SpecVersion, SpecValidators] = { - SpecVersion("openapi", "3.0"): SpecValidators( - V30RequestValidator, - V30ResponseValidator, - None, - None, - ), - SpecVersion("openapi", "3.1"): SpecValidators( - V31RequestValidator, - V31ResponseValidator, - V31WebhookRequestValidator, - V31WebhookResponseValidator, - ), -} - - -def get_validators(spec: Spec) -> SpecValidators: - for v, validators in SPECS.items(): - if v.name in spec and spec[v.name].startswith(v.version): - return validators - raise ValidatorDetectError("Spec schema version not detected") - - -def validate_request( - request: AnyRequest, - spec: Spec, - base_url: Optional[str] = None, - validator: Optional[SpecRequestValidatorProxy] = None, - cls: Optional[AnyRequestValidatorType] = None, - **validator_kwargs: Any, -) -> RequestValidationResult: - if isinstance(spec, (Request, WebhookRequest)) and isinstance( - request, Spec - ): - warnings.warn( - "spec parameter as a first argument is deprecated. " - "Move it to second argument instead.", - DeprecationWarning, - ) - request, spec = spec, request - if not isinstance(request, (Request, WebhookRequest)): - raise TypeError("'request' argument is not type of (Webhook)Request") - if not isinstance(spec, Spec): - raise TypeError("'spec' argument is not type of Spec") - if validator is not None and isinstance(request, Request): - warnings.warn( - "validator parameter is deprecated. Use cls instead.", - DeprecationWarning, - ) - result = validator.validate(spec, request, base_url=base_url) - else: - if cls is None: - validators = get_validators(spec) - if isinstance(request, WebhookRequest): - cls = validators.webhook_request_cls - else: - cls = validators.request_cls - if cls is None: - raise ValidatorDetectError("Validator not found") - assert ( - isinstance(cls, RequestValidator) and isinstance(request, Request) - ) or ( - isinstance(cls, WebhookRequestValidator) - and isinstance(request, WebhookRequest) - ) - v = cls(spec, base_url=base_url, **validator_kwargs) - result = v.validate(request) - result.raise_for_errors() - return result - - -def validate_response( - request: Union[Request, WebhookRequest, Spec], - response: Union[Response, Request, WebhookRequest], - spec: Union[Spec, Response], - base_url: Optional[str] = None, - validator: Optional[SpecResponseValidatorProxy] = None, - cls: Optional[AnyResponseValidatorType] = None, - **validator_kwargs: Any, -) -> ResponseValidationResult: - if ( - isinstance(request, Spec) - and isinstance(response, (Request, WebhookRequest)) - and isinstance(spec, Response) - ): - warnings.warn( - "spec parameter as a first argument is deprecated. " - "Move it to third argument instead.", - DeprecationWarning, - ) - args = request, response, spec - spec, request, response = args - - if not isinstance(request, (Request, WebhookRequest)): - raise TypeError("'request' argument is not type of (Webhook)Request") - if not isinstance(response, Response): - raise TypeError("'response' argument is not type of Response") - if not isinstance(spec, Spec): - raise TypeError("'spec' argument is not type of Spec") - if validator is not None and isinstance(request, Request): - warnings.warn( - "validator parameter is deprecated. Use cls instead.", - DeprecationWarning, - ) - result = validator.validate(spec, request, response, base_url=base_url) - else: - if cls is None: - validators = get_validators(spec) - if isinstance(request, WebhookRequest): - cls = validators.webhook_response_cls - else: - cls = validators.response_cls - if cls is None: - raise ValidatorDetectError("Validator not found") - assert ( - isinstance(cls, ResponseValidator) and isinstance(request, Request) - ) or ( - isinstance(cls, WebhookResponseValidator) - and isinstance(request, WebhookRequest) - ) - v = cls(spec, base_url=base_url, **validator_kwargs) - result = v.validate(request, response) - result.raise_for_errors() - return result diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index d656f377..a465c67d 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -1,11 +1,9 @@ """OpenAPI core validation validators module""" import sys from typing import Any -from typing import Dict from typing import Mapping from typing import Optional from typing import Tuple -from typing import Type from urllib.parse import urljoin if sys.version_info >= (3, 8): @@ -26,6 +24,8 @@ from openapi_core.deserializing.parameters.factories import ( ParameterDeserializersFactory, ) +from openapi_core.protocols import Request +from openapi_core.protocols import WebhookRequest from openapi_core.schema.parameters import get_value from openapi_core.spec import Spec from openapi_core.templating.media_types.datatypes import MediaType @@ -33,18 +33,11 @@ from openapi_core.templating.paths.finders import APICallPathFinder from openapi_core.templating.paths.finders import BasePathFinder from openapi_core.templating.paths.finders import WebhookPathFinder -from openapi_core.unmarshalling.schemas.factories import ( - SchemaUnmarshallersFactory, -) -from openapi_core.validation.request.protocols import Request -from openapi_core.validation.request.protocols import SupportsPathPattern -from openapi_core.validation.request.protocols import WebhookRequest from openapi_core.validation.schemas.factories import SchemaValidatorsFactory class BaseValidator: schema_validators_factory: SchemaValidatorsFactory = NotImplemented - schema_unmarshallers_factory: SchemaUnmarshallersFactory = NotImplemented def __init__( self, @@ -89,15 +82,33 @@ def _cast(self, schema: Spec, value: Any) -> Any: caster = self.schema_casters_factory.create(schema) return caster(value) + def _validate_schema(self, schema: Spec, value: Any) -> None: + validator = self.schema_validators_factory.create(schema) + validator.validate(value) + def _get_param_or_header_value( self, param_or_header: Spec, location: Mapping[str, Any], name: Optional[str] = None, ) -> Any: - casted, _ = self._get_param_or_header_value_and_schema( + casted, schema = self._get_param_or_header_value_and_schema( param_or_header, location, name ) + if schema is None: + return casted + self._validate_schema(schema, casted) + return casted + + def _get_content_value( + self, raw: Any, mimetype: str, content: Spec + ) -> Any: + casted, schema = self._get_content_value_and_schema( + raw, mimetype, content + ) + if schema is None: + return casted + self._validate_schema(schema, casted) return casted def _get_param_or_header_value_and_schema( diff --git a/tests/integration/contrib/flask/test_flask_decorator.py b/tests/integration/contrib/flask/test_flask_decorator.py index e5ea16d9..a8b0c112 100644 --- a/tests/integration/contrib/flask/test_flask_decorator.py +++ b/tests/integration/contrib/flask/test_flask_decorator.py @@ -4,7 +4,7 @@ from flask import make_response from openapi_core.contrib.flask.decorators import FlaskOpenAPIViewDecorator -from openapi_core.validation.request.datatypes import Parameters +from openapi_core.datatypes import Parameters class TestFlaskOpenAPIDecorator: diff --git a/tests/integration/contrib/flask/test_flask_validator.py b/tests/integration/contrib/flask/test_flask_validator.py index 6ccdb3c0..1f4a1a4f 100644 --- a/tests/integration/contrib/flask/test_flask_validator.py +++ b/tests/integration/contrib/flask/test_flask_validator.py @@ -5,7 +5,7 @@ from flask.testing import FlaskClient from flask.wrappers import Response -from openapi_core import V30RequestValidator +from openapi_core import V30RequestUnmarshaller from openapi_core.contrib.flask import FlaskOpenAPIRequest @@ -28,8 +28,8 @@ def datails_browse(id): from flask import request openapi_request = FlaskOpenAPIRequest(request) - validator = V30RequestValidator(spec) - result = validator.validate(openapi_request) + unmarshaller = V30RequestUnmarshaller(spec) + result = unmarshaller.unmarshal(openapi_request) assert not result.errors if request.args.get("q") == "string": diff --git a/tests/integration/contrib/requests/test_requests_validation.py b/tests/integration/contrib/requests/test_requests_validation.py index 4078807e..2e8aee8c 100644 --- a/tests/integration/contrib/requests/test_requests_validation.py +++ b/tests/integration/contrib/requests/test_requests_validation.py @@ -2,10 +2,10 @@ import requests import responses -from openapi_core import V31RequestValidator -from openapi_core import V31ResponseValidator -from openapi_core import V31WebhookRequestValidator -from openapi_core import V31WebhookResponseValidator +from openapi_core import V31RequestUnmarshaller +from openapi_core import V31ResponseUnmarshaller +from openapi_core import V31WebhookRequestUnmarshaller +from openapi_core import V31WebhookResponseUnmarshaller from openapi_core.contrib.requests import RequestsOpenAPIRequest from openapi_core.contrib.requests import RequestsOpenAPIResponse from openapi_core.contrib.requests import RequestsOpenAPIWebhookRequest @@ -18,23 +18,23 @@ def spec(self, factory): return factory.spec_from_file(specfile) @pytest.fixture - def request_validator(self, spec): - return V31RequestValidator(spec) + def request_unmarshaller(self, spec): + return V31RequestUnmarshaller(spec) @pytest.fixture - def response_validator(self, spec): - return V31ResponseValidator(spec) + def response_unmarshaller(self, spec): + return V31ResponseUnmarshaller(spec) @pytest.fixture - def webhook_request_validator(self, spec): - return V31WebhookRequestValidator(spec) + def webhook_request_unmarshaller(self, spec): + return V31WebhookRequestUnmarshaller(spec) @pytest.fixture - def webhook_response_validator(self, spec): - return V31WebhookResponseValidator(spec) + def webhook_response_unmarshaller(self, spec): + return V31WebhookResponseUnmarshaller(spec) @responses.activate - def test_response_validator_path_pattern(self, response_validator): + def test_response_validator_path_pattern(self, response_unmarshaller): responses.add( responses.POST, "http://localhost/browse/12/?q=string", @@ -55,10 +55,12 @@ def test_response_validator_path_pattern(self, response_validator): response = session.send(request_prepared) openapi_request = RequestsOpenAPIRequest(request) openapi_response = RequestsOpenAPIResponse(response) - result = response_validator.validate(openapi_request, openapi_response) + result = response_unmarshaller.unmarshal( + openapi_request, openapi_response + ) assert not result.errors - def test_request_validator_path_pattern(self, request_validator): + def test_request_validator_path_pattern(self, request_unmarshaller): request = requests.Request( "POST", "http://localhost/browse/12/", @@ -67,10 +69,10 @@ def test_request_validator_path_pattern(self, request_validator): json={"param1": 1}, ) openapi_request = RequestsOpenAPIRequest(request) - result = request_validator.validate(openapi_request) + result = request_unmarshaller.unmarshal(openapi_request) assert not result.errors - def test_request_validator_prepared_request(self, request_validator): + def test_request_validator_prepared_request(self, request_unmarshaller): request = requests.Request( "POST", "http://localhost/browse/12/", @@ -80,10 +82,12 @@ def test_request_validator_prepared_request(self, request_validator): ) request_prepared = request.prepare() openapi_request = RequestsOpenAPIRequest(request_prepared) - result = request_validator.validate(openapi_request) + result = request_unmarshaller.unmarshal(openapi_request) assert not result.errors - def test_webhook_request_validator_path(self, webhook_request_validator): + def test_webhook_request_validator_path( + self, webhook_request_unmarshaller + ): request = requests.Request( "POST", "http://otherhost/callback/", @@ -96,11 +100,15 @@ def test_webhook_request_validator_path(self, webhook_request_validator): openapi_webhook_request = RequestsOpenAPIWebhookRequest( request, "resourceAdded" ) - result = webhook_request_validator.validate(openapi_webhook_request) + result = webhook_request_unmarshaller.unmarshal( + openapi_webhook_request + ) assert not result.errors @responses.activate - def test_webhook_response_validator_path(self, webhook_response_validator): + def test_webhook_response_validator_path( + self, webhook_response_unmarshaller + ): responses.add( responses.POST, "http://otherhost/callback/", @@ -123,7 +131,7 @@ def test_webhook_response_validator_path(self, webhook_response_validator): request, "resourceAdded" ) openapi_response = RequestsOpenAPIResponse(response) - result = webhook_response_validator.validate( + result = webhook_response_unmarshaller.unmarshal( openapi_webhook_request, openapi_response ) assert not result.errors diff --git a/tests/integration/contrib/werkzeug/test_werkzeug_validation.py b/tests/integration/contrib/werkzeug/test_werkzeug_validation.py index 0e8fa5b6..a940a500 100644 --- a/tests/integration/contrib/werkzeug/test_werkzeug_validation.py +++ b/tests/integration/contrib/werkzeug/test_werkzeug_validation.py @@ -6,8 +6,8 @@ from werkzeug.wrappers import Request from werkzeug.wrappers import Response -from openapi_core import V30RequestValidator -from openapi_core import V30ResponseValidator +from openapi_core import V30RequestUnmarshaller +from openapi_core import V30ResponseUnmarshaller from openapi_core.contrib.werkzeug import WerkzeugOpenAPIRequest from openapi_core.contrib.werkzeug import WerkzeugOpenAPIResponse @@ -53,8 +53,8 @@ def test_request_validator_root_path(self, client, spec): headers=headers, ) openapi_request = WerkzeugOpenAPIRequest(response.request) - validator = V30RequestValidator(spec) - result = validator.validate(openapi_request) + unmarshaller = V30RequestUnmarshaller(spec) + result = unmarshaller.unmarshal(openapi_request) assert not result.errors def test_request_validator_path_pattern(self, client, spec): @@ -71,8 +71,8 @@ def test_request_validator_path_pattern(self, client, spec): headers=headers, ) openapi_request = WerkzeugOpenAPIRequest(response.request) - validator = V30RequestValidator(spec) - result = validator.validate(openapi_request) + unmarshaller = V30RequestUnmarshaller(spec) + result = unmarshaller.unmarshal(openapi_request) assert not result.errors @responses.activate @@ -91,6 +91,6 @@ def test_response_validator_path_pattern(self, client, spec): ) openapi_request = WerkzeugOpenAPIRequest(response.request) openapi_response = WerkzeugOpenAPIResponse(response) - validator = V30ResponseValidator(spec) - result = validator.validate(openapi_request, openapi_response) + unmarshaller = V30ResponseUnmarshaller(spec) + result = unmarshaller.unmarshal(openapi_request, openapi_response) assert not result.errors diff --git a/tests/integration/schema/test_spec.py b/tests/integration/schema/test_spec.py index daa77db8..7f47cdb1 100644 --- a/tests/integration/schema/test_spec.py +++ b/tests/integration/schema/test_spec.py @@ -4,11 +4,11 @@ from openapi_spec_validator import openapi_v30_spec_validator from openapi_spec_validator import openapi_v31_spec_validator +from openapi_core import RequestValidator +from openapi_core import ResponseValidator +from openapi_core import Spec from openapi_core.schema.servers import get_server_url from openapi_core.schema.specs import get_spec_url -from openapi_core.spec import Spec -from openapi_core.validation.request.validators import RequestValidator -from openapi_core.validation.response.validators import ResponseValidator class TestPetstore: diff --git a/tests/integration/validation/test_minimal.py b/tests/integration/test_minimal.py similarity index 100% rename from tests/integration/validation/test_minimal.py rename to tests/integration/test_minimal.py diff --git a/tests/integration/validation/test_petstore.py b/tests/integration/test_petstore.py similarity index 88% rename from tests/integration/validation/test_petstore.py rename to tests/integration/test_petstore.py index 0f6a14bd..3f2b2781 100644 --- a/tests/integration/validation/test_petstore.py +++ b/tests/integration/test_petstore.py @@ -8,10 +8,10 @@ import pytest from isodate.tzinfo import UTC -from openapi_core import V30ResponseValidator from openapi_core import validate_request from openapi_core import validate_response from openapi_core.casting.schemas.exceptions import CastError +from openapi_core.datatypes import Parameters from openapi_core.deserializing.exceptions import DeserializeError from openapi_core.deserializing.parameters.exceptions import ( EmptyQueryParameterValue, @@ -19,27 +19,33 @@ from openapi_core.spec import Spec from openapi_core.templating.media_types.exceptions import MediaTypeNotFound from openapi_core.templating.paths.exceptions import ServerNotFound +from openapi_core.templating.security.exceptions import SecurityNotFound from openapi_core.testing import MockRequest from openapi_core.testing import MockResponse -from openapi_core.validation.request.datatypes import Parameters +from openapi_core.unmarshalling.request.unmarshallers import ( + V30RequestBodyUnmarshaller, +) +from openapi_core.unmarshalling.request.unmarshallers import ( + V30RequestParametersUnmarshaller, +) +from openapi_core.unmarshalling.request.unmarshallers import ( + V30RequestSecurityUnmarshaller, +) +from openapi_core.unmarshalling.response.unmarshallers import ( + V30ResponseDataUnmarshaller, +) +from openapi_core.unmarshalling.response.unmarshallers import ( + V30ResponseHeadersUnmarshaller, +) +from openapi_core.unmarshalling.response.unmarshallers import ( + V30ResponseUnmarshaller, +) from openapi_core.validation.request.exceptions import MissingRequiredParameter from openapi_core.validation.request.exceptions import ParameterError from openapi_core.validation.request.exceptions import RequestBodyError -from openapi_core.validation.request.validators import V30RequestBodyValidator -from openapi_core.validation.request.validators import ( - V30RequestParametersValidator, -) -from openapi_core.validation.request.validators import ( - V30RequestSecurityValidator, -) +from openapi_core.validation.request.exceptions import SecurityError from openapi_core.validation.response.exceptions import InvalidData from openapi_core.validation.response.exceptions import MissingRequiredHeader -from openapi_core.validation.response.validators import ( - V30ResponseDataValidator, -) -from openapi_core.validation.response.validators import ( - V30ResponseHeadersValidator, -) from openapi_core.validation.schemas.exceptions import InvalidSchemaValue @@ -61,12 +67,8 @@ def spec(self, v30_petstore_spec): return v30_petstore_spec @pytest.fixture(scope="module") - def request_parameters_validator(self, spec): - return V30RequestParametersValidator(spec) - - @pytest.fixture(scope="module") - def response_validator(self, spec): - return V30ResponseValidator(spec) + def response_unmarshaller(self, spec): + return V30ResponseUnmarshaller(spec) def test_get_pets(self, spec): host_url = "http://petstore.swagger.io/v1" @@ -87,7 +89,7 @@ def test_get_pets(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -101,7 +103,7 @@ def test_get_pets(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestBodyValidator, + cls=V30RequestBodyUnmarshaller, ) assert result.body is None @@ -144,7 +146,7 @@ def test_get_pets_response(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -156,7 +158,7 @@ def test_get_pets_response(self, spec): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -202,7 +204,7 @@ def test_get_pets_response_no_schema(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -214,7 +216,7 @@ def test_get_pets_response_no_schema(self, spec): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -228,7 +230,7 @@ def test_get_pets_response_no_schema(self, spec): assert response_result.errors == [] assert response_result.data == data - def test_get_pets_invalid_response(self, spec, response_validator): + def test_get_pets_invalid_response(self, spec, response_unmarshaller): host_url = "http://petstore.swagger.io/v1" path_pattern = "/v1/pets" query_params = { @@ -247,7 +249,7 @@ def test_get_pets_invalid_response(self, spec, response_validator): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -259,7 +261,7 @@ def test_get_pets_invalid_response(self, spec, response_validator): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -282,11 +284,11 @@ def test_get_pets_invalid_response(self, spec, response_validator): request, response, spec=spec, - cls=V30ResponseDataValidator, + cls=V30ResponseDataUnmarshaller, ) assert type(exc_info.value.__cause__) is InvalidSchemaValue - response_result = response_validator.validate(request, response) + response_result = response_unmarshaller.unmarshal(request, response) assert response_result.errors == [InvalidData()] schema_errors = response_result.errors[0].__cause__.schema_errors @@ -317,7 +319,7 @@ def test_get_pets_ids_param(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -330,7 +332,7 @@ def test_get_pets_ids_param(self, spec): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -367,7 +369,7 @@ def test_get_pets_tags_param(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -380,7 +382,7 @@ def test_get_pets_tags_param(self, spec): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -418,12 +420,12 @@ def test_get_pets_parameter_deserialization_error(self, spec): validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert type(exc_info.value.__cause__) is DeserializeError result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -448,12 +450,12 @@ def test_get_pets_wrong_parameter_type(self, spec): validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert type(exc_info.value.__cause__) is CastError result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -473,11 +475,11 @@ def test_get_pets_raises_missing_required_param(self, spec): validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -502,12 +504,12 @@ def test_get_pets_empty_value(self, spec): validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert type(exc_info.value.__cause__) is EmptyQueryParameterValue result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -532,7 +534,7 @@ def test_get_pets_allow_empty_value(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -544,7 +546,7 @@ def test_get_pets_allow_empty_value(self, spec): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -568,7 +570,7 @@ def test_get_pets_none_value(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -580,7 +582,7 @@ def test_get_pets_none_value(self, spec): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -605,7 +607,7 @@ def test_get_pets_param_order(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -618,7 +620,7 @@ def test_get_pets_param_order(self, spec): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -647,7 +649,7 @@ def test_get_pets_param_coordinates(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert is_dataclass(result.parameters.query["coordinates"]) @@ -659,7 +661,7 @@ def test_get_pets_param_coordinates(self, spec): assert result.parameters.query["coordinates"].lon == coordinates["lon"] result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -711,7 +713,7 @@ def test_post_birds(self, spec, spec_dict): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert is_dataclass(result.parameters.cookie["userdata"]) @@ -722,7 +724,7 @@ def test_post_birds(self, spec, spec_dict): assert result.parameters.cookie["userdata"].name == "user1" result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) schemas = spec_dict["components"]["schemas"] @@ -740,7 +742,7 @@ def test_post_birds(self, spec, spec_dict): result = validate_request( request, spec=spec, - cls=V30RequestSecurityValidator, + cls=V30RequestSecurityUnmarshaller, ) assert result.security == {} @@ -788,7 +790,7 @@ def test_post_cats(self, spec, spec_dict): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -801,7 +803,7 @@ def test_post_cats(self, spec, spec_dict): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) schemas = spec_dict["components"]["schemas"] @@ -859,7 +861,7 @@ def test_post_cats_boolean_string(self, spec, spec_dict): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -872,7 +874,7 @@ def test_post_cats_boolean_string(self, spec, spec_dict): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) schemas = spec_dict["components"]["schemas"] @@ -917,7 +919,7 @@ def test_post_no_one_of_schema(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -933,7 +935,7 @@ def test_post_no_one_of_schema(self, spec): validate_request( request, spec=spec, - cls=V30RequestBodyValidator, + cls=V30RequestBodyUnmarshaller, ) assert type(exc_info.value.__cause__) is InvalidSchemaValue @@ -969,7 +971,7 @@ def test_post_cats_only_required_body(self, spec, spec_dict): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -982,7 +984,7 @@ def test_post_cats_only_required_body(self, spec, spec_dict): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) schemas = spec_dict["components"]["schemas"] @@ -1021,7 +1023,7 @@ def test_post_pets_raises_invalid_mimetype(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -1037,7 +1039,7 @@ def test_post_pets_raises_invalid_mimetype(self, spec): validate_request( request, spec=spec, - cls=V30RequestBodyValidator, + cls=V30RequestBodyUnmarshaller, ) assert type(exc_info.value.__cause__) is MediaTypeNotFound @@ -1070,11 +1072,11 @@ def test_post_pets_missing_cookie(self, spec, spec_dict): validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) schemas = spec_dict["components"]["schemas"] @@ -1113,11 +1115,11 @@ def test_post_pets_missing_header(self, spec, spec_dict): validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) schemas = spec_dict["components"]["schemas"] @@ -1157,14 +1159,14 @@ def test_post_pets_raises_invalid_server_error(self, spec): validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) with pytest.raises(ServerNotFound): validate_request( request, spec=spec, - cls=V30RequestBodyValidator, + cls=V30RequestBodyUnmarshaller, ) data_id = 1 @@ -1186,9 +1188,35 @@ def test_post_pets_raises_invalid_server_error(self, spec): request, response, spec=spec, - cls=V30ResponseDataValidator, + cls=V30ResponseDataUnmarshaller, + ) + + def test_get_pet_invalid_security(self, spec): + host_url = "http://petstore.swagger.io/v1" + path_pattern = "/v1/pets/{petId}" + view_args = { + "petId": "1", + } + auth = "authuser" + request = MockRequest( + host_url, + "GET", + "/pets/1", + path_pattern=path_pattern, + view_args=view_args, + ) + + with pytest.raises(SecurityError) as exc_info: + validate_request( + request, + spec=spec, + cls=V30RequestSecurityUnmarshaller, ) + assert exc_info.value.__cause__ == SecurityNotFound( + [["petstore_auth"]] + ) + def test_get_pet(self, spec): host_url = "http://petstore.swagger.io/v1" path_pattern = "/v1/pets/{petId}" @@ -1211,7 +1239,7 @@ def test_get_pet(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -1221,7 +1249,7 @@ def test_get_pet(self, spec): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -1229,7 +1257,7 @@ def test_get_pet(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestSecurityValidator, + cls=V30RequestSecurityUnmarshaller, ) assert result.security == { @@ -1275,7 +1303,7 @@ def test_get_pet_not_found(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -1285,7 +1313,7 @@ def test_get_pet_not_found(self, spec): ) result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -1326,7 +1354,7 @@ def test_get_pet_wildcard(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters( @@ -1338,7 +1366,7 @@ def test_get_pet_wildcard(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestBodyValidator, + cls=V30RequestBodyUnmarshaller, ) assert result.body is None @@ -1366,13 +1394,13 @@ def test_get_tags(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters() result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -1408,7 +1436,7 @@ def test_post_tags_extra_body_properties(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters() @@ -1417,7 +1445,7 @@ def test_post_tags_extra_body_properties(self, spec): validate_request( request, spec=spec, - cls=V30RequestBodyValidator, + cls=V30RequestBodyUnmarshaller, ) assert type(exc_info.value.__cause__) is InvalidSchemaValue @@ -1438,7 +1466,7 @@ def test_post_tags_empty_body(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters() @@ -1447,7 +1475,7 @@ def test_post_tags_empty_body(self, spec): validate_request( request, spec=spec, - cls=V30RequestBodyValidator, + cls=V30RequestBodyUnmarshaller, ) assert type(exc_info.value.__cause__) is InvalidSchemaValue @@ -1468,7 +1496,7 @@ def test_post_tags_wrong_property_type(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters() @@ -1477,7 +1505,7 @@ def test_post_tags_wrong_property_type(self, spec): validate_request( request, spec=spec, - cls=V30RequestBodyValidator, + cls=V30RequestBodyUnmarshaller, ) assert type(exc_info.value.__cause__) is InvalidSchemaValue @@ -1501,13 +1529,13 @@ def test_post_tags_additional_properties(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters() result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert is_dataclass(result.body) @@ -1557,13 +1585,13 @@ def test_post_tags_created_now(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters() result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert is_dataclass(result.body) @@ -1614,13 +1642,13 @@ def test_post_tags_created_datetime(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters() result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert is_dataclass(result.body) @@ -1646,7 +1674,7 @@ def test_post_tags_created_datetime(self, spec): request, response, spec=spec, - cls=V30ResponseDataValidator, + cls=V30ResponseDataUnmarshaller, ) assert is_dataclass(result.data) @@ -1686,7 +1714,7 @@ def test_post_tags_created_invalid_type(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters() @@ -1695,7 +1723,7 @@ def test_post_tags_created_invalid_type(self, spec): validate_request( request, spec=spec, - cls=V30RequestBodyValidator, + cls=V30RequestBodyUnmarshaller, ) assert type(exc_info.value.__cause__) is InvalidSchemaValue @@ -1742,13 +1770,13 @@ def test_delete_tags_with_requestbody(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters() result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert is_dataclass(result.body) @@ -1770,7 +1798,7 @@ def test_delete_tags_with_requestbody(self, spec): request, response, spec=spec, - cls=V30ResponseHeadersValidator, + cls=V30ResponseHeadersUnmarshaller, ) assert result.headers == { @@ -1790,19 +1818,19 @@ def test_delete_tags_no_requestbody(self, spec): result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters() result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None def test_delete_tags_raises_missing_required_response_header( - self, spec, response_validator + self, spec, response_unmarshaller ): host_url = "http://petstore.swagger.io/v1" path_pattern = "/v1/tags" @@ -1816,13 +1844,13 @@ def test_delete_tags_raises_missing_required_response_header( result = validate_request( request, spec=spec, - cls=V30RequestParametersValidator, + cls=V30RequestParametersUnmarshaller, ) assert result.parameters == Parameters() result = validate_request( - request, spec=spec, cls=V30RequestBodyValidator + request, spec=spec, cls=V30RequestBodyUnmarshaller ) assert result.body is None @@ -1831,7 +1859,9 @@ def test_delete_tags_raises_missing_required_response_header( response = MockResponse(data, status_code=200) with pytest.warns(DeprecationWarning): - response_result = response_validator.validate(request, response) + response_result = response_unmarshaller.unmarshal( + request, response + ) assert response_result.errors == [ MissingRequiredHeader(name="x-delete-confirm"), diff --git a/tests/integration/validation/test_read_only_write_only.py b/tests/integration/unmarshalling/test_read_only_write_only.py similarity index 71% rename from tests/integration/validation/test_read_only_write_only.py rename to tests/integration/unmarshalling/test_read_only_write_only.py index c7fd7ad1..3a54636b 100644 --- a/tests/integration/validation/test_read_only_write_only.py +++ b/tests/integration/unmarshalling/test_read_only_write_only.py @@ -3,10 +3,14 @@ import pytest -from openapi_core import V30RequestValidator -from openapi_core import V30ResponseValidator from openapi_core.testing import MockRequest from openapi_core.testing import MockResponse +from openapi_core.unmarshalling.request.unmarshallers import ( + V30RequestUnmarshaller, +) +from openapi_core.unmarshalling.response.unmarshallers import ( + V30ResponseUnmarshaller, +) from openapi_core.validation.request.exceptions import InvalidRequestBody from openapi_core.validation.response.exceptions import InvalidData @@ -17,17 +21,17 @@ def spec(factory): @pytest.fixture(scope="class") -def request_validator(spec): - return V30RequestValidator(spec) +def request_unmarshaller(spec): + return V30RequestUnmarshaller(spec) @pytest.fixture(scope="class") -def response_validator(spec): - return V30ResponseValidator(spec) +def response_unmarshaller(spec): + return V30ResponseUnmarshaller(spec) class TestReadOnly: - def test_write_a_read_only_property(self, request_validator): + def test_write_a_read_only_property(self, request_unmarshaller): data = json.dumps( { "id": 10, @@ -39,13 +43,13 @@ def test_write_a_read_only_property(self, request_validator): host_url="", method="POST", path="/users", data=data ) - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert len(result.errors) == 1 assert type(result.errors[0]) == InvalidRequestBody assert result.body is None - def test_read_only_property_response(self, response_validator): + def test_read_only_property_response(self, response_unmarshaller): data = json.dumps( { "id": 10, @@ -57,7 +61,7 @@ def test_read_only_property_response(self, response_validator): response = MockResponse(data) - result = response_validator.validate(request, response) + result = response_unmarshaller.unmarshal(request, response) assert not result.errors assert is_dataclass(result.data) @@ -67,7 +71,7 @@ def test_read_only_property_response(self, response_validator): class TestWriteOnly: - def test_write_only_property(self, request_validator): + def test_write_only_property(self, request_unmarshaller): data = json.dumps( { "name": "Pedro", @@ -79,7 +83,7 @@ def test_write_only_property(self, request_validator): host_url="", method="POST", path="/users", data=data ) - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert not result.errors assert is_dataclass(result.body) @@ -87,7 +91,7 @@ def test_write_only_property(self, request_validator): assert result.body.name == "Pedro" assert result.body.hidden == False - def test_read_a_write_only_property(self, response_validator): + def test_read_a_write_only_property(self, response_unmarshaller): data = json.dumps( { "id": 10, @@ -99,7 +103,7 @@ def test_read_a_write_only_property(self, response_validator): request = MockRequest(host_url="", method="POST", path="/users") response = MockResponse(data) - result = response_validator.validate(request, response) + result = response_unmarshaller.unmarshal(request, response) assert result.errors == [InvalidData()] assert result.data is None diff --git a/tests/integration/validation/test_request_validator.py b/tests/integration/unmarshalling/test_request_unmarshaller.py similarity index 86% rename from tests/integration/validation/test_request_validator.py rename to tests/integration/unmarshalling/test_request_unmarshaller.py index 3051d51f..ea19f84e 100644 --- a/tests/integration/validation/test_request_validator.py +++ b/tests/integration/unmarshalling/test_request_unmarshaller.py @@ -3,13 +3,13 @@ import pytest -from openapi_core import V30RequestValidator +from openapi_core import V30RequestUnmarshaller +from openapi_core.datatypes import Parameters from openapi_core.templating.media_types.exceptions import MediaTypeNotFound from openapi_core.templating.paths.exceptions import OperationNotFound from openapi_core.templating.paths.exceptions import PathNotFound from openapi_core.templating.security.exceptions import SecurityNotFound from openapi_core.testing import MockRequest -from openapi_core.validation.request.datatypes import Parameters from openapi_core.validation.request.exceptions import InvalidParameter from openapi_core.validation.request.exceptions import MissingRequiredParameter from openapi_core.validation.request.exceptions import ( @@ -19,7 +19,7 @@ from openapi_core.validation.request.exceptions import SecurityError -class TestRequestValidator: +class TestRequestUnmarshaller: host_url = "http://petstore.swagger.io" api_key = "12345" @@ -39,44 +39,44 @@ def spec(self, v30_petstore_spec): return v30_petstore_spec @pytest.fixture(scope="session") - def request_validator(self, spec): - return V30RequestValidator(spec) + def request_unmarshaller(self, spec): + return V30RequestUnmarshaller(spec) - def test_request_server_error(self, request_validator): + def test_request_server_error(self, request_unmarshaller): request = MockRequest("http://petstore.invalid.net/v1", "get", "/") - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert len(result.errors) == 1 assert type(result.errors[0]) == PathNotFound assert result.body is None assert result.parameters == Parameters() - def test_invalid_path(self, request_validator): + def test_invalid_path(self, request_unmarshaller): request = MockRequest(self.host_url, "get", "/v1") - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert len(result.errors) == 1 assert type(result.errors[0]) == PathNotFound assert result.body is None assert result.parameters == Parameters() - def test_invalid_operation(self, request_validator): + def test_invalid_operation(self, request_unmarshaller): request = MockRequest(self.host_url, "patch", "/v1/pets") - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert len(result.errors) == 1 assert type(result.errors[0]) == OperationNotFound assert result.body is None assert result.parameters == Parameters() - def test_missing_parameter(self, request_validator): + def test_missing_parameter(self, request_unmarshaller): request = MockRequest(self.host_url, "get", "/v1/pets") with pytest.warns(DeprecationWarning): - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert type(result.errors[0]) == MissingRequiredParameter assert result.body is None @@ -87,7 +87,7 @@ def test_missing_parameter(self, request_validator): }, ) - def test_get_pets(self, request_validator): + def test_get_pets(self, request_unmarshaller): args = {"limit": "10", "ids": ["1", "2"], "api_key": self.api_key} request = MockRequest( self.host_url, @@ -98,7 +98,7 @@ def test_get_pets(self, request_validator): ) with pytest.warns(DeprecationWarning): - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert result.errors == [] assert result.body is None @@ -114,7 +114,7 @@ def test_get_pets(self, request_validator): "api_key": self.api_key, } - def test_get_pets_webob(self, request_validator): + def test_get_pets_webob(self, request_unmarshaller): from webob.multidict import GetDict request = MockRequest( @@ -128,7 +128,7 @@ def test_get_pets_webob(self, request_validator): ) with pytest.warns(DeprecationWarning): - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert result.errors == [] assert result.body is None @@ -141,7 +141,7 @@ def test_get_pets_webob(self, request_validator): }, ) - def test_missing_body(self, request_validator): + def test_missing_body(self, request_unmarshaller): headers = { "api-key": self.api_key_encoded, } @@ -157,7 +157,7 @@ def test_missing_body(self, request_validator): cookies=cookies, ) - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert len(result.errors) == 1 assert type(result.errors[0]) == MissingRequiredRequestBody @@ -171,7 +171,7 @@ def test_missing_body(self, request_validator): }, ) - def test_invalid_content_type(self, request_validator): + def test_invalid_content_type(self, request_unmarshaller): data = "csv,data" headers = { "api-key": self.api_key_encoded, @@ -190,7 +190,7 @@ def test_invalid_content_type(self, request_validator): cookies=cookies, ) - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert len(result.errors) == 1 assert type(result.errors[0]) == RequestBodyError @@ -208,7 +208,7 @@ def test_invalid_content_type(self, request_validator): }, ) - def test_invalid_complex_parameter(self, request_validator, spec_dict): + def test_invalid_complex_parameter(self, request_unmarshaller, spec_dict): pet_name = "Cat" pet_tag = "cats" pet_street = "Piekna" @@ -247,7 +247,7 @@ def test_invalid_complex_parameter(self, request_validator, spec_dict): cookies=cookies, ) - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert result.errors == [ InvalidParameter(name="userdata", location="cookie") @@ -273,7 +273,7 @@ def test_invalid_complex_parameter(self, request_validator, spec_dict): assert result.body.address.street == pet_street assert result.body.address.city == pet_city - def test_post_pets(self, request_validator, spec_dict): + def test_post_pets(self, request_unmarshaller, spec_dict): pet_name = "Cat" pet_tag = "cats" pet_street = "Piekna" @@ -307,7 +307,7 @@ def test_post_pets(self, request_validator, spec_dict): cookies=cookies, ) - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert result.errors == [] assert result.parameters == Parameters( @@ -331,7 +331,7 @@ def test_post_pets(self, request_validator, spec_dict): assert result.body.address.street == pet_street assert result.body.address.city == pet_city - def test_post_pets_plain_no_schema(self, request_validator): + def test_post_pets_plain_no_schema(self, request_unmarshaller): data = "plain text" headers = { "api-key": self.api_key_encoded, @@ -351,7 +351,7 @@ def test_post_pets_plain_no_schema(self, request_validator): ) with pytest.warns(UserWarning): - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert result.errors == [] assert result.parameters == Parameters( @@ -365,7 +365,7 @@ def test_post_pets_plain_no_schema(self, request_validator): assert result.security == {} assert result.body == data - def test_get_pet_unauthorized(self, request_validator): + def test_get_pet_unauthorized(self, request_unmarshaller): request = MockRequest( self.host_url, "get", @@ -374,7 +374,7 @@ def test_get_pet_unauthorized(self, request_validator): view_args={"petId": "1"}, ) - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert len(result.errors) == 1 assert type(result.errors[0]) is SecurityError @@ -385,7 +385,7 @@ def test_get_pet_unauthorized(self, request_validator): assert result.parameters == Parameters() assert result.security is None - def test_get_pet(self, request_validator): + def test_get_pet(self, request_unmarshaller): authorization = "Basic " + self.api_key_encoded headers = { "Authorization": authorization, @@ -399,7 +399,7 @@ def test_get_pet(self, request_validator): headers=headers, ) - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert result.errors == [] assert result.body is None diff --git a/tests/integration/validation/test_response_validator.py b/tests/integration/unmarshalling/test_response_unmarshaller.py similarity index 78% rename from tests/integration/validation/test_response_validator.py rename to tests/integration/unmarshalling/test_response_unmarshaller.py index fd1bf01c..36de07d9 100644 --- a/tests/integration/validation/test_response_validator.py +++ b/tests/integration/unmarshalling/test_response_unmarshaller.py @@ -3,7 +3,6 @@ import pytest -from openapi_core import V30ResponseValidator from openapi_core.deserializing.media_types.exceptions import ( MediaTypeDeserializeError, ) @@ -13,6 +12,9 @@ from openapi_core.templating.responses.exceptions import ResponseNotFound from openapi_core.testing import MockRequest from openapi_core.testing import MockResponse +from openapi_core.unmarshalling.response.unmarshallers import ( + V30ResponseUnmarshaller, +) from openapi_core.validation.response.exceptions import DataError from openapi_core.validation.response.exceptions import InvalidData from openapi_core.validation.response.exceptions import InvalidHeader @@ -20,7 +22,7 @@ from openapi_core.validation.schemas.exceptions import InvalidSchemaValue -class TestResponseValidator: +class TestResponseUnmarshaller: host_url = "http://petstore.swagger.io" @pytest.fixture(scope="session") @@ -32,68 +34,68 @@ def spec(self, v30_petstore_spec): return v30_petstore_spec @pytest.fixture(scope="session") - def response_validator(self, spec): - return V30ResponseValidator(spec) + def response_unmarshaller(self, spec): + return V30ResponseUnmarshaller(spec) - def test_invalid_server(self, response_validator): + def test_invalid_server(self, response_unmarshaller): request = MockRequest("http://petstore.invalid.net/v1", "get", "/") response = MockResponse("Not Found", status_code=404) - result = response_validator.validate(request, response) + result = response_unmarshaller.unmarshal(request, response) assert len(result.errors) == 1 assert type(result.errors[0]) == PathNotFound assert result.data is None assert result.headers == {} - def test_invalid_operation(self, response_validator): + def test_invalid_operation(self, response_unmarshaller): request = MockRequest(self.host_url, "patch", "/v1/pets") response = MockResponse("Not Found", status_code=404) - result = response_validator.validate(request, response) + result = response_unmarshaller.unmarshal(request, response) assert len(result.errors) == 1 assert type(result.errors[0]) == OperationNotFound assert result.data is None assert result.headers == {} - def test_invalid_response(self, response_validator): + def test_invalid_response(self, response_unmarshaller): request = MockRequest(self.host_url, "get", "/v1/pets") response = MockResponse("Not Found", status_code=409) - result = response_validator.validate(request, response) + result = response_unmarshaller.unmarshal(request, response) assert len(result.errors) == 1 assert type(result.errors[0]) == ResponseNotFound assert result.data is None assert result.headers == {} - def test_invalid_content_type(self, response_validator): + def test_invalid_content_type(self, response_unmarshaller): request = MockRequest(self.host_url, "get", "/v1/pets") response = MockResponse("Not Found", mimetype="text/csv") - result = response_validator.validate(request, response) + result = response_unmarshaller.unmarshal(request, response) assert result.errors == [DataError()] assert type(result.errors[0].__cause__) == MediaTypeNotFound assert result.data is None assert result.headers == {} - def test_missing_body(self, response_validator): + def test_missing_body(self, response_unmarshaller): request = MockRequest(self.host_url, "get", "/v1/pets") response = MockResponse(None) - result = response_validator.validate(request, response) + result = response_unmarshaller.unmarshal(request, response) assert result.errors == [MissingData()] assert result.data is None assert result.headers == {} - def test_invalid_media_type(self, response_validator): + def test_invalid_media_type(self, response_unmarshaller): request = MockRequest(self.host_url, "get", "/v1/pets") response = MockResponse("abcde") - result = response_validator.validate(request, response) + result = response_unmarshaller.unmarshal(request, response) assert result.errors == [DataError()] assert result.errors[0].__cause__ == MediaTypeDeserializeError( @@ -102,18 +104,18 @@ def test_invalid_media_type(self, response_validator): assert result.data is None assert result.headers == {} - def test_invalid_media_type_value(self, response_validator): + def test_invalid_media_type_value(self, response_unmarshaller): request = MockRequest(self.host_url, "get", "/v1/pets") response = MockResponse("{}") - result = response_validator.validate(request, response) + result = response_unmarshaller.unmarshal(request, response) assert result.errors == [InvalidData()] assert type(result.errors[0].__cause__) == InvalidSchemaValue assert result.data is None assert result.headers == {} - def test_invalid_value(self, response_validator): + def test_invalid_value(self, response_unmarshaller): request = MockRequest(self.host_url, "get", "/v1/tags") response_json = { "data": [ @@ -123,14 +125,14 @@ def test_invalid_value(self, response_validator): response_data = json.dumps(response_json) response = MockResponse(response_data) - result = response_validator.validate(request, response) + result = response_unmarshaller.unmarshal(request, response) assert result.errors == [InvalidData()] assert type(result.errors[0].__cause__) == InvalidSchemaValue assert result.data is None assert result.headers == {} - def test_invalid_header(self, response_validator): + def test_invalid_header(self, response_unmarshaller): userdata = { "name": 1, } @@ -160,13 +162,13 @@ def test_invalid_header(self, response_validator): response = MockResponse(response_data, headers=headers) with pytest.warns(DeprecationWarning): - result = response_validator.validate(request, response) + result = response_unmarshaller.unmarshal(request, response) assert result.errors == [InvalidHeader(name="x-delete-date")] assert result.data is None assert result.headers == {"x-delete-confirm": True} - def test_get_pets(self, response_validator): + def test_get_pets(self, response_unmarshaller): request = MockRequest(self.host_url, "get", "/v1/pets") response_json = { "data": [ @@ -182,7 +184,7 @@ def test_get_pets(self, response_validator): response_data = json.dumps(response_json) response = MockResponse(response_data) - result = response_validator.validate(request, response) + result = response_unmarshaller.unmarshal(request, response) assert result.errors == [] assert is_dataclass(result.data) diff --git a/tests/integration/validation/test_security_override.py b/tests/integration/unmarshalling/test_security_override.py similarity index 73% rename from tests/integration/validation/test_security_override.py rename to tests/integration/unmarshalling/test_security_override.py index bb316f8d..a885da99 100644 --- a/tests/integration/validation/test_security_override.py +++ b/tests/integration/unmarshalling/test_security_override.py @@ -2,9 +2,11 @@ import pytest -from openapi_core import V30RequestValidator from openapi_core.templating.security.exceptions import SecurityNotFound from openapi_core.testing import MockRequest +from openapi_core.unmarshalling.request.unmarshallers import ( + V30RequestUnmarshaller, +) from openapi_core.validation.request.exceptions import SecurityError @@ -14,8 +16,8 @@ def spec(factory): @pytest.fixture(scope="class") -def request_validator(spec): - return V30RequestValidator(spec) +def request_unmarshaller(spec): + return V30RequestUnmarshaller(spec) class TestSecurityOverride: @@ -29,28 +31,28 @@ def api_key_encoded(self): api_key_bytes_enc = b64encode(api_key_bytes) return str(api_key_bytes_enc, "utf8") - def test_default(self, request_validator): + def test_default(self, request_unmarshaller): args = {"api_key": self.api_key} request = MockRequest(self.host_url, "get", "/resource/one", args=args) - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert not result.errors assert result.security == { "api_key": self.api_key, } - def test_default_invalid(self, request_validator): + def test_default_invalid(self, request_unmarshaller): request = MockRequest(self.host_url, "get", "/resource/one") - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert len(result.errors) == 1 assert type(result.errors[0]) is SecurityError assert type(result.errors[0].__cause__) is SecurityNotFound assert result.security is None - def test_override(self, request_validator): + def test_override(self, request_unmarshaller): authorization = "Basic " + self.api_key_encoded headers = { "Authorization": authorization, @@ -59,27 +61,27 @@ def test_override(self, request_validator): self.host_url, "post", "/resource/one", headers=headers ) - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert not result.errors assert result.security == { "petstore_auth": self.api_key_encoded, } - def test_override_invalid(self, request_validator): + def test_override_invalid(self, request_unmarshaller): request = MockRequest(self.host_url, "post", "/resource/one") - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert len(result.errors) == 1 assert type(result.errors[0]) is SecurityError assert type(result.errors[0].__cause__) is SecurityNotFound assert result.security is None - def test_remove(self, request_validator): + def test_remove(self, request_unmarshaller): request = MockRequest(self.host_url, "put", "/resource/one") - result = request_validator.validate(request) + result = request_unmarshaller.unmarshal(request) assert not result.errors assert result.security == {} diff --git a/tests/integration/unmarshalling/test_unmarshallers.py b/tests/integration/unmarshalling/test_unmarshallers.py index 4dc10dec..c69d4af7 100644 --- a/tests/integration/unmarshalling/test_unmarshallers.py +++ b/tests/integration/unmarshalling/test_unmarshallers.py @@ -8,7 +8,6 @@ from isodate.tzinfo import FixedOffset from jsonschema.exceptions import SchemaError from jsonschema.exceptions import UnknownType -from jsonschema.exceptions import ValidationError from openapi_core import Spec from openapi_core.unmarshalling.schemas import ( @@ -23,7 +22,6 @@ from openapi_core.unmarshalling.schemas.exceptions import ( FormatterNotFoundError, ) -from openapi_core.unmarshalling.schemas.exceptions import UnmarshalError from openapi_core.validation.schemas.exceptions import InvalidSchemaValue @@ -54,6 +52,28 @@ def test_create_formatter_not_found(self, unmarshallers_factory): ): unmarshallers_factory.create(spec) + @pytest.mark.parametrize( + "value", + [ + "test", + 10, + 10, + 3.12, + ["one", "two"], + True, + False, + ], + ) + def test_call_deprecated(self, unmarshallers_factory, value): + schema = {} + spec = Spec.from_dict(schema, validator=None) + unmarshaller = unmarshallers_factory.create(spec) + + with pytest.warns(DeprecationWarning): + result = unmarshaller(value) + + assert result == value + @pytest.mark.parametrize( "value", [ @@ -71,7 +91,7 @@ def test_no_type(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -94,7 +114,7 @@ def test_basic_types(self, unmarshallers_factory, type, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -153,7 +173,7 @@ def test_basic_types_invalid(self, unmarshallers_factory, type, value): InvalidSchemaValue, match=f"not valid for schema of type {type}", ) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert ( f"is not of type '{type}'" @@ -201,7 +221,7 @@ def test_basic_formats( spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == unmarshalled @@ -244,7 +264,7 @@ def test_basic_type_formats( spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == unmarshalled @@ -269,7 +289,7 @@ def test_basic_type_formats_invalid( unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert ( f"is not a '{format}'" in exc_info.value.schema_errors[0].message @@ -289,7 +309,7 @@ def test_string_byte(self, unmarshallers_factory, value, expected): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == expected @@ -302,7 +322,7 @@ def test_string_date(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = "2018-01-02" - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == date(2018, 1, 2) @@ -324,7 +344,7 @@ def test_string_datetime(self, unmarshallers_factory, value, expected): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == expected @@ -338,7 +358,7 @@ def test_string_datetime_invalid(self, unmarshallers_factory): value = "2018-01-02T00:00:00" with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert ( f"is not a 'date-time'" in exc_info.value.schema_errors[0].message @@ -353,7 +373,7 @@ def test_string_password(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = "passwd" - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -366,7 +386,7 @@ def test_string_uuid(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = str(uuid4()) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == UUID(value) @@ -380,7 +400,7 @@ def test_string_uuid_invalid(self, unmarshallers_factory): value = "test" with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert f"is not a 'uuid'" in exc_info.value.schema_errors[0].message @@ -413,7 +433,7 @@ def test_formats_ignored( spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == expected @@ -426,7 +446,7 @@ def test_string_pattern(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -448,7 +468,7 @@ def test_string_pattern_invalid( unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert ( f"'{value}' does not match '{pattern}'" @@ -464,7 +484,7 @@ def test_string_min_length(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -478,7 +498,7 @@ def test_string_min_length_invalid(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert ( f"'{value}' is too short" @@ -494,7 +514,7 @@ def test_string_max_length(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -508,7 +528,7 @@ def test_string_max_length_invalid(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert ( f"'{value}' is too long" in exc_info.value.schema_errors[0].message @@ -531,7 +551,7 @@ def test_string_max_length_invalid_schema( unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) def test_integer_enum(self, unmarshallers_factory): schema = { @@ -542,7 +562,7 @@ def test_integer_enum(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = 2 - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == int(value) @@ -557,7 +577,7 @@ def test_integer_enum_invalid(self, unmarshallers_factory): value = 12 with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert ( f"{value} is not one of {enum}" @@ -587,7 +607,7 @@ def test_array(self, unmarshallers_factory, type, value): unmarshaller = unmarshallers_factory.create(spec) value_list = [value] * 3 - result = unmarshaller(value_list) + result = unmarshaller.unmarshal(value_list) assert result == value_list @@ -613,7 +633,7 @@ def test_array_invalid(self, unmarshallers_factory, type, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller([value]) + unmarshaller.unmarshal([value]) assert len(exc_info.value.schema_errors) == 1 assert ( f"is not of type '{type}'" @@ -633,7 +653,7 @@ def test_array_min_items_invalid(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert ( f"{value} is too short" in exc_info.value.schema_errors[0].message @@ -651,7 +671,7 @@ def test_array_min_items(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -675,7 +695,7 @@ def test_array_max_items_invalid_schema( unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.parametrize("value", [[1, 2], [2, 3, 4]]) def test_array_max_items_invalid(self, unmarshallers_factory, value): @@ -690,7 +710,7 @@ def test_array_max_items_invalid(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert ( f"{value} is too long" in exc_info.value.schema_errors[0].message @@ -709,7 +729,7 @@ def test_array_unique_items_invalid(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert ( f"{value} has non-unique elements" @@ -736,7 +756,7 @@ def test_object_any_of(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = {"someint": 1} - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -760,7 +780,7 @@ def test_object_any_of_invalid(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller({"someint": "1"}) + unmarshaller.unmarshal({"someint": "1"}) def test_object_one_of_default(self, unmarshallers_factory): schema = { @@ -794,7 +814,7 @@ def test_object_one_of_default(self, unmarshallers_factory): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - assert unmarshaller({"someint": 1}) == { + assert unmarshaller.unmarshal({"someint": 1}) == { "someint": 1, "somestr": "defaultstring", } @@ -825,7 +845,7 @@ def test_object_any_of_default(self, unmarshallers_factory): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - assert unmarshaller({"someint": "1"}) == { + assert unmarshaller.unmarshal({"someint": "1"}) == { "someint": "1", "somestr": "defaultstring", } @@ -857,7 +877,7 @@ def test_object_all_of_default(self, unmarshallers_factory): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - assert unmarshaller({}) == { + assert unmarshaller.unmarshal({}) == { "someint": 1, "somestr": "defaultstring", } @@ -892,7 +912,7 @@ def test_object_with_properties(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -937,7 +957,7 @@ def test_object_with_properties_invalid( unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.parametrize( "value", @@ -958,7 +978,7 @@ def test_object_default_property(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == {"prop": "value1"} @@ -979,7 +999,7 @@ def test_object_additional_properties_false( unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.parametrize( "value", @@ -1000,7 +1020,7 @@ def test_object_additional_properties_free_form_object( spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1009,7 +1029,7 @@ def test_object_additional_properties_list(self, unmarshallers_factory): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller({"user_ids": [1, 2, 3, 4]}) + result = unmarshaller.unmarshal({"user_ids": [1, 2, 3, 4]}) assert result == { "user_ids": [1, 2, 3, 4], @@ -1028,7 +1048,7 @@ def test_object_additional_properties(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1051,7 +1071,7 @@ def test_object_additional_properties_object( spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1072,7 +1092,7 @@ def test_object_min_properties(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1094,7 +1114,7 @@ def test_object_min_properties_invalid(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.parametrize( "value", @@ -1113,7 +1133,7 @@ def test_object_min_properties_invalid_schema( unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.parametrize( "value", @@ -1132,7 +1152,7 @@ def test_object_max_properties(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1154,7 +1174,7 @@ def test_object_max_properties_invalid(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.parametrize( "value", @@ -1173,7 +1193,7 @@ def test_object_max_properties_invalid_schema( unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) def test_any_one_of(self, unmarshallers_factory): schema = { @@ -1193,7 +1213,7 @@ def test_any_one_of(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = ["hello"] - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1215,7 +1235,7 @@ def test_any_any_of(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = ["hello"] - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1234,7 +1254,7 @@ def test_any_all_of(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = ["hello"] - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1288,7 +1308,7 @@ def test_any_all_of_invalid_properties(self, value, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.xfail( reason=( @@ -1310,7 +1330,7 @@ def test_any_format_one_of(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = "2018-01-02" - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == date(2018, 1, 2) @@ -1328,7 +1348,7 @@ def test_any_one_of_any(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = "2018-01-02" - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == date(2018, 1, 2) @@ -1346,7 +1366,7 @@ def test_any_any_of_any(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = "2018-01-02" - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == date(2018, 1, 2) @@ -1364,7 +1384,7 @@ def test_any_all_of_any(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = "2018-01-02" - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == date(2018, 1, 2) @@ -1402,7 +1422,7 @@ def test_any_of_no_valid(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.parametrize( "value", @@ -1442,7 +1462,7 @@ def test_any_one_of_no_valid(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.parametrize( "value", @@ -1459,7 +1479,7 @@ def test_any_any_of_different_type(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.parametrize( "value", @@ -1483,7 +1503,7 @@ def test_any_one_of_different_type(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.parametrize( "value", @@ -1529,7 +1549,7 @@ def test_any_any_of_unambiguous(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1555,7 +1575,7 @@ def test_object_multiple_any_of(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1582,7 +1602,7 @@ def test_object_multiple_one_of(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.parametrize( "value", @@ -1630,7 +1650,7 @@ def test_any_one_of_unambiguous(self, unmarshallers_factory, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1642,7 +1662,7 @@ def test_null_undefined(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(UnknownType): - unmarshaller(None) + unmarshaller.unmarshal(None) @pytest.mark.parametrize( "type", @@ -1659,7 +1679,7 @@ def test_nullable(self, unmarshallers_factory, type): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(None) + result = unmarshaller.unmarshal(None) assert result is None @@ -1682,7 +1702,7 @@ def test_not_nullable(self, unmarshallers_factory, type): InvalidSchemaValue, match=f"not valid for schema of type {type}", ) as exc_info: - unmarshaller(None) + unmarshaller.unmarshal(None) assert len(exc_info.value.schema_errors) == 2 assert ( "None for not nullable" in exc_info.value.schema_errors[0].message @@ -1709,7 +1729,7 @@ def test_basic_type_oas30_formats( spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == unmarshalled @@ -1734,7 +1754,7 @@ def test_basic_type_oas30_formats_invalid( InvalidSchemaValue, match=f"not valid for schema of type {type}", ) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert ( f"is not a '{format}'" in exc_info.value.schema_errors[0].message @@ -1758,7 +1778,7 @@ def test_string_format_binary_invalid(self, unmarshallers_factory): InvalidSchemaValue, match=f"not valid for schema of type {type}", ): - unmarshaller(value) + unmarshaller.unmarshal(value) @pytest.mark.xfail( reason=( @@ -1785,7 +1805,7 @@ def test_nultiple_types_undefined( unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(SchemaError): - unmarshaller(value) + unmarshaller.unmarshal(value) def test_integer_default_nullable(self, unmarshallers_factory): default_value = 123 @@ -1798,7 +1818,7 @@ def test_integer_default_nullable(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = None - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result is None @@ -1814,7 +1834,7 @@ def test_array_nullable(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = None - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result is None @@ -1832,7 +1852,7 @@ def test_object_property_nullable(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) value = {"foo": None} - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1861,7 +1881,7 @@ def test_write_only_properties(self, unmarshallers_factory): value = {"id": 10} # readOnly properties may be admitted in a Response context - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -1882,7 +1902,7 @@ def test_read_only_properties_invalid(self, unmarshallers_factory): # readOnly properties are not admitted on a Request context with pytest.raises(InvalidSchemaValue): - unmarshaller(value) + unmarshaller.unmarshal(value) class TestOAS30ResponseSchemaUnmarshallersFactory( @@ -1908,7 +1928,7 @@ def test_read_only_properties(self, unmarshallers_factory): unmarshaller = unmarshallers_factory.create(spec) # readOnly properties may be admitted in a Response context - result = unmarshaller({"id": 10}) + result = unmarshaller.unmarshal({"id": 10}) assert result == { "id": 10, @@ -1930,7 +1950,7 @@ def test_write_only_properties_invalid(self, unmarshallers_factory): # readOnly properties are not admitted on a Request context with pytest.raises(InvalidSchemaValue): - unmarshaller({"id": 10}) + unmarshaller.unmarshal({"id": 10}) class TestOAS31SchemaUnmarshallersFactory( @@ -1987,14 +2007,14 @@ def test_basic_types_invalid(self, unmarshallers_factory, type, value): InvalidSchemaValue, match=f"not valid for schema of type {type}", ): - unmarshaller(value) + unmarshaller.unmarshal(value) def test_null(self, unmarshallers_factory): schema = {"type": "null"} spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(None) + result = unmarshaller.unmarshal(None) assert result is None @@ -2005,7 +2025,7 @@ def test_null_invalid(self, unmarshallers_factory, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert ( "is not of type 'null'" in exc_info.value.schema_errors[0].message @@ -2027,7 +2047,7 @@ def test_nultiple_types(self, unmarshallers_factory, types, value): spec = Spec.from_dict(schema, validator=None) unmarshaller = unmarshallers_factory.create(spec) - result = unmarshaller(value) + result = unmarshaller.unmarshal(value) assert result == value @@ -2048,6 +2068,6 @@ def test_nultiple_types_invalid(self, unmarshallers_factory, types, value): unmarshaller = unmarshallers_factory.create(spec) with pytest.raises(InvalidSchemaValue) as exc_info: - unmarshaller(value) + unmarshaller.unmarshal(value) assert len(exc_info.value.schema_errors) == 1 assert "is not of type" in exc_info.value.schema_errors[0].message diff --git a/tests/integration/validation/test_request_validators.py b/tests/integration/validation/test_request_validators.py new file mode 100644 index 00000000..48eed5a7 --- /dev/null +++ b/tests/integration/validation/test_request_validators.py @@ -0,0 +1,132 @@ +import json +from base64 import b64encode + +import pytest + +from openapi_core import V30RequestValidator +from openapi_core.datatypes import Parameters +from openapi_core.templating.media_types.exceptions import MediaTypeNotFound +from openapi_core.templating.paths.exceptions import OperationNotFound +from openapi_core.templating.paths.exceptions import PathNotFound +from openapi_core.templating.security.exceptions import SecurityNotFound +from openapi_core.testing import MockRequest +from openapi_core.unmarshalling.request.unmarshallers import ( + V30RequestUnmarshaller, +) +from openapi_core.validation.request.exceptions import InvalidParameter +from openapi_core.validation.request.exceptions import MissingRequiredParameter +from openapi_core.validation.request.exceptions import ( + MissingRequiredRequestBody, +) +from openapi_core.validation.request.exceptions import RequestBodyError +from openapi_core.validation.request.exceptions import SecurityError + + +class TestRequestValidator: + host_url = "http://petstore.swagger.io" + + api_key = "12345" + + @property + def api_key_encoded(self): + api_key_bytes = self.api_key.encode("utf8") + api_key_bytes_enc = b64encode(api_key_bytes) + return str(api_key_bytes_enc, "utf8") + + @pytest.fixture(scope="session") + def spec_dict(self, v30_petstore_content): + return v30_petstore_content + + @pytest.fixture(scope="session") + def spec(self, v30_petstore_spec): + return v30_petstore_spec + + @pytest.fixture(scope="session") + def request_validator(self, spec): + return V30RequestValidator(spec) + + def test_request_server_error(self, request_validator): + request = MockRequest("http://petstore.invalid.net/v1", "get", "/") + + with pytest.raises(PathNotFound): + request_validator.validate(request) + + def test_path_not_found(self, request_validator): + request = MockRequest(self.host_url, "get", "/v1") + + with pytest.raises(PathNotFound): + request_validator.validate(request) + + def test_operation_not_found(self, request_validator): + request = MockRequest(self.host_url, "patch", "/v1/pets") + + with pytest.raises(OperationNotFound): + request_validator.validate(request) + + def test_missing_parameter(self, request_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + + with pytest.raises(MissingRequiredParameter): + with pytest.warns(DeprecationWarning): + request_validator.validate(request) + + def test_security_not_found(self, request_validator): + request = MockRequest( + self.host_url, + "get", + "/v1/pets/1", + path_pattern="/v1/pets/{petId}", + view_args={"petId": "1"}, + ) + + with pytest.raises(SecurityError) as exc_info: + request_validator.validate(request) + + assert exc_info.value.__cause__ == SecurityNotFound( + [["petstore_auth"]] + ) + + def test_media_type_not_found(self, request_validator): + data = "csv,data" + headers = { + "api-key": self.api_key_encoded, + } + cookies = { + "user": "123", + } + request = MockRequest( + "https://development.gigantic-server.com", + "post", + "/v1/pets", + path_pattern="/v1/pets", + mimetype="text/csv", + data=data, + headers=headers, + cookies=cookies, + ) + + with pytest.raises(RequestBodyError) as exc_info: + request_validator.validate(request) + + assert exc_info.value.__cause__ == MediaTypeNotFound( + mimetype="text/csv", + availableMimetypes=["application/json", "text/plain"], + ) + + def test_valid(self, request_validator): + authorization = "Basic " + self.api_key_encoded + headers = { + "Authorization": authorization, + } + request = MockRequest( + self.host_url, + "get", + "/v1/pets/1", + path_pattern="/v1/pets/{petId}", + view_args={"petId": "1"}, + headers=headers, + ) + + result = request_validator.validate(request) + + assert result is None diff --git a/tests/integration/validation/test_response_validators.py b/tests/integration/validation/test_response_validators.py new file mode 100644 index 00000000..c7d7d2fa --- /dev/null +++ b/tests/integration/validation/test_response_validators.py @@ -0,0 +1,160 @@ +import json +from dataclasses import is_dataclass + +import pytest + +from openapi_core import V30ResponseValidator +from openapi_core.deserializing.media_types.exceptions import ( + MediaTypeDeserializeError, +) +from openapi_core.templating.media_types.exceptions import MediaTypeNotFound +from openapi_core.templating.paths.exceptions import OperationNotFound +from openapi_core.templating.paths.exceptions import PathNotFound +from openapi_core.templating.responses.exceptions import ResponseNotFound +from openapi_core.testing import MockRequest +from openapi_core.testing import MockResponse +from openapi_core.unmarshalling.response.unmarshallers import ( + V30ResponseUnmarshaller, +) +from openapi_core.validation.response.exceptions import DataError +from openapi_core.validation.response.exceptions import InvalidData +from openapi_core.validation.response.exceptions import InvalidHeader +from openapi_core.validation.response.exceptions import MissingData +from openapi_core.validation.schemas.exceptions import InvalidSchemaValue + + +class TestResponseValidator: + host_url = "http://petstore.swagger.io" + + @pytest.fixture(scope="session") + def spec_dict(self, v30_petstore_content): + return v30_petstore_content + + @pytest.fixture(scope="session") + def spec(self, v30_petstore_spec): + return v30_petstore_spec + + @pytest.fixture(scope="session") + def response_validator(self, spec): + return V30ResponseValidator(spec) + + def test_invalid_server(self, response_validator): + request = MockRequest("http://petstore.invalid.net/v1", "get", "/") + response = MockResponse("Not Found", status_code=404) + + with pytest.raises(PathNotFound): + response_validator.validate(request, response) + + def test_invalid_operation(self, response_validator): + request = MockRequest(self.host_url, "patch", "/v1/pets") + response = MockResponse("Not Found", status_code=404) + + with pytest.raises(OperationNotFound): + response_validator.validate(request, response) + + def test_invalid_response(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + response = MockResponse("Not Found", status_code=409) + + with pytest.raises(ResponseNotFound): + response_validator.validate(request, response) + + def test_invalid_content_type(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + response = MockResponse("Not Found", mimetype="text/csv") + + with pytest.raises(DataError) as exc_info: + response_validator.validate(request, response) + + assert type(exc_info.value.__cause__) == MediaTypeNotFound + + def test_missing_body(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + response = MockResponse(None) + + with pytest.raises(MissingData): + response_validator.validate(request, response) + + def test_invalid_media_type(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + response = MockResponse("abcde") + + with pytest.raises(DataError) as exc_info: + response_validator.validate(request, response) + + assert exc_info.value.__cause__ == MediaTypeDeserializeError( + mimetype="application/json", value="abcde" + ) + + def test_invalid_media_type_value(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + response = MockResponse("{}") + + with pytest.raises(DataError) as exc_info: + response_validator.validate(request, response) + + assert type(exc_info.value.__cause__) == InvalidSchemaValue + + def test_invalid_value(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/tags") + response_json = { + "data": [ + {"id": 1, "name": "Sparky"}, + ], + } + response_data = json.dumps(response_json) + response = MockResponse(response_data) + + with pytest.raises(InvalidData) as exc_info: + response_validator.validate(request, response) + + assert type(exc_info.value.__cause__) == InvalidSchemaValue + + def test_invalid_header(self, response_validator): + request = MockRequest( + self.host_url, + "delete", + "/v1/tags", + path_pattern="/v1/tags", + ) + response_json = { + "data": [ + { + "id": 1, + "name": "Sparky", + "ears": { + "healthy": True, + }, + }, + ], + } + response_data = json.dumps(response_json) + headers = { + "x-delete-confirm": "true", + "x-delete-date": "today", + } + response = MockResponse(response_data, headers=headers) + + with pytest.raises(InvalidHeader): + with pytest.warns(DeprecationWarning): + response_validator.validate(request, response) + + def test_valid(self, response_validator): + request = MockRequest(self.host_url, "get", "/v1/pets") + response_json = { + "data": [ + { + "id": 1, + "name": "Sparky", + "ears": { + "healthy": True, + }, + }, + ], + } + response_data = json.dumps(response_json) + response = MockResponse(response_data) + + result = response_validator.validate(request, response) + + assert result is None diff --git a/tests/unit/validation/conftest.py b/tests/unit/conftest.py similarity index 100% rename from tests/unit/validation/conftest.py rename to tests/unit/conftest.py diff --git a/tests/unit/contrib/django/test_django.py b/tests/unit/contrib/django/test_django.py index fb4d0316..907875bf 100644 --- a/tests/unit/contrib/django/test_django.py +++ b/tests/unit/contrib/django/test_django.py @@ -4,7 +4,7 @@ from openapi_core.contrib.django import DjangoOpenAPIRequest from openapi_core.contrib.django import DjangoOpenAPIResponse -from openapi_core.validation.request.datatypes import RequestParameters +from openapi_core.datatypes import RequestParameters class BaseTestDjango: diff --git a/tests/unit/contrib/flask/test_flask_requests.py b/tests/unit/contrib/flask/test_flask_requests.py index 80f92181..ca173267 100644 --- a/tests/unit/contrib/flask/test_flask_requests.py +++ b/tests/unit/contrib/flask/test_flask_requests.py @@ -5,7 +5,7 @@ from werkzeug.datastructures import ImmutableMultiDict from openapi_core.contrib.flask import FlaskOpenAPIRequest -from openapi_core.validation.request.datatypes import RequestParameters +from openapi_core.datatypes import RequestParameters class TestFlaskOpenAPIRequest: diff --git a/tests/unit/contrib/requests/test_requests_requests.py b/tests/unit/contrib/requests/test_requests_requests.py index 45bfbdf8..762a115a 100644 --- a/tests/unit/contrib/requests/test_requests_requests.py +++ b/tests/unit/contrib/requests/test_requests_requests.py @@ -3,7 +3,7 @@ from werkzeug.datastructures import ImmutableMultiDict from openapi_core.contrib.requests import RequestsOpenAPIRequest -from openapi_core.validation.request.datatypes import RequestParameters +from openapi_core.datatypes import RequestParameters class TestRequestsOpenAPIRequest: diff --git a/tests/unit/test_shortcuts.py b/tests/unit/test_shortcuts.py new file mode 100644 index 00000000..796722e3 --- /dev/null +++ b/tests/unit/test_shortcuts.py @@ -0,0 +1,570 @@ +from unittest import mock + +import pytest + +from openapi_core import RequestValidator +from openapi_core import ResponseValidator +from openapi_core import unmarshal_request +from openapi_core import unmarshal_response +from openapi_core import unmarshal_webhook_request +from openapi_core import unmarshal_webhook_response +from openapi_core import validate_request +from openapi_core import validate_response +from openapi_core.exceptions import SpecError +from openapi_core.protocols import Request +from openapi_core.protocols import Response +from openapi_core.protocols import WebhookRequest +from openapi_core.testing.datatypes import ResultMock +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult +from openapi_core.unmarshalling.request.unmarshallers import ( + APICallRequestUnmarshaller, +) +from openapi_core.unmarshalling.request.unmarshallers import ( + WebhookRequestUnmarshaller, +) +from openapi_core.unmarshalling.response.unmarshallers import ( + APICallResponseUnmarshaller, +) + + +class MockClass: + schema_validators_factory = None + schema_unmarshallers_factory = None + + unmarshal_calls = [] + return_unmarshal = None + + @classmethod + def setUp(cls, return_unmarshal): + cls.unmarshal_calls = [] + cls.return_unmarshal = return_unmarshal + + +class MockReqClass(MockClass): + assert_request = None + + @classmethod + def setUp(cls, return_unmarshal, assert_request): + super().setUp(return_unmarshal) + cls.assert_request = assert_request + + def unmarshal(self, req): + self.unmarshal_calls.append([req]) + assert req == self.assert_request + return self.return_unmarshal + + +class MockRespClass(MockClass): + assert_request = None + assert_response = None + + @classmethod + def setUp(cls, return_unmarshal, assert_request, assert_response): + super().setUp(return_unmarshal) + cls.assert_request = assert_request + cls.assert_response = assert_response + + def unmarshal(self, req, resp): + self.unmarshal_calls.append([req, resp]) + assert req == self.assert_request + assert resp == self.assert_response + return self.return_unmarshal + + +class TestUnmarshalRequest: + def test_spec_not_detected(self, spec_invalid): + request = mock.Mock(spec=Request) + + with pytest.raises(SpecError): + unmarshal_request(request, spec=spec_invalid) + + def test_request_type_invalid(self, spec_v31): + request = mock.sentinel.request + + with pytest.raises(TypeError): + unmarshal_request(request, spec=spec_v31) + + def test_spec_type_invalid(self): + request = mock.Mock(spec=Request) + spec = mock.sentinel.spec + + with pytest.raises(TypeError): + unmarshal_request(request, spec=spec) + + def test_cls_type_invalid(self, spec_v31): + request = mock.Mock(spec=Request) + + with pytest.raises(TypeError): + unmarshal_request(request, spec=spec_v31, cls=Exception) + + @mock.patch( + "openapi_core.unmarshalling.request.unmarshallers.APICallRequestUnmarshaller." + "unmarshal", + ) + def test_request(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=Request) + + result = unmarshal_request(request, spec=spec_v31) + + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request) + + +class TestUnmarshalWebhookRequest: + def test_spec_not_detected(self, spec_invalid): + request = mock.Mock(spec=WebhookRequest) + + with pytest.raises(SpecError): + unmarshal_webhook_request(request, spec=spec_invalid) + + def test_request_type_invalid(self, spec_v31): + request = mock.sentinel.request + + with pytest.raises(TypeError): + unmarshal_webhook_request(request, spec=spec_v31) + + def test_spec_type_invalid(self): + request = mock.Mock(spec=WebhookRequest) + spec = mock.sentinel.spec + + with pytest.raises(TypeError): + unmarshal_webhook_request(request, spec=spec) + + def test_cls_type_invalid(self, spec_v31): + request = mock.Mock(spec=WebhookRequest) + + with pytest.raises(TypeError): + unmarshal_webhook_request(request, spec=spec_v31, cls=Exception) + + @mock.patch( + "openapi_core.unmarshalling.request.unmarshallers.WebhookRequestUnmarshaller." + "unmarshal", + ) + def test_request(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=WebhookRequest) + + result = unmarshal_webhook_request(request, spec=spec_v31) + + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request) + + +class TestUnmarshalResponse: + def test_spec_not_detected(self, spec_invalid): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + + with pytest.raises(SpecError): + unmarshal_response(request, response, spec=spec_invalid) + + def test_request_type_invalid(self, spec_v31): + request = mock.sentinel.request + response = mock.Mock(spec=Response) + + with pytest.raises(TypeError): + unmarshal_response(request, response, spec=spec_v31) + + def test_response_type_invalid(self, spec_v31): + request = mock.Mock(spec=Request) + response = mock.sentinel.response + + with pytest.raises(TypeError): + unmarshal_response(request, response, spec=spec_v31) + + def test_spec_type_invalid(self): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + spec = mock.sentinel.spec + + with pytest.raises(TypeError): + unmarshal_response(request, response, spec=spec) + + def test_cls_type_invalid(self, spec_v31): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + + with pytest.raises(TypeError): + unmarshal_response(request, response, spec=spec_v31, cls=Exception) + + @mock.patch( + "openapi_core.unmarshalling.response.unmarshallers.APICallResponseUnmarshaller." + "unmarshal", + ) + def test_request_response(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + + result = unmarshal_response(request, response, spec=spec_v31) + + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request, response) + + +class TestUnmarshalWebhookResponse: + def test_spec_not_detected(self, spec_invalid): + request = mock.Mock(spec=WebhookRequest) + response = mock.Mock(spec=Response) + + with pytest.raises(SpecError): + unmarshal_webhook_response(request, response, spec=spec_invalid) + + def test_request_type_invalid(self, spec_v31): + request = mock.sentinel.request + response = mock.Mock(spec=Response) + + with pytest.raises(TypeError): + unmarshal_webhook_response(request, response, spec=spec_v31) + + def test_response_type_invalid(self, spec_v31): + request = mock.Mock(spec=WebhookRequest) + response = mock.sentinel.response + + with pytest.raises(TypeError): + unmarshal_webhook_response(request, response, spec=spec_v31) + + def test_spec_type_invalid(self): + request = mock.Mock(spec=WebhookRequest) + response = mock.Mock(spec=Response) + spec = mock.sentinel.spec + + with pytest.raises(TypeError): + unmarshal_webhook_response(request, response, spec=spec) + + def test_cls_type_invalid(self, spec_v31): + request = mock.Mock(spec=WebhookRequest) + response = mock.Mock(spec=Response) + + with pytest.raises(TypeError): + unmarshal_webhook_response( + request, response, spec=spec_v31, cls=Exception + ) + + @mock.patch( + "openapi_core.unmarshalling.response.unmarshallers.WebhookResponseUnmarshaller." + "unmarshal", + ) + def test_request_response(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=WebhookRequest) + response = mock.Mock(spec=Response) + + result = unmarshal_webhook_response(request, response, spec=spec_v31) + + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request, response) + + +class TestValidateRequest: + def test_spec_not_detected(self, spec_invalid): + request = mock.Mock(spec=Request) + + with pytest.raises(SpecError): + validate_request(request, spec=spec_invalid) + + def test_request_type_invalid(self, spec_v31): + request = mock.sentinel.request + + with pytest.raises(TypeError): + validate_request(request, spec=spec_v31) + + def test_spec_type_invalid(self): + request = mock.Mock(spec=Request) + spec = mock.sentinel.spec + + with pytest.raises(TypeError): + validate_request(request, spec=spec) + + @mock.patch( + "openapi_core.unmarshalling.request.unmarshallers.APICallRequestUnmarshaller." + "unmarshal", + ) + def test_request(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=Request) + + result = validate_request(request, spec=spec_v31) + + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request) + + @mock.patch( + "openapi_core.unmarshalling.request.unmarshallers.APICallRequestUnmarshaller." + "unmarshal", + ) + def test_spec_as_first_arg_deprecated(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=Request) + + with pytest.warns(DeprecationWarning): + result = validate_request(spec_v31, request) + + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request) + + @mock.patch( + "openapi_core.unmarshalling.request.unmarshallers.APICallRequestUnmarshaller." + "unmarshal", + ) + def test_request_error(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=Request) + mock_unmarshal.return_value = ResultMock(error_to_raise=ValueError) + + with pytest.raises(ValueError): + validate_request(request, spec=spec_v31) + + mock_unmarshal.assert_called_once_with(request) + + def test_validator(self, spec_v31): + request = mock.Mock(spec=Request) + validator = mock.Mock(spec=RequestValidator) + + with pytest.warns(DeprecationWarning): + result = validate_request( + request, spec=spec_v31, validator=validator + ) + + assert result == validator.validate.return_value + validator.validate.assert_called_once_with( + spec_v31, request, base_url=None + ) + + def test_cls(self, spec_v31): + request = mock.Mock(spec=Request) + unmarshal = mock.Mock(spec=RequestUnmarshalResult) + TestAPICallReq = type( + "TestAPICallReq", + (MockReqClass, APICallRequestUnmarshaller), + {}, + ) + TestAPICallReq.setUp(unmarshal, request) + + result = validate_request(request, spec=spec_v31, cls=TestAPICallReq) + + assert result == unmarshal + assert len(TestAPICallReq.unmarshal_calls) == 1 + + def test_cls_invalid(self, spec_v31): + request = mock.Mock(spec=Request) + + with pytest.raises(TypeError): + validate_request(request, spec=spec_v31, cls=Exception) + + @mock.patch( + "openapi_core.unmarshalling.request.unmarshallers.WebhookRequestUnmarshaller." + "unmarshal", + ) + def test_webhook_request(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=WebhookRequest) + + result = validate_request(request, spec=spec_v31) + + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request) + + def test_webhook_request_validator_not_found(self, spec_v30): + request = mock.Mock(spec=WebhookRequest) + + with pytest.raises(SpecError): + validate_request(request, spec=spec_v30) + + @mock.patch( + "openapi_core.unmarshalling.request.unmarshallers.WebhookRequestUnmarshaller." + "unmarshal", + ) + def test_webhook_request_error(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=WebhookRequest) + mock_unmarshal.return_value = ResultMock(error_to_raise=ValueError) + + with pytest.raises(ValueError): + validate_request(request, spec=spec_v31) + + mock_unmarshal.assert_called_once_with(request) + + def test_webhook_cls(self, spec_v31): + request = mock.Mock(spec=WebhookRequest) + unmarshal = mock.Mock(spec=RequestUnmarshalResult) + TestWebhookReq = type( + "TestWebhookReq", + (MockReqClass, WebhookRequestUnmarshaller), + {}, + ) + TestWebhookReq.setUp(unmarshal, request) + + result = validate_request(request, spec=spec_v31, cls=TestWebhookReq) + + assert result == unmarshal + assert len(TestWebhookReq.unmarshal_calls) == 1 + + def test_webhook_cls_invalid(self, spec_v31): + request = mock.Mock(spec=WebhookRequest) + + with pytest.raises(TypeError): + validate_request(request, spec=spec_v31, cls=Exception) + + +class TestValidateResponse: + def test_spec_not_detected(self, spec_invalid): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + + with pytest.raises(SpecError): + validate_response(request, response, spec=spec_invalid) + + def test_request_type_invalid(self, spec_v31): + request = mock.sentinel.request + response = mock.Mock(spec=Response) + + with pytest.raises(TypeError): + validate_response(request, response, spec=spec_v31) + + def test_response_type_invalid(self, spec_v31): + request = mock.Mock(spec=Request) + response = mock.sentinel.response + + with pytest.raises(TypeError): + validate_response(request, response, spec=spec_v31) + + def test_spec_type_invalid(self): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + spec = mock.sentinel.spec + + with pytest.raises(TypeError): + validate_response(request, response, spec=spec) + + @mock.patch( + "openapi_core.unmarshalling.response.unmarshallers.APICallResponseUnmarshaller." + "unmarshal", + ) + def test_request_response(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + + result = validate_response(request, response, spec=spec_v31) + + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request, response) + + @mock.patch( + "openapi_core.unmarshalling.response.unmarshallers.APICallResponseUnmarshaller." + "unmarshal", + ) + def test_spec_as_first_arg_deprecated(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + + with pytest.warns(DeprecationWarning): + result = validate_response(spec_v31, request, response) + + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request, response) + + @mock.patch( + "openapi_core.unmarshalling.response.unmarshallers.APICallResponseUnmarshaller." + "unmarshal", + ) + def test_request_response_error(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + mock_unmarshal.return_value = ResultMock(error_to_raise=ValueError) + + with pytest.raises(ValueError): + validate_response(request, response, spec=spec_v31) + + mock_unmarshal.assert_called_once_with(request, response) + + def test_validator(self, spec_v31): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + validator = mock.Mock(spec=ResponseValidator) + + with pytest.warns(DeprecationWarning): + result = validate_response( + request, response, spec=spec_v31, validator=validator + ) + + assert result == validator.validate.return_value + validator.validate.assert_called_once_with( + spec_v31, request, response, base_url=None + ) + + def test_cls(self, spec_v31): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + unmarshal = mock.Mock(spec=RequestUnmarshalResult) + TestAPICallResp = type( + "TestAPICallResp", + (MockRespClass, APICallResponseUnmarshaller), + {}, + ) + TestAPICallResp.setUp(unmarshal, request, response) + + result = validate_response( + request, response, spec=spec_v31, cls=TestAPICallResp + ) + + assert result == unmarshal + assert len(TestAPICallResp.unmarshal_calls) == 1 + + def test_cls_type_invalid(self, spec_v31): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + + with pytest.raises(TypeError): + validate_response(request, response, spec=spec_v31, cls=Exception) + + def test_webhook_response_validator_not_found(self, spec_v30): + request = mock.Mock(spec=WebhookRequest) + response = mock.Mock(spec=Response) + + with pytest.raises(SpecError): + validate_response(request, response, spec=spec_v30) + + @mock.patch( + "openapi_core.unmarshalling.response.unmarshallers.WebhookResponseUnmarshaller." + "unmarshal", + ) + def test_webhook_request(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=WebhookRequest) + response = mock.Mock(spec=Response) + + result = validate_response(request, response, spec=spec_v31) + + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request, response) + + @mock.patch( + "openapi_core.unmarshalling.response.unmarshallers.WebhookResponseUnmarshaller." + "unmarshal", + ) + def test_webhook_request_error(self, mock_unmarshal, spec_v31): + request = mock.Mock(spec=WebhookRequest) + response = mock.Mock(spec=Response) + mock_unmarshal.return_value = ResultMock(error_to_raise=ValueError) + + with pytest.raises(ValueError): + validate_response(request, response, spec=spec_v31) + + mock_unmarshal.assert_called_once_with(request, response) + + def test_webhook_cls(self, spec_v31): + request = mock.Mock(spec=WebhookRequest) + response = mock.Mock(spec=Response) + unmarshal = mock.Mock(spec=RequestUnmarshalResult) + TestWebhookResp = type( + "TestWebhookResp", + (MockRespClass, APICallResponseUnmarshaller), + {}, + ) + TestWebhookResp.setUp(unmarshal, request, response) + + result = validate_response( + request, response, spec=spec_v31, cls=TestWebhookResp + ) + + assert result == unmarshal + assert len(TestWebhookResp.unmarshal_calls) == 1 + + def test_webhook_cls_type_invalid(self, spec_v31): + request = mock.Mock(spec=WebhookRequest) + response = mock.Mock(spec=Response) + + with pytest.raises(TypeError): + validate_response(request, response, spec=spec_v31, cls=Exception) diff --git a/tests/unit/unmarshalling/test_path_item_params_validator.py b/tests/unit/unmarshalling/test_path_item_params_validator.py new file mode 100644 index 00000000..1c0aabf7 --- /dev/null +++ b/tests/unit/unmarshalling/test_path_item_params_validator.py @@ -0,0 +1,179 @@ +from dataclasses import is_dataclass + +import pytest + +from openapi_core import Spec +from openapi_core import V30RequestUnmarshaller +from openapi_core import openapi_request_validator +from openapi_core.casting.schemas.exceptions import CastError +from openapi_core.datatypes import Parameters +from openapi_core.testing import MockRequest +from openapi_core.validation.request.exceptions import MissingRequiredParameter +from openapi_core.validation.request.exceptions import ParameterError + + +class TestPathItemParamsValidator: + @pytest.fixture(scope="session") + def spec_dict(self): + return { + "openapi": "3.0.0", + "info": { + "title": "Test path item parameter validation", + "version": "0.1", + }, + "paths": { + "/resource": { + "parameters": [ + { + "name": "resId", + "in": "query", + "required": True, + "schema": { + "type": "integer", + }, + }, + ], + "get": { + "responses": { + "default": {"description": "Return the resource."} + } + }, + } + }, + } + + @pytest.fixture(scope="session") + def spec(self, spec_dict): + return Spec.from_dict(spec_dict) + + @pytest.fixture(scope="session") + def request_unmarshaller(self, spec): + return V30RequestUnmarshaller(spec) + + def test_request_missing_param(self, request_unmarshaller): + request = MockRequest("http://example.com", "get", "/resource") + + result = request_unmarshaller.unmarshal(request) + + assert len(result.errors) == 1 + assert type(result.errors[0]) == MissingRequiredParameter + assert result.body is None + assert result.parameters == Parameters() + + def test_request_invalid_param(self, request_unmarshaller): + request = MockRequest( + "http://example.com", + "get", + "/resource", + args={"resId": "invalid"}, + ) + + result = request_unmarshaller.unmarshal(request) + + assert result.errors == [ + ParameterError(name="resId", location="query") + ] + assert type(result.errors[0].__cause__) is CastError + assert result.body is None + assert result.parameters == Parameters() + + def test_request_valid_param(self, request_unmarshaller): + request = MockRequest( + "http://example.com", + "get", + "/resource", + args={"resId": "10"}, + ) + + result = request_unmarshaller.unmarshal(request) + + assert len(result.errors) == 0 + assert result.body is None + assert result.parameters == Parameters(query={"resId": 10}) + + def test_request_override_param(self, spec, spec_dict): + # override path parameter on operation + spec_dict["paths"]["/resource"]["get"]["parameters"] = [ + { + # full valid parameter object required + "name": "resId", + "in": "query", + "required": False, + "schema": { + "type": "integer", + }, + } + ] + request = MockRequest("http://example.com", "get", "/resource") + with pytest.warns(DeprecationWarning): + result = openapi_request_validator.validate( + spec, request, base_url="http://example.com" + ) + + assert len(result.errors) == 0 + assert result.body is None + assert result.parameters == Parameters() + + def test_request_override_param_uniqueness(self, spec, spec_dict): + # add parameter on operation with same name as on path but + # different location + spec_dict["paths"]["/resource"]["get"]["parameters"] = [ + { + # full valid parameter object required + "name": "resId", + "in": "header", + "required": False, + "schema": { + "type": "integer", + }, + } + ] + request = MockRequest("http://example.com", "get", "/resource") + with pytest.warns(DeprecationWarning): + result = openapi_request_validator.validate( + spec, request, base_url="http://example.com" + ) + + assert len(result.errors) == 1 + assert type(result.errors[0]) == MissingRequiredParameter + assert result.body is None + assert result.parameters == Parameters() + + def test_request_object_deep_object_params(self, spec, spec_dict): + # override path parameter on operation + spec_dict["paths"]["/resource"]["parameters"] = [ + { + # full valid parameter object required + "name": "paramObj", + "in": "query", + "required": True, + "schema": { + "x-model": "paramObj", + "type": "object", + "properties": { + "count": {"type": "integer"}, + "name": {"type": "string"}, + }, + }, + "explode": True, + "style": "deepObject", + } + ] + + request = MockRequest( + "http://example.com", + "get", + "/resource", + args={"paramObj[count]": 2, "paramObj[name]": "John"}, + ) + with pytest.warns(DeprecationWarning): + result = openapi_request_validator.validate( + spec, request, base_url="http://example.com" + ) + + assert len(result.errors) == 0 + assert result.body is None + assert len(result.parameters.query) == 1 + assert is_dataclass(result.parameters.query["paramObj"]) + assert result.parameters.query["paramObj"].count == 2 + assert result.parameters.query["paramObj"].name == "John" diff --git a/tests/unit/unmarshalling/test_unmarshal.py b/tests/unit/unmarshalling/test_schema_unmarshallers.py similarity index 75% rename from tests/unit/unmarshalling/test_unmarshal.py rename to tests/unit/unmarshalling/test_schema_unmarshallers.py index 15a604c0..d010f39c 100644 --- a/tests/unit/unmarshalling/test_unmarshal.py +++ b/tests/unit/unmarshalling/test_schema_unmarshallers.py @@ -3,7 +3,7 @@ import pytest from openapi_core.spec.paths import Spec -from openapi_core.unmarshalling.schemas import oas30_write_types_unmarshaller +from openapi_core.unmarshalling.schemas import oas30_types_unmarshaller from openapi_core.unmarshalling.schemas.exceptions import ( FormatterNotFoundError, ) @@ -25,21 +25,49 @@ def create_unmarshaller( custom_formatters = custom_formatters or {} return SchemaUnmarshallersFactory( validators_factory, - oas30_write_types_unmarshaller, + oas30_types_unmarshaller, custom_formatters=custom_formatters, ).create(schema) return create_unmarshaller -class TestOAS30SchemaUnmarshallerCall: - @pytest.fixture - def unmarshaller_factory(self, schema_unmarshaller_factory): - return partial( - schema_unmarshaller_factory, - oas30_write_schema_validators_factory, - ) +@pytest.fixture +def unmarshaller_factory(schema_unmarshaller_factory): + return partial( + schema_unmarshaller_factory, + oas30_write_schema_validators_factory, + ) + + +class TestOAS30SchemaUnmarshallerFactoryCreate: + def test_string_format_unknown(self, unmarshaller_factory): + unknown_format = "unknown" + schema = { + "type": "string", + "format": unknown_format, + } + spec = Spec.from_dict(schema, validator=None) + + with pytest.raises(FormatterNotFoundError): + unmarshaller_factory(spec) + + def test_string_format_invalid_value(self, unmarshaller_factory): + custom_format = "custom" + schema = { + "type": "string", + "format": custom_format, + } + spec = Spec.from_dict(schema, validator=None) + + with pytest.raises( + FormatterNotFoundError, + match="Formatter not found for custom format", + ): + unmarshaller_factory(spec) + +class TestOAS30SchemaUnmarshallerUnmarshal: def test_schema_custom_format_invalid(self, unmarshaller_factory): class CustomFormatter(Formatter): def format(self, value): @@ -56,12 +84,13 @@ def format(self, value): } spec = Spec.from_dict(schema, validator=None) value = "x" + unmarshaller = unmarshaller_factory( + spec, + custom_formatters=custom_formatters, + ) with pytest.raises(FormatUnmarshalError): - unmarshaller_factory( - spec, - custom_formatters=custom_formatters, - )(value) + unmarshaller.unmarshal(value) def test_string_format_custom(self, unmarshaller_factory): formatted = "x-custom" @@ -81,38 +110,38 @@ def format(self, value): custom_formatters = { custom_format: formatter, } - - result = unmarshaller_factory( + unmarshaller = unmarshaller_factory( spec, custom_formatters=custom_formatters - )(value) + ) - assert result == formatted + result = unmarshaller.unmarshal(value) - def test_string_format_custom_formatter(self, unmarshaller_factory): - formatted = "x-custom" + assert result == formatted + def test_array_format_custom_formatter(self, unmarshaller_factory): class CustomFormatter(Formatter): def unmarshal(self, value): - return formatted + return tuple(value) custom_format = "custom" schema = { - "type": "string", + "type": "array", "format": custom_format, } spec = Spec.from_dict(schema, validator=None) - value = "x" + value = ["x"] formatter = CustomFormatter() custom_formatters = { custom_format: formatter, } + unmarshaller = unmarshaller_factory( + spec, custom_formatters=custom_formatters + ) with pytest.warns(DeprecationWarning): - result = unmarshaller_factory( - spec, custom_formatters=custom_formatters - )(value) + result = unmarshaller.unmarshal(value) - assert result == formatted + assert result == tuple(value) def test_string_format_custom_value_error(self, unmarshaller_factory): class CustomFormatter(Formatter): @@ -130,35 +159,9 @@ def format(self, value): custom_formatters = { custom_format: formatter, } + unmarshaller = unmarshaller_factory( + spec, custom_formatters=custom_formatters + ) with pytest.raises(FormatUnmarshalError): - unmarshaller_factory(spec, custom_formatters=custom_formatters)( - value - ) - - def test_string_format_unknown(self, unmarshaller_factory): - unknown_format = "unknown" - schema = { - "type": "string", - "format": unknown_format, - } - spec = Spec.from_dict(schema, validator=None) - value = "x" - - with pytest.raises(FormatterNotFoundError): - unmarshaller_factory(spec)(value) - - def test_string_format_invalid_value(self, unmarshaller_factory): - custom_format = "custom" - schema = { - "type": "string", - "format": custom_format, - } - spec = Spec.from_dict(schema, validator=None) - value = "x" - - with pytest.raises( - FormatterNotFoundError, - match="Formatter not found for custom format", - ): - unmarshaller_factory(spec)(value) + unmarshaller.unmarshal(value) diff --git a/tests/unit/validation/test_request_response_validators.py b/tests/unit/validation/test_request_response_validators.py index 352e1e88..9d526204 100644 --- a/tests/unit/validation/test_request_response_validators.py +++ b/tests/unit/validation/test_request_response_validators.py @@ -3,15 +3,15 @@ import pytest from openapi_schema_validator import OAS31Validator +from openapi_core import RequestValidator +from openapi_core import ResponseValidator +from openapi_core import openapi_request_validator +from openapi_core import openapi_response_validator from openapi_core.unmarshalling.schemas import oas31_types_unmarshaller from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) from openapi_core.unmarshalling.schemas.formatters import Formatter -from openapi_core.validation import openapi_request_validator -from openapi_core.validation import openapi_response_validator -from openapi_core.validation.request.validators import RequestValidator -from openapi_core.validation.response.validators import ResponseValidator from openapi_core.validation.schemas import oas31_schema_validators_factory @@ -32,18 +32,18 @@ def validator(self, schema_unmarshallers_factory): return RequestValidator(schema_unmarshallers_factory) @mock.patch( - "openapi_core.validation.request.validators.APICallRequestValidator." - "validate", + "openapi_core.unmarshalling.request.unmarshallers.APICallRequestUnmarshaller." + "unmarshal", ) - def test_valid(self, mock_validate, validator): + def test_valid(self, mock_unmarshal, validator): spec = mock.sentinel.spec request = mock.sentinel.request with pytest.warns(DeprecationWarning): result = validator.validate(spec, request) - assert result == mock_validate.return_value - mock_validate.assert_called_once_with(request) + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request) class TestResponseValidatorValidate(BaseTestValidate): @@ -52,10 +52,10 @@ def validator(self, schema_unmarshallers_factory): return ResponseValidator(schema_unmarshallers_factory) @mock.patch( - "openapi_core.validation.response.validators.APICallResponseValidator." - "validate", + "openapi_core.unmarshalling.response.unmarshallers.APICallResponseUnmarshaller." + "unmarshal", ) - def test_valid(self, mock_validate, validator): + def test_valid(self, mock_unmarshal, validator): spec = mock.sentinel.spec request = mock.sentinel.request response = mock.sentinel.response @@ -63,8 +63,8 @@ def test_valid(self, mock_validate, validator): with pytest.warns(DeprecationWarning): result = validator.validate(spec, request, response) - assert result == mock_validate.return_value - mock_validate.assert_called_once_with(request, response) + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request, response) class TestDetectProxyOpenAPIRequestValidator: @@ -73,17 +73,17 @@ def validator(self): return openapi_request_validator @mock.patch( - "openapi_core.validation.request.validators.APICallRequestValidator." - "validate", + "openapi_core.unmarshalling.request.unmarshallers.APICallRequestUnmarshaller." + "unmarshal", ) - def test_valid(self, mock_validate, validator, spec_v31): + def test_valid(self, mock_unmarshal, validator, spec_v31): request = mock.sentinel.request with pytest.warns(DeprecationWarning): result = validator.validate(spec_v31, request) - assert result == mock_validate.return_value - mock_validate.assert_called_once_with(request) + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request) class TestDetectProxyOpenAPIResponsealidator: @@ -92,15 +92,15 @@ def validator(self): return openapi_response_validator @mock.patch( - "openapi_core.validation.response.validators.APICallResponseValidator." - "validate", + "openapi_core.unmarshalling.response.unmarshallers.APICallResponseUnmarshaller." + "unmarshal", ) - def test_valid(self, mock_validate, validator, spec_v31): + def test_valid(self, mock_unmarshal, validator, spec_v31): request = mock.sentinel.request response = mock.sentinel.response with pytest.warns(DeprecationWarning): result = validator.validate(spec_v31, request, response) - assert result == mock_validate.return_value - mock_validate.assert_called_once_with(request, response) + assert result == mock_unmarshal.return_value + mock_unmarshal.assert_called_once_with(request, response) diff --git a/tests/unit/unmarshalling/test_validate.py b/tests/unit/validation/test_schema_validators.py similarity index 98% rename from tests/unit/unmarshalling/test_validate.py rename to tests/unit/validation/test_schema_validators.py index afaeac56..099121d1 100644 --- a/tests/unit/unmarshalling/test_validate.py +++ b/tests/unit/validation/test_schema_validators.py @@ -1,11 +1,6 @@ -import datetime - import pytest from openapi_core.spec.paths import Spec -from openapi_core.unmarshalling.schemas.exceptions import ( - FormatterNotFoundError, -) from openapi_core.validation.schemas import ( oas30_write_schema_validators_factory, ) diff --git a/tests/unit/validation/test_shortcuts.py b/tests/unit/validation/test_shortcuts.py deleted file mode 100644 index 0c2846c9..00000000 --- a/tests/unit/validation/test_shortcuts.py +++ /dev/null @@ -1,288 +0,0 @@ -from unittest import mock - -import pytest - -from openapi_core import validate_request -from openapi_core import validate_response -from openapi_core.testing.datatypes import ResultMock -from openapi_core.validation.exceptions import ValidatorDetectError -from openapi_core.validation.request.protocols import Request -from openapi_core.validation.request.protocols import WebhookRequest -from openapi_core.validation.request.validators import APICallRequestValidator -from openapi_core.validation.request.validators import RequestValidator -from openapi_core.validation.request.validators import WebhookRequestValidator -from openapi_core.validation.response.protocols import Response -from openapi_core.validation.response.validators import ( - APICallResponseValidator, -) -from openapi_core.validation.response.validators import ResponseValidator -from openapi_core.validation.response.validators import ( - WebhookResponseValidator, -) - - -class TestValidateRequest: - def test_spec_not_detected(self, spec_invalid): - request = mock.Mock(spec=Request) - - with pytest.raises(ValidatorDetectError): - validate_request(request, spec=spec_invalid) - - def test_request_type_error(self, spec_v31): - request = mock.sentinel.request - - with pytest.raises(TypeError): - validate_request(request, spec=spec_v31) - - def test_spec_type_error(self): - request = mock.Mock(spec=Request) - spec = mock.sentinel.spec - - with pytest.raises(TypeError): - validate_request(request, spec=spec) - - @mock.patch( - "openapi_core.validation.request.validators.APICallRequestValidator." - "validate", - ) - def test_request(self, mock_validate, spec_v31): - request = mock.Mock(spec=Request) - - result = validate_request(request, spec=spec_v31) - - assert result == mock_validate.return_value - mock_validate.assert_called_once_with(request) - - @mock.patch( - "openapi_core.validation.request.validators.APICallRequestValidator." - "validate", - ) - def test_spec_as_first_arg_deprecated(self, mock_validate, spec_v31): - request = mock.Mock(spec=Request) - - with pytest.warns(DeprecationWarning): - result = validate_request(spec_v31, request) - - assert result == mock_validate.return_value - mock_validate.assert_called_once_with(request) - - @mock.patch( - "openapi_core.validation.request.validators.APICallRequestValidator." - "validate", - ) - def test_request_error(self, mock_validate, spec_v31): - request = mock.Mock(spec=Request) - mock_validate.return_value = ResultMock(error_to_raise=ValueError) - - with pytest.raises(ValueError): - validate_request(request, spec=spec_v31) - - mock_validate.assert_called_once_with(request) - - def test_validator(self, spec_v31): - request = mock.Mock(spec=Request) - validator = mock.Mock(spec=RequestValidator) - - with pytest.warns(DeprecationWarning): - result = validate_request( - request, spec=spec_v31, validator=validator - ) - - assert result == validator.validate.return_value - validator.validate.assert_called_once_with( - spec_v31, request, base_url=None - ) - - def test_validator_cls(self, spec_v31): - request = mock.Mock(spec=Request) - validator_cls = mock.Mock(spec=APICallRequestValidator) - - result = validate_request(request, spec=spec_v31, cls=validator_cls) - - assert result == validator_cls().validate.return_value - validator_cls().validate.assert_called_once_with(request) - - @mock.patch( - "openapi_core.validation.request.validators.WebhookRequestValidator." - "validate", - ) - def test_webhook_request(self, mock_validate, spec_v31): - request = mock.Mock(spec=WebhookRequest) - - result = validate_request(request, spec=spec_v31) - - assert result == mock_validate.return_value - mock_validate.assert_called_once_with(request) - - def test_webhook_request_validator_not_found(self, spec_v30): - request = mock.Mock(spec=WebhookRequest) - - with pytest.raises(ValidatorDetectError): - validate_request(request, spec=spec_v30) - - @mock.patch( - "openapi_core.validation.request.validators.WebhookRequestValidator." - "validate", - ) - def test_webhook_request_error(self, mock_validate, spec_v31): - request = mock.Mock(spec=WebhookRequest) - mock_validate.return_value = ResultMock(error_to_raise=ValueError) - - with pytest.raises(ValueError): - validate_request(request, spec=spec_v31) - - mock_validate.assert_called_once_with(request) - - def test_webhook_validator_cls(self, spec_v31): - request = mock.Mock(spec=WebhookRequest) - validator_cls = mock.Mock(spec=WebhookRequestValidator) - - result = validate_request(request, spec=spec_v31, cls=validator_cls) - - assert result == validator_cls().validate.return_value - validator_cls().validate.assert_called_once_with(request) - - -class TestValidateResponse: - def test_spec_not_detected(self, spec_invalid): - request = mock.Mock(spec=Request) - response = mock.Mock(spec=Response) - - with pytest.raises(ValidatorDetectError): - validate_response(request, response, spec=spec_invalid) - - def test_request_type_error(self, spec_v31): - request = mock.sentinel.request - response = mock.Mock(spec=Response) - - with pytest.raises(TypeError): - validate_response(request, response, spec=spec_v31) - - def test_response_type_error(self, spec_v31): - request = mock.Mock(spec=Request) - response = mock.sentinel.response - - with pytest.raises(TypeError): - validate_response(request, response, spec=spec_v31) - - def test_spec_type_error(self): - request = mock.Mock(spec=Request) - response = mock.Mock(spec=Response) - spec = mock.sentinel.spec - - with pytest.raises(TypeError): - validate_response(request, response, spec=spec) - - @mock.patch( - "openapi_core.validation.response.validators.APICallResponseValidator." - "validate", - ) - def test_request_response(self, mock_validate, spec_v31): - request = mock.Mock(spec=Request) - response = mock.Mock(spec=Response) - - result = validate_response(request, response, spec=spec_v31) - - assert result == mock_validate.return_value - mock_validate.assert_called_once_with(request, response) - - @mock.patch( - "openapi_core.validation.response.validators.APICallResponseValidator." - "validate", - ) - def test_spec_as_first_arg_deprecated(self, mock_validate, spec_v31): - request = mock.Mock(spec=Request) - response = mock.Mock(spec=Response) - - with pytest.warns(DeprecationWarning): - result = validate_response(spec_v31, request, response) - - assert result == mock_validate.return_value - mock_validate.assert_called_once_with(request, response) - - @mock.patch( - "openapi_core.validation.response.validators.APICallResponseValidator." - "validate", - ) - def test_request_response_error(self, mock_validate, spec_v31): - request = mock.Mock(spec=Request) - response = mock.Mock(spec=Response) - mock_validate.return_value = ResultMock(error_to_raise=ValueError) - - with pytest.raises(ValueError): - validate_response(request, response, spec=spec_v31) - - mock_validate.assert_called_once_with(request, response) - - def test_validator(self, spec_v31): - request = mock.Mock(spec=Request) - response = mock.Mock(spec=Response) - validator = mock.Mock(spec=ResponseValidator) - - with pytest.warns(DeprecationWarning): - result = validate_response( - request, response, spec=spec_v31, validator=validator - ) - - assert result == validator.validate.return_value - validator.validate.assert_called_once_with( - spec_v31, request, response, base_url=None - ) - - def test_validator_cls(self, spec_v31): - request = mock.Mock(spec=Request) - response = mock.Mock(spec=Response) - validator_cls = mock.Mock(spec=APICallResponseValidator) - - result = validate_response( - request, response, spec=spec_v31, cls=validator_cls - ) - - assert result == validator_cls().validate.return_value - validator_cls().validate.assert_called_once_with(request, response) - - def test_webhook_response_validator_not_found(self, spec_v30): - request = mock.Mock(spec=WebhookRequest) - response = mock.Mock(spec=Response) - - with pytest.raises(ValidatorDetectError): - validate_response(request, response, spec=spec_v30) - - @mock.patch( - "openapi_core.validation.response.validators.WebhookResponseValidator." - "validate", - ) - def test_webhook_request(self, mock_validate, spec_v31): - request = mock.Mock(spec=WebhookRequest) - response = mock.Mock(spec=Response) - - result = validate_response(request, response, spec=spec_v31) - - assert result == mock_validate.return_value - mock_validate.assert_called_once_with(request, response) - - @mock.patch( - "openapi_core.validation.response.validators.WebhookResponseValidator." - "validate", - ) - def test_webhook_request_error(self, mock_validate, spec_v31): - request = mock.Mock(spec=WebhookRequest) - response = mock.Mock(spec=Response) - mock_validate.return_value = ResultMock(error_to_raise=ValueError) - - with pytest.raises(ValueError): - validate_response(request, response, spec=spec_v31) - - mock_validate.assert_called_once_with(request, response) - - def test_webhook_response_cls(self, spec_v31): - request = mock.Mock(spec=WebhookRequest) - response = mock.Mock(spec=Response) - validator_cls = mock.Mock(spec=WebhookResponseValidator) - - result = validate_response( - request, response, spec=spec_v31, cls=validator_cls - ) - - assert result == validator_cls().validate.return_value - validator_cls().validate.assert_called_once_with(request, response) From d8db7cef6783af7696f12368bc4fc68658bc0aba Mon Sep 17 00:00:00 2001 From: p1c2u Date: Sat, 18 Feb 2023 06:58:17 +0000 Subject: [PATCH 3/4] formats assigned to type fix enable tests --- .../unmarshalling/test_unmarshallers.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/tests/integration/unmarshalling/test_unmarshallers.py b/tests/integration/unmarshalling/test_unmarshallers.py index c69d4af7..9476d4ee 100644 --- a/tests/integration/unmarshalling/test_unmarshallers.py +++ b/tests/integration/unmarshalling/test_unmarshallers.py @@ -180,12 +180,6 @@ def test_basic_types_invalid(self, unmarshallers_factory, type, value): in exc_info.value.schema_errors[0].message ) - @pytest.mark.xfail( - reason=( - "Format assigned to type bug. " - "See https://github.com/p1c2u/openapi-core/issues/483" - ) - ) @pytest.mark.parametrize( "format,value,unmarshalled", [ @@ -406,8 +400,8 @@ def test_string_uuid_invalid(self, unmarshallers_factory): @pytest.mark.xfail( reason=( - "Format assigned to type bug. " - "See https://github.com/p1c2u/openapi-core/issues/483" + "Formats raise error for other types. " + "See https://github.com/p1c2u/openapi-schema-validator/issues/66" ) ) @pytest.mark.parametrize( @@ -1310,12 +1304,6 @@ def test_any_all_of_invalid_properties(self, value, unmarshallers_factory): with pytest.raises(InvalidSchemaValue): unmarshaller.unmarshal(value) - @pytest.mark.xfail( - reason=( - "Format assigned to type bug. " - "See https://github.com/p1c2u/openapi-core/issues/483" - ) - ) def test_any_format_one_of(self, unmarshallers_factory): schema = { "format": "date", From 9a57fd6971317341207bbf224d76eb108e9d8819 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Sat, 18 Feb 2023 11:19:07 +0000 Subject: [PATCH 4/4] Format validators and Format unmarshallers; deprecate custom formatters --- docs/customizations.rst | 54 +++-- .../unmarshalling/request/unmarshallers.py | 14 ++ .../unmarshalling/schemas/datatypes.py | 8 +- .../unmarshalling/schemas/factories.py | 72 +++--- .../unmarshalling/schemas/unmarshallers.py | 25 +- openapi_core/unmarshalling/unmarshallers.py | 20 +- openapi_core/validation/request/validators.py | 5 + openapi_core/validation/schemas/datatypes.py | 8 + openapi_core/validation/schemas/factories.py | 39 ++- .../schemas/formatters.py | 0 openapi_core/validation/schemas/validators.py | 3 + openapi_core/validation/validators.py | 11 +- .../test_schema_unmarshallers.py | 227 ++++++++++++++++-- .../test_request_response_validators.py | 13 +- 14 files changed, 408 insertions(+), 91 deletions(-) rename openapi_core/{unmarshalling => validation}/schemas/formatters.py (100%) diff --git a/docs/customizations.rst b/docs/customizations.rst index 70c12c9d..8a5ea64c 100644 --- a/docs/customizations.rst +++ b/docs/customizations.rst @@ -42,42 +42,54 @@ Pass custom defined media type deserializers dictionary with supported mimetypes media_type_deserializers_factory=media_type_deserializers_factory, ) -Formats -------- +Format validators +----------------- OpenAPI defines a ``format`` keyword that hints at how a value should be interpreted, e.g. a ``string`` with the type ``date`` should conform to the RFC 3339 date format. -Openapi-core comes with a set of built-in formatters, but it's also possible to add custom formatters in `SchemaUnmarshallersFactory` and pass it to `RequestValidator` or `ResponseValidator`. +OpenAPI comes with a set of built-in format validators, but it's also possible to add custom ones. Here's how you could add support for a ``usdate`` format that handles dates of the form MM/DD/YYYY: .. code-block:: python - from openapi_core.unmarshalling.schemas.factories import SchemaUnmarshallersFactory - from openapi_schema_validator import OAS30Validator - from datetime import datetime import re - class USDateFormatter: - def validate(self, value) -> bool: - return bool(re.match(r"^\d{1,2}/\d{1,2}/\d{4}$", value)) - - def format(self, value): - return datetime.strptime(value, "%m/%d/%y").date - + def validate_usdate(value): + return bool(re.match(r"^\d{1,2}/\d{1,2}/\d{4}$", value)) - custom_formatters = { - 'usdate': USDateFormatter(), + extra_format_validators = { + 'usdate': validate_usdate, } - schema_unmarshallers_factory = SchemaUnmarshallersFactory( - OAS30Validator, - custom_formatters=custom_formatters, - ) result = validate_response( request, response, spec=spec, - cls=ResponseValidator, - schema_unmarshallers_factory=schema_unmarshallers_factory, + extra_format_validators=extra_format_validators, ) +Format unmarshallers +-------------------- + +Based on ``format`` keyword, openapi-core can also unmarshal values to specific formats. + +Openapi-core comes with a set of built-in format unmarshallers, but it's also possible to add custom ones. + +Here's an example with the ``usdate`` format that converts a value to date object: + +.. code-block:: python + + from datetime import datetime + + def unmarshal_usdate(value): + return datetime.strptime(value, "%m/%d/%y").date + + extra_format_unmarshallers = { + 'usdate': unmarshal_usdate, + } + + result = unmarshal_response( + request, response, + spec=spec, + extra_format_unmarshallers=extra_format_unmarshallers, + ) diff --git a/openapi_core/unmarshalling/request/unmarshallers.py b/openapi_core/unmarshalling/request/unmarshallers.py index e828d8a6..2983d082 100644 --- a/openapi_core/unmarshalling/request/unmarshallers.py +++ b/openapi_core/unmarshalling/request/unmarshallers.py @@ -32,6 +32,9 @@ from openapi_core.unmarshalling.schemas import ( oas31_schema_unmarshallers_factory, ) +from openapi_core.unmarshalling.schemas.datatypes import ( + FormatUnmarshallersDict, +) from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) @@ -72,6 +75,7 @@ V31WebhookRequestValidator, ) from openapi_core.validation.request.validators import WebhookRequestValidator +from openapi_core.validation.schemas.datatypes import FormatValidatorsDict from openapi_core.validation.schemas.factories import SchemaValidatorsFactory @@ -84,10 +88,14 @@ def __init__( parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory, media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + format_validators: Optional[FormatValidatorsDict] = None, + extra_format_validators: Optional[FormatValidatorsDict] = None, security_provider_factory: SecurityProviderFactory = security_provider_factory, schema_unmarshallers_factory: Optional[ SchemaUnmarshallersFactory ] = None, + format_unmarshallers: Optional[FormatUnmarshallersDict] = None, + extra_format_unmarshallers: Optional[FormatUnmarshallersDict] = None, ): BaseUnmarshaller.__init__( self, @@ -97,7 +105,11 @@ def __init__( parameter_deserializers_factory=parameter_deserializers_factory, media_type_deserializers_factory=media_type_deserializers_factory, schema_validators_factory=schema_validators_factory, + format_validators=format_validators, + extra_format_validators=extra_format_validators, schema_unmarshallers_factory=schema_unmarshallers_factory, + format_unmarshallers=format_unmarshallers, + extra_format_unmarshallers=extra_format_unmarshallers, ) BaseRequestValidator.__init__( self, @@ -107,6 +119,8 @@ def __init__( parameter_deserializers_factory=parameter_deserializers_factory, media_type_deserializers_factory=media_type_deserializers_factory, schema_validators_factory=schema_validators_factory, + format_validators=format_validators, + extra_format_validators=extra_format_validators, security_provider_factory=security_provider_factory, ) diff --git a/openapi_core/unmarshalling/schemas/datatypes.py b/openapi_core/unmarshalling/schemas/datatypes.py index 23e0eb0c..2e1892a1 100644 --- a/openapi_core/unmarshalling/schemas/datatypes.py +++ b/openapi_core/unmarshalling/schemas/datatypes.py @@ -1,12 +1,6 @@ from typing import Any from typing import Callable from typing import Dict -from typing import Optional - -from openapi_core.unmarshalling.schemas.formatters import Formatter FormatUnmarshaller = Callable[[Any], Any] - -CustomFormattersDict = Dict[str, Formatter] -FormattersDict = Dict[Optional[str], Formatter] -UnmarshallersDict = Dict[str, Callable[[Any], Any]] +FormatUnmarshallersDict = Dict[str, FormatUnmarshaller] diff --git a/openapi_core/unmarshalling/schemas/factories.py b/openapi_core/unmarshalling/schemas/factories.py index a3c36243..ea796b82 100644 --- a/openapi_core/unmarshalling/schemas/factories.py +++ b/openapi_core/unmarshalling/schemas/factories.py @@ -2,15 +2,10 @@ import warnings from typing import Optional -if sys.version_info >= (3, 8): - from functools import cached_property -else: - from backports.cached_property import cached_property - from openapi_core.spec import Spec -from openapi_core.unmarshalling.schemas.datatypes import CustomFormattersDict -from openapi_core.unmarshalling.schemas.datatypes import FormatUnmarshaller -from openapi_core.unmarshalling.schemas.datatypes import UnmarshallersDict +from openapi_core.unmarshalling.schemas.datatypes import ( + FormatUnmarshallersDict, +) from openapi_core.unmarshalling.schemas.exceptions import ( FormatterNotFoundError, ) @@ -19,6 +14,8 @@ ) from openapi_core.unmarshalling.schemas.unmarshallers import SchemaUnmarshaller from openapi_core.unmarshalling.schemas.unmarshallers import TypesUnmarshaller +from openapi_core.validation.schemas.datatypes import CustomFormattersDict +from openapi_core.validation.schemas.datatypes import FormatValidatorsDict from openapi_core.validation.schemas.factories import SchemaValidatorsFactory @@ -27,26 +24,33 @@ def __init__( self, schema_validators_factory: SchemaValidatorsFactory, types_unmarshaller: TypesUnmarshaller, - format_unmarshallers: Optional[UnmarshallersDict] = None, + format_unmarshallers: Optional[FormatUnmarshallersDict] = None, custom_formatters: Optional[CustomFormattersDict] = None, ): self.schema_validators_factory = schema_validators_factory self.types_unmarshaller = types_unmarshaller - if custom_formatters is None: - custom_formatters = {} if format_unmarshallers is None: format_unmarshallers = {} self.format_unmarshallers = format_unmarshallers + if custom_formatters is None: + custom_formatters = {} + else: + warnings.warn( + "custom_formatters is deprecated. " + "Use extra_format_validators to validate custom formats " + "and use extra_format_unmarshallers to unmarshal custom formats.", + DeprecationWarning, + ) self.custom_formatters = custom_formatters - @cached_property - def formats_unmarshaller(self) -> FormatsUnmarshaller: - return FormatsUnmarshaller( - self.format_unmarshallers, - self.custom_formatters, - ) - - def create(self, schema: Spec) -> SchemaUnmarshaller: + def create( + self, + schema: Spec, + format_validators: Optional[FormatValidatorsDict] = None, + format_unmarshallers: Optional[FormatUnmarshallersDict] = None, + extra_format_validators: Optional[FormatValidatorsDict] = None, + extra_format_unmarshallers: Optional[FormatUnmarshallersDict] = None, + ) -> SchemaUnmarshaller: """Create unmarshaller from the schema.""" if schema is None: raise TypeError("Invalid schema") @@ -54,22 +58,34 @@ def create(self, schema: Spec) -> SchemaUnmarshaller: if schema.getkey("deprecated", False): warnings.warn("The schema is deprecated", DeprecationWarning) - formatters_checks = { - name: formatter.validate - for name, formatter in self.custom_formatters.items() - } + if extra_format_validators is None: + extra_format_validators = {} + extra_format_validators.update( + { + name: formatter.validate + for name, formatter in self.custom_formatters.items() + } + ) schema_validator = self.schema_validators_factory.create( - schema, **formatters_checks + schema, + format_validators=format_validators, + extra_format_validators=extra_format_validators, ) schema_format = schema.getkey("format") + formats_unmarshaller = FormatsUnmarshaller( + format_unmarshallers or self.format_unmarshallers, + extra_format_unmarshallers, + self.custom_formatters, + ) + # FIXME: don;t raise exception on unknown format + # See https://github.com/p1c2u/openapi-core/issues/515 if ( schema_format - and schema_format - not in self.schema_validators_factory.format_checker.checkers - and schema_format not in self.custom_formatters + and schema_format not in schema_validator + and schema_format not in formats_unmarshaller ): raise FormatterNotFoundError(schema_format) @@ -77,5 +93,5 @@ def create(self, schema: Spec) -> SchemaUnmarshaller: schema, schema_validator, self.types_unmarshaller, - self.formats_unmarshaller, + formats_unmarshaller, ) diff --git a/openapi_core/unmarshalling/schemas/unmarshallers.py b/openapi_core/unmarshalling/schemas/unmarshallers.py index 353e50d9..2387541b 100644 --- a/openapi_core/unmarshalling/schemas/unmarshallers.py +++ b/openapi_core/unmarshalling/schemas/unmarshallers.py @@ -12,11 +12,13 @@ from openapi_core.extensions.models.factories import ModelPathFactory from openapi_core.schema.schemas import get_properties from openapi_core.spec import Spec -from openapi_core.unmarshalling.schemas.datatypes import CustomFormattersDict from openapi_core.unmarshalling.schemas.datatypes import FormatUnmarshaller -from openapi_core.unmarshalling.schemas.datatypes import UnmarshallersDict +from openapi_core.unmarshalling.schemas.datatypes import ( + FormatUnmarshallersDict, +) from openapi_core.unmarshalling.schemas.exceptions import FormatUnmarshalError from openapi_core.unmarshalling.schemas.exceptions import UnmarshallerError +from openapi_core.validation.schemas.datatypes import CustomFormattersDict from openapi_core.validation.schemas.validators import SchemaValidator log = logging.getLogger(__name__) @@ -212,12 +214,16 @@ def get_unmarshaller( class FormatsUnmarshaller: def __init__( self, - format_unmarshallers: Optional[UnmarshallersDict] = None, + format_unmarshallers: Optional[FormatUnmarshallersDict] = None, + extra_format_unmarshallers: Optional[FormatUnmarshallersDict] = None, custom_formatters: Optional[CustomFormattersDict] = None, ): if format_unmarshallers is None: format_unmarshallers = {} self.format_unmarshallers = format_unmarshallers + if extra_format_unmarshallers is None: + extra_format_unmarshallers = {} + self.extra_format_unmarshallers = extra_format_unmarshallers if custom_formatters is None: custom_formatters = {} self.custom_formatters = custom_formatters @@ -237,11 +243,24 @@ def get_unmarshaller( if schema_format in self.custom_formatters: formatter = self.custom_formatters[schema_format] return formatter.format + if schema_format in self.extra_format_unmarshallers: + return self.extra_format_unmarshallers[schema_format] if schema_format in self.format_unmarshallers: return self.format_unmarshallers[schema_format] return None + def __contains__(self, schema_format: str) -> bool: + format_unmarshallers_dicts: List[Mapping[str, Any]] = [ + self.custom_formatters, + self.extra_format_unmarshallers, + self.format_unmarshallers, + ] + for content in format_unmarshallers_dicts: + if schema_format in content: + return True + return False + class SchemaUnmarshaller: def __init__( diff --git a/openapi_core/unmarshalling/unmarshallers.py b/openapi_core/unmarshalling/unmarshallers.py index f381811b..61ae6fd7 100644 --- a/openapi_core/unmarshalling/unmarshallers.py +++ b/openapi_core/unmarshalling/unmarshallers.py @@ -18,9 +18,13 @@ ParameterDeserializersFactory, ) from openapi_core.spec import Spec +from openapi_core.unmarshalling.schemas.datatypes import ( + FormatUnmarshallersDict, +) from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) +from openapi_core.validation.schemas.datatypes import FormatValidatorsDict from openapi_core.validation.schemas.factories import SchemaValidatorsFactory from openapi_core.validation.validators import BaseValidator @@ -36,9 +40,13 @@ def __init__( parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory, media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + format_validators: Optional[FormatValidatorsDict] = None, + extra_format_validators: Optional[FormatValidatorsDict] = None, schema_unmarshallers_factory: Optional[ SchemaUnmarshallersFactory ] = None, + format_unmarshallers: Optional[FormatUnmarshallersDict] = None, + extra_format_unmarshallers: Optional[FormatUnmarshallersDict] = None, ): if schema_validators_factory is None and schema_unmarshallers_factory: schema_validators_factory = ( @@ -51,6 +59,8 @@ def __init__( parameter_deserializers_factory=parameter_deserializers_factory, media_type_deserializers_factory=media_type_deserializers_factory, schema_validators_factory=schema_validators_factory, + format_validators=format_validators, + extra_format_validators=extra_format_validators, ) self.schema_unmarshallers_factory = ( schema_unmarshallers_factory or self.schema_unmarshallers_factory @@ -59,9 +69,17 @@ def __init__( raise NotImplementedError( "schema_unmarshallers_factory is not assigned" ) + self.format_unmarshallers = format_unmarshallers + self.extra_format_unmarshallers = extra_format_unmarshallers def _unmarshal_schema(self, schema: Spec, value: Any) -> Any: - unmarshaller = self.schema_unmarshallers_factory.create(schema) + unmarshaller = self.schema_unmarshallers_factory.create( + schema, + format_validators=self.format_validators, + extra_format_validators=self.extra_format_validators, + format_unmarshallers=self.format_unmarshallers, + extra_format_unmarshallers=self.extra_format_unmarshallers, + ) return unmarshaller.unmarshal(value) def _get_param_or_header_value( diff --git a/openapi_core/validation/request/validators.py b/openapi_core/validation/request/validators.py index c79bfe3e..b25246a9 100644 --- a/openapi_core/validation/request/validators.py +++ b/openapi_core/validation/request/validators.py @@ -50,6 +50,7 @@ oas30_write_schema_validators_factory, ) from openapi_core.validation.schemas import oas31_schema_validators_factory +from openapi_core.validation.schemas.datatypes import FormatValidatorsDict from openapi_core.validation.schemas.factories import SchemaValidatorsFactory from openapi_core.validation.validators import BaseAPICallValidator from openapi_core.validation.validators import BaseValidator @@ -65,6 +66,8 @@ def __init__( parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory, media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + format_validators: Optional[FormatValidatorsDict] = None, + extra_format_validators: Optional[FormatValidatorsDict] = None, security_provider_factory: SecurityProviderFactory = security_provider_factory, ): super().__init__( @@ -74,6 +77,8 @@ def __init__( parameter_deserializers_factory=parameter_deserializers_factory, media_type_deserializers_factory=media_type_deserializers_factory, schema_validators_factory=schema_validators_factory, + format_validators=format_validators, + extra_format_validators=extra_format_validators, ) self.security_provider_factory = security_provider_factory diff --git a/openapi_core/validation/schemas/datatypes.py b/openapi_core/validation/schemas/datatypes.py index b3e398f9..89e9c737 100644 --- a/openapi_core/validation/schemas/datatypes.py +++ b/openapi_core/validation/schemas/datatypes.py @@ -1,4 +1,12 @@ from typing import Any from typing import Callable +from typing import Dict +from typing import Optional + +from openapi_core.validation.schemas.formatters import Formatter FormatValidator = Callable[[Any], bool] + +CustomFormattersDict = Dict[str, Formatter] +FormattersDict = Dict[Optional[str], Formatter] +FormatValidatorsDict = Dict[str, FormatValidator] diff --git a/openapi_core/validation/schemas/factories.py b/openapi_core/validation/schemas/factories.py index 41122724..313f9c9f 100644 --- a/openapi_core/validation/schemas/factories.py +++ b/openapi_core/validation/schemas/factories.py @@ -7,7 +7,8 @@ from jsonschema.protocols import Validator from openapi_core.spec import Spec -from openapi_core.validation.schemas.datatypes import FormatValidator +from openapi_core.validation.schemas.datatypes import CustomFormattersDict +from openapi_core.validation.schemas.datatypes import FormatValidatorsDict from openapi_core.validation.schemas.validators import SchemaValidator @@ -23,17 +24,41 @@ def __init__( self.format_checker = format_checker def get_format_checker( - self, **format_checks: FormatValidator + self, + format_validators: Optional[FormatValidatorsDict] = None, + extra_format_validators: Optional[FormatValidatorsDict] = None, ) -> FormatChecker: - format_checker = deepcopy(self.format_checker) - for name, check in format_checks.items(): - format_checker.checks(name)(check) + if format_validators is None: + format_checker = deepcopy(self.format_checker) + else: + format_checker = FormatChecker([]) + format_checker = self._add_validators( + format_checker, format_validators + ) + format_checker = self._add_validators( + format_checker, extra_format_validators + ) return format_checker + def _add_validators( + self, + base_format_checker: FormatChecker, + format_validators: Optional[FormatValidatorsDict] = None, + ) -> FormatChecker: + if format_validators is not None: + for name, check in format_validators.items(): + base_format_checker.checks(name)(check) + return base_format_checker + def create( - self, schema: Spec, **format_checks: FormatValidator + self, + schema: Spec, + format_validators: Optional[FormatValidatorsDict] = None, + extra_format_validators: Optional[FormatValidatorsDict] = None, ) -> Validator: - format_checker = self.get_format_checker(**format_checks) + format_checker = self.get_format_checker( + format_validators, extra_format_validators + ) resolver = schema.accessor.resolver # type: ignore with schema.open() as schema_dict: jsonschema_validator = self.schema_validator_class( diff --git a/openapi_core/unmarshalling/schemas/formatters.py b/openapi_core/validation/schemas/formatters.py similarity index 100% rename from openapi_core/unmarshalling/schemas/formatters.py rename to openapi_core/validation/schemas/formatters.py diff --git a/openapi_core/validation/schemas/validators.py b/openapi_core/validation/schemas/validators.py index e46dad31..2e87dc54 100644 --- a/openapi_core/validation/schemas/validators.py +++ b/openapi_core/validation/schemas/validators.py @@ -30,6 +30,9 @@ def __init__( self.schema = schema self.validator = validator + def __contains__(self, schema_format: str) -> bool: + return schema_format in self.validator.format_checker.checkers + def validate(self, value: Any) -> None: errors_iter = self.validator.iter_errors(value) errors = tuple(errors_iter) diff --git a/openapi_core/validation/validators.py b/openapi_core/validation/validators.py index a465c67d..fc3e93bd 100644 --- a/openapi_core/validation/validators.py +++ b/openapi_core/validation/validators.py @@ -33,6 +33,7 @@ from openapi_core.templating.paths.finders import APICallPathFinder from openapi_core.templating.paths.finders import BasePathFinder from openapi_core.templating.paths.finders import WebhookPathFinder +from openapi_core.validation.schemas.datatypes import FormatValidatorsDict from openapi_core.validation.schemas.factories import SchemaValidatorsFactory @@ -47,6 +48,8 @@ def __init__( parameter_deserializers_factory: ParameterDeserializersFactory = parameter_deserializers_factory, media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory, schema_validators_factory: Optional[SchemaValidatorsFactory] = None, + format_validators: Optional[FormatValidatorsDict] = None, + extra_format_validators: Optional[FormatValidatorsDict] = None, ): self.spec = spec self.base_url = base_url @@ -63,6 +66,8 @@ def __init__( raise NotImplementedError( "schema_validators_factory is not assigned" ) + self.format_validators = format_validators + self.extra_format_validators = extra_format_validators def _get_media_type(self, content: Spec, mimetype: str) -> MediaType: from openapi_core.templating.media_types.finders import MediaTypeFinder @@ -83,7 +88,11 @@ def _cast(self, schema: Spec, value: Any) -> Any: return caster(value) def _validate_schema(self, schema: Spec, value: Any) -> None: - validator = self.schema_validators_factory.create(schema) + validator = self.schema_validators_factory.create( + schema, + format_validators=self.format_validators, + extra_format_validators=self.extra_format_validators, + ) validator.validate(value) def _get_param_or_header_value( diff --git a/tests/unit/unmarshalling/test_schema_unmarshallers.py b/tests/unit/unmarshalling/test_schema_unmarshallers.py index d010f39c..b7a18198 100644 --- a/tests/unit/unmarshalling/test_schema_unmarshallers.py +++ b/tests/unit/unmarshalling/test_schema_unmarshallers.py @@ -1,6 +1,7 @@ from functools import partial import pytest +from openapi_schema_validator import OAS30WriteValidator from openapi_core.spec.paths import Spec from openapi_core.unmarshalling.schemas import oas30_types_unmarshaller @@ -11,23 +12,34 @@ from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) -from openapi_core.unmarshalling.schemas.formatters import Formatter from openapi_core.validation.schemas import ( oas30_write_schema_validators_factory, ) +from openapi_core.validation.schemas.exceptions import InvalidSchemaValue +from openapi_core.validation.schemas.factories import SchemaValidatorsFactory +from openapi_core.validation.schemas.formatters import Formatter @pytest.fixture def schema_unmarshaller_factory(): def create_unmarshaller( - validators_factory, schema, custom_formatters=None + validators_factory, + schema, + format_validators=None, + extra_format_validators=None, + extra_format_unmarshallers=None, + custom_formatters=None, ): - custom_formatters = custom_formatters or {} return SchemaUnmarshallersFactory( validators_factory, oas30_types_unmarshaller, custom_formatters=custom_formatters, - ).create(schema) + ).create( + schema, + format_validators=format_validators, + extra_format_validators=extra_format_validators, + extra_format_unmarshallers=extra_format_unmarshallers, + ) return create_unmarshaller @@ -68,7 +80,9 @@ def test_string_format_invalid_value(self, unmarshaller_factory): class TestOAS30SchemaUnmarshallerUnmarshal: - def test_schema_custom_format_invalid(self, unmarshaller_factory): + def test_schema_custom_formatter_format_invalid( + self, unmarshaller_factory + ): class CustomFormatter(Formatter): def format(self, value): raise ValueError @@ -84,10 +98,11 @@ def format(self, value): } spec = Spec.from_dict(schema, validator=None) value = "x" - unmarshaller = unmarshaller_factory( - spec, - custom_formatters=custom_formatters, - ) + with pytest.warns(DeprecationWarning): + unmarshaller = unmarshaller_factory( + spec, + custom_formatters=custom_formatters, + ) with pytest.raises(FormatUnmarshalError): unmarshaller.unmarshal(value) @@ -110,9 +125,10 @@ def format(self, value): custom_formatters = { custom_format: formatter, } - unmarshaller = unmarshaller_factory( - spec, custom_formatters=custom_formatters - ) + with pytest.warns(DeprecationWarning): + unmarshaller = unmarshaller_factory( + spec, custom_formatters=custom_formatters + ) result = unmarshaller.unmarshal(value) @@ -134,9 +150,10 @@ def unmarshal(self, value): custom_formatters = { custom_format: formatter, } - unmarshaller = unmarshaller_factory( - spec, custom_formatters=custom_formatters - ) + with pytest.warns(DeprecationWarning): + unmarshaller = unmarshaller_factory( + spec, custom_formatters=custom_formatters + ) with pytest.warns(DeprecationWarning): result = unmarshaller.unmarshal(value) @@ -159,9 +176,185 @@ def format(self, value): custom_formatters = { custom_format: formatter, } - unmarshaller = unmarshaller_factory( - spec, custom_formatters=custom_formatters + with pytest.warns(DeprecationWarning): + unmarshaller = unmarshaller_factory( + spec, custom_formatters=custom_formatters + ) + + with pytest.raises(FormatUnmarshalError): + unmarshaller.unmarshal(value) + + def test_schema_extra_format_unmarshaller_format_invalid( + self, schema_unmarshaller_factory, unmarshaller_factory + ): + def custom_format_unmarshaller(value): + raise ValueError + + custom_format = "custom" + schema = { + "type": "string", + "format": "custom", + } + spec = Spec.from_dict(schema, validator=None) + value = "x" + schema_validators_factory = SchemaValidatorsFactory( + OAS30WriteValidator + ) + extra_format_unmarshallers = { + custom_format: custom_format_unmarshaller, + } + unmarshaller = schema_unmarshaller_factory( + schema_validators_factory, + spec, + extra_format_unmarshallers=extra_format_unmarshallers, ) with pytest.raises(FormatUnmarshalError): unmarshaller.unmarshal(value) + + def test_schema_extra_format_unmarshaller_format_custom( + self, schema_unmarshaller_factory + ): + formatted = "x-custom" + + def custom_format_unmarshaller(value): + return formatted + + custom_format = "custom" + schema = { + "type": "string", + "format": custom_format, + } + spec = Spec.from_dict(schema, validator=None) + value = "x" + schema_validators_factory = SchemaValidatorsFactory( + OAS30WriteValidator + ) + extra_format_unmarshallers = { + custom_format: custom_format_unmarshaller, + } + unmarshaller = schema_unmarshaller_factory( + schema_validators_factory, + spec, + extra_format_unmarshallers=extra_format_unmarshallers, + ) + + result = unmarshaller.unmarshal(value) + + assert result == formatted + + def test_schema_extra_format_validator_format_invalid( + self, schema_unmarshaller_factory, unmarshaller_factory + ): + def custom_format_validator(value): + return False + + custom_format = "custom" + schema = { + "type": "string", + "format": custom_format, + } + spec = Spec.from_dict(schema, validator=None) + value = "x" + schema_validators_factory = SchemaValidatorsFactory( + OAS30WriteValidator + ) + extra_format_validators = { + custom_format: custom_format_validator, + } + unmarshaller = schema_unmarshaller_factory( + schema_validators_factory, + spec, + extra_format_validators=extra_format_validators, + ) + + with pytest.raises(InvalidSchemaValue): + unmarshaller.unmarshal(value) + + def test_schema_extra_format_validator_format_custom( + self, schema_unmarshaller_factory + ): + def custom_format_validator(value): + return True + + custom_format = "custom" + schema = { + "type": "string", + "format": custom_format, + } + spec = Spec.from_dict(schema, validator=None) + value = "x" + schema_validators_factory = SchemaValidatorsFactory( + OAS30WriteValidator + ) + extra_format_validators = { + custom_format: custom_format_validator, + } + unmarshaller = schema_unmarshaller_factory( + schema_validators_factory, + spec, + extra_format_validators=extra_format_validators, + ) + + result = unmarshaller.unmarshal(value) + + assert result == value + + @pytest.mark.xfail( + reason=( + "Not registered format raises FormatterNotFoundError" + "See https://github.com/p1c2u/openapi-core/issues/515" + ) + ) + def test_schema_format_validator_format_invalid( + self, schema_unmarshaller_factory, unmarshaller_factory + ): + custom_format = "date" + schema = { + "type": "string", + "format": custom_format, + } + spec = Spec.from_dict(schema, validator=None) + value = "x" + schema_validators_factory = SchemaValidatorsFactory( + OAS30WriteValidator + ) + format_validators = {} + unmarshaller = schema_unmarshaller_factory( + schema_validators_factory, + spec, + format_validators=format_validators, + ) + + result = unmarshaller.unmarshal(value) + + assert result == value + + def test_schema_format_validator_format_custom( + self, schema_unmarshaller_factory, unmarshaller_factory + ): + def custom_format_validator(value): + return True + + custom_format = "date" + schema = { + "type": "string", + "format": custom_format, + } + spec = Spec.from_dict(schema, validator=None) + value = "x" + schema_validators_factory = SchemaValidatorsFactory( + OAS30WriteValidator + ) + format_validators = { + custom_format: custom_format_validator, + } + unmarshaller = schema_unmarshaller_factory( + schema_validators_factory, + spec, + format_validators=format_validators, + ) + + result = unmarshaller.unmarshal(value) + + assert result == value diff --git a/tests/unit/validation/test_request_response_validators.py b/tests/unit/validation/test_request_response_validators.py index 9d526204..31dd2c9a 100644 --- a/tests/unit/validation/test_request_response_validators.py +++ b/tests/unit/validation/test_request_response_validators.py @@ -11,19 +11,20 @@ from openapi_core.unmarshalling.schemas.factories import ( SchemaUnmarshallersFactory, ) -from openapi_core.unmarshalling.schemas.formatters import Formatter from openapi_core.validation.schemas import oas31_schema_validators_factory +from openapi_core.validation.schemas.formatters import Formatter class BaseTestValidate: @pytest.fixture def schema_unmarshallers_factory(self): CUSTOM_FORMATTERS = {"custom": Formatter.from_callables()} - return SchemaUnmarshallersFactory( - oas31_schema_validators_factory, - oas31_types_unmarshaller, - custom_formatters=CUSTOM_FORMATTERS, - ) + with pytest.warns(DeprecationWarning): + return SchemaUnmarshallersFactory( + oas31_schema_validators_factory, + oas31_types_unmarshaller, + custom_formatters=CUSTOM_FORMATTERS, + ) class TestRequestValidatorValidate(BaseTestValidate):