Skip to content

Commit 240177d

Browse files
fix(openapi): validate response serialization when falsy (#6119)
* fix(openapi): validate response serialization when falsy * revert serialize * change comment * add more tests * revert serialize * fix mypy * Refactoring tests + removing additional code --------- Co-authored-by: Leandro Damascena <[email protected]>
1 parent d9719b9 commit 240177d

File tree

3 files changed

+212
-20
lines changed

3 files changed

+212
-20
lines changed

Diff for: aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

+8-19
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,13 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
136136
return self._handle_response(route=route, response=response)
137137

138138
def _handle_response(self, *, route: Route, response: Response):
139-
# Process the response body if it exists
140-
if response.body:
141-
# Validate and serialize the response, if it's JSON
142-
if response.is_json():
143-
response.body = self._serialize_response(
144-
field=route.dependant.return_param,
145-
response_content=response.body,
146-
)
139+
# Check if we have a return type defined
140+
if route.dependant.return_param:
141+
# Validate and serialize the response, including None
142+
response.body = self._serialize_response(
143+
field=route.dependant.return_param,
144+
response_content=response.body,
145+
)
147146

148147
return response
149148

@@ -164,15 +163,6 @@ def _serialize_response(
164163
"""
165164
if field:
166165
errors: list[dict[str, Any]] = []
167-
# MAINTENANCE: remove this when we drop pydantic v1
168-
if not hasattr(field, "serializable"):
169-
response_content = self._prepare_response_content(
170-
response_content,
171-
exclude_unset=exclude_unset,
172-
exclude_defaults=exclude_defaults,
173-
exclude_none=exclude_none,
174-
)
175-
176166
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
177167
if errors:
178168
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
@@ -187,7 +177,6 @@ def _serialize_response(
187177
exclude_defaults=exclude_defaults,
188178
exclude_none=exclude_none,
189179
)
190-
191180
return jsonable_encoder(
192181
value,
193182
include=include,
@@ -199,7 +188,7 @@ def _serialize_response(
199188
custom_serializer=self._validation_serializer,
200189
)
201190
else:
202-
# Just serialize the response content returned from the handler
191+
# Just serialize the response content returned from the handler.
203192
return jsonable_encoder(response_content, custom_serializer=self._validation_serializer)
204193

205194
def _prepare_response_content(

Diff for: tests/functional/event_handler/_pydantic/test_openapi_serialization.py

+131-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
import json
2-
from typing import Dict
2+
from dataclasses import dataclass
3+
from typing import Dict, Optional, Set
34

45
import pytest
6+
from pydantic import BaseModel
57

68
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
79

810

11+
@dataclass
12+
class Person:
13+
name: str
14+
birth_date: str
15+
scores: Set[int]
16+
17+
918
def test_openapi_duplicated_serialization():
1019
# GIVEN APIGatewayRestResolver is initialized with enable_validation=True
1120
app = APIGatewayRestResolver(enable_validation=True)
@@ -61,3 +70,124 @@ def handler():
6170

6271
# THEN we the custom serializer should be used
6372
assert response["body"] == "hello world"
73+
74+
75+
def test_valid_model_returned_for_optional_type(gw_event):
76+
# GIVEN an APIGatewayRestResolver with validation enabled
77+
app = APIGatewayRestResolver(enable_validation=True)
78+
79+
class Model(BaseModel):
80+
name: str
81+
age: int
82+
83+
@app.get("/valid_optional")
84+
def handler_valid_optional() -> Optional[Model]:
85+
return Model(name="John", age=30)
86+
87+
# WHEN returning a valid model for an Optional type
88+
gw_event["path"] = "/valid_optional"
89+
result = app(gw_event, {})
90+
91+
# THEN it should succeed and return the serialized model
92+
assert result["statusCode"] == 200
93+
assert json.loads(result["body"]) == {"name": "John", "age": 30}
94+
95+
96+
def test_serialize_response_without_field(gw_event):
97+
# GIVEN an APIGatewayRestResolver with validation enabled
98+
app = APIGatewayRestResolver(enable_validation=True)
99+
100+
# WHEN a handler is defined without return type annotation
101+
@app.get("/test")
102+
def handler():
103+
return {"message": "Hello, World!"}
104+
105+
gw_event["path"] = "/test"
106+
107+
# THEN the handler should be invoked and return 200
108+
# AND the body must be a JSON object
109+
response = app(gw_event, None)
110+
assert response["statusCode"] == 200
111+
assert response["body"] == '{"message":"Hello, World!"}'
112+
113+
114+
def test_serialize_response_list(gw_event):
115+
"""Test serialization of list responses containing complex types"""
116+
# GIVEN an APIGatewayRestResolver with validation enabled
117+
app = APIGatewayRestResolver(enable_validation=True)
118+
119+
# WHEN a handler returns a list containing various types
120+
@app.get("/test")
121+
def handler():
122+
return [{"set": [1, 2, 3]}, {"simple": "value"}]
123+
124+
gw_event["path"] = "/test"
125+
126+
# THEN the response should be properly serialized
127+
response = app(gw_event, None)
128+
assert response["statusCode"] == 200
129+
assert response["body"] == '[{"set":[1,2,3]},{"simple":"value"}]'
130+
131+
132+
def test_serialize_response_nested_dict(gw_event):
133+
"""Test serialization of nested dictionary responses"""
134+
# GIVEN an APIGatewayRestResolver with validation enabled
135+
app = APIGatewayRestResolver(enable_validation=True)
136+
137+
# WHEN a handler returns a nested dictionary with complex types
138+
@app.get("/test")
139+
def handler():
140+
return {"nested": {"date": "2000-01-01", "set": [1, 2, 3]}, "simple": "value"}
141+
142+
gw_event["path"] = "/test"
143+
144+
# THEN the response should be properly serialized
145+
response = app(gw_event, None)
146+
assert response["statusCode"] == 200
147+
assert response["body"] == '{"nested":{"date":"2000-01-01","set":[1,2,3]},"simple":"value"}'
148+
149+
150+
def test_serialize_response_dataclass(gw_event):
151+
"""Test serialization of dataclass responses"""
152+
# GIVEN an APIGatewayRestResolver with validation enabled
153+
app = APIGatewayRestResolver(enable_validation=True)
154+
155+
# WHEN a handler returns a dataclass instance
156+
@app.get("/test")
157+
def handler():
158+
return Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91])
159+
160+
gw_event["path"] = "/test"
161+
162+
# THEN the response should be properly serialized
163+
response = app(gw_event, None)
164+
assert response["statusCode"] == 200
165+
assert response["body"] == '{"name":"John Doe","birth_date":"1990-01-01","scores":[95,87,91]}'
166+
167+
168+
def test_serialize_response_mixed_types(gw_event):
169+
"""Test serialization of mixed type responses"""
170+
# GIVEN an APIGatewayRestResolver with validation enabled
171+
app = APIGatewayRestResolver(enable_validation=True)
172+
173+
# WHEN a handler returns a response with mixed types
174+
@app.get("/test")
175+
def handler():
176+
person = Person(name="John Doe", birth_date="1990-01-01", scores=[95, 87, 91])
177+
return {
178+
"person": person,
179+
"records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}],
180+
"metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]},
181+
}
182+
183+
gw_event["path"] = "/test"
184+
185+
# THEN the response should be properly serialized
186+
response = app(gw_event, None)
187+
assert response["statusCode"] == 200
188+
expected = {
189+
"person": {"name": "John Doe", "birth_date": "1990-01-01", "scores": [95, 87, 91]},
190+
"records": [{"date": "2000-01-01"}, {"set": [1, 2, 3]}],
191+
"metadata": {"processed_at": "2050-01-01", "tags": ["tag1", "tag2"]},
192+
}
193+
assert json.loads(response["body"]) == expected

Diff for: tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py

+73
Original file line numberDiff line numberDiff line change
@@ -1128,3 +1128,76 @@ def handler(user_id: int = 123):
11281128
# THEN the handler should be invoked and return 200
11291129
result = app(minimal_event, {})
11301130
assert result["statusCode"] == 200
1131+
1132+
1133+
def test_validation_error_none_returned_non_optional_type(gw_event):
1134+
# GIVEN an APIGatewayRestResolver with validation enabled
1135+
app = APIGatewayRestResolver(enable_validation=True)
1136+
1137+
class Model(BaseModel):
1138+
name: str
1139+
age: int
1140+
1141+
@app.get("/none_not_allowed")
1142+
def handler_none_not_allowed() -> Model:
1143+
return None # type: ignore
1144+
1145+
# WHEN returning None for a non-Optional type
1146+
gw_event["path"] = "/none_not_allowed"
1147+
result = app(gw_event, {})
1148+
1149+
# THEN it should return a validation error
1150+
assert result["statusCode"] == 422
1151+
body = json.loads(result["body"])
1152+
assert "model_attributes_type" in body["detail"][0]["type"]
1153+
1154+
1155+
def test_none_returned_for_optional_type(gw_event):
1156+
# GIVEN an APIGatewayRestResolver with validation enabled
1157+
app = APIGatewayRestResolver(enable_validation=True)
1158+
1159+
class Model(BaseModel):
1160+
name: str
1161+
age: int
1162+
1163+
@app.get("/none_allowed")
1164+
def handler_none_allowed() -> Optional[Model]:
1165+
return None
1166+
1167+
# WHEN returning None for an Optional type
1168+
gw_event["path"] = "/none_allowed"
1169+
result = app(gw_event, {})
1170+
1171+
# THEN it should succeed
1172+
assert result["statusCode"] == 200
1173+
assert result["body"] == "null"
1174+
1175+
1176+
@pytest.mark.parametrize(
1177+
"path, body",
1178+
[
1179+
("/empty_dict", {}),
1180+
("/empty_list", []),
1181+
("/none", "null"),
1182+
("/empty_string", ""),
1183+
],
1184+
ids=["empty_dict", "empty_list", "none", "empty_string"],
1185+
)
1186+
def test_none_returned_for_falsy_return(gw_event, path, body):
1187+
# GIVEN an APIGatewayRestResolver with validation enabled
1188+
app = APIGatewayRestResolver(enable_validation=True)
1189+
1190+
class Model(BaseModel):
1191+
name: str
1192+
age: int
1193+
1194+
@app.get(path)
1195+
def handler_none_allowed() -> Model:
1196+
return body
1197+
1198+
# WHEN returning None for an Optional type
1199+
gw_event["path"] = path
1200+
result = app(gw_event, {})
1201+
1202+
# THEN it should succeed
1203+
assert result["statusCode"] == 422

0 commit comments

Comments
 (0)