Skip to content

fix(openapi): validate response serialization when falsy #6119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 20, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,18 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
return self._handle_response(route=route, response=response)

def _handle_response(self, *, route: Route, response: Response):
# Process the response body if it exists
if response.body:
# Validate and serialize the response, if it's JSON
if response.is_json():
# Check if we have a return type defined
if route.dependant.return_param:
try:
# Validate and serialize the response, including None
response.body = self._serialize_response(
field=route.dependant.return_param,
response_content=response.body,
)
except RequestValidationError as e:
logger.error(f"Response validation failed: {str(e)}")
response.status_code = 422
response.body = {"detail": e.errors()}

return response

Expand All @@ -164,42 +168,21 @@ def _serialize_response(
"""
if field:
errors: list[dict[str, Any]] = []
# MAINTENANCE: remove this when we drop pydantic v1
if not hasattr(field, "serializable"):
response_content = self._prepare_response_content(
response_content,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)

value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
if errors:
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)

if hasattr(field, "serialize"):
return field.serialize(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)

return jsonable_encoder(
return field.serialize(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_serializer=self._validation_serializer,
)
else:
# Just serialize the response content returned from the handler
# Just serialize the response content returned from the handler.
return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)

def _prepare_response_content(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import PurePath
from typing import List, Optional, Tuple
from typing import List, Optional, Set, Tuple

import pytest
from pydantic import BaseModel
Expand Down Expand Up @@ -1128,3 +1128,154 @@ def handler(user_id: int = 123):
# THEN the handler should be invoked and return 200
result = app(minimal_event, {})
assert result["statusCode"] == 200


def test_validate_optional_return_types(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

class Model(BaseModel):
name: str
age: int

# AND handlers defined with different Optional return types
@app.get("/none_not_allowed")
def handler_none_not_allowed() -> Model:
return None # type: ignore

@app.get("/none_allowed")
def handler_none_allowed() -> Optional[Model]:
return None

@app.get("/valid_optional")
def handler_valid_optional() -> Optional[Model]:
return Model(name="John", age=30)

# WHEN returning None for a non-Optional type
gw_event["path"] = "/none_not_allowed"
result = app(gw_event, {})
# THEN it should return a validation error
assert result["statusCode"] == 422
body = json.loads(result["body"])
assert "model_attributes_type" in body["detail"][0]["type"]

# WHEN returning None for an Optional type
gw_event["path"] = "/none_allowed"
result = app(gw_event, {})
# THEN it should succeed
assert result["statusCode"] == 200
assert result["body"] == "null"

# WHEN returning a valid model for an Optional type
gw_event["path"] = "/valid_optional"
result = app(gw_event, {})
# THEN it should succeed and return the serialized model
assert result["statusCode"] == 200
assert json.loads(result["body"]) == {"name": "John", "age": 30}


def test_serialize_response_without_field(gw_event):
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler is defined without return type annotation
@app.get("/test")
def handler():
return {"message": "Hello, World!"}

gw_event["path"] = "/test"

# THEN the handler should be invoked and return 200
# AND the body must be a JSON object
response = app(gw_event, None)
assert response["statusCode"] == 200
assert response["body"] == '{"message":"Hello, World!"}'


def test_serialize_response_list(gw_event):
"""Test serialization of list responses containing complex types"""
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler returns a list containing various types
@app.get("/test")
def handler():
return [{"set": [1, 2, 3]}, {"simple": "value"}]

gw_event["path"] = "/test"

# THEN the response should be properly serialized
response = app(gw_event, None)
assert response["statusCode"] == 200
assert response["body"] == '[{"set":[1,2,3]},{"simple":"value"}]'


def test_serialize_response_nested_dict(gw_event):
"""Test serialization of nested dictionary responses"""
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler returns a nested dictionary with complex types
@app.get("/test")
def handler():
return {"nested": {"date": "2000-01-01", "set": [1, 2, 3]}, "simple": "value"}

gw_event["path"] = "/test"

# THEN the response should be properly serialized
response = app(gw_event, None)
assert response["statusCode"] == 200
assert response["body"] == '{"nested":{"date":"2000-01-01","set":[1,2,3]},"simple":"value"}'


@dataclass
class Person:
name: str
birth_date: str
scores: Set[int]


def test_serialize_response_dataclass(gw_event):
"""Test serialization of dataclass responses"""
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler returns a dataclass instance
@app.get("/test")
def handler():
return Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91])

gw_event["path"] = "/test"

# THEN the response should be properly serialized
response = app(gw_event, None)
assert response["statusCode"] == 200
assert response["body"] == '{"name":"John Doe","birth_date":"1990-01-01","scores":[95,87,91]}'


def test_serialize_response_mixed_types(gw_event):
"""Test serialization of mixed type responses"""
# GIVEN an APIGatewayRestResolver with validation enabled
app = APIGatewayRestResolver(enable_validation=True)

# WHEN a handler returns a response with mixed types
@app.get("/test")
def handler():
person = Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91])
return {
"person": person,
"records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}],
"metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]},
}

gw_event["path"] = "/test"

# THEN the response should be properly serialized
response = app(gw_event, None)
assert response["statusCode"] == 200
expected = {
"person": {"name": "John Doe", "birth_date": "1990-01-01", "scores": [95, 87, 91]},
"records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}],
"metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]},
}
assert json.loads(response["body"]) == expected
Loading