Skip to content

Commit a157671

Browse files
feat(data_classes): add API Gateway Websocket event (#6287)
* feat(data_classes): add API Gateway Websocket event Pydantic models were added for API Gateway WebSocket events here: https://docs.powertools.aws.dev/lambda/python/latest/utilities/parser/#built-in-models Add the corresponding data model class. * feat(data_classes): increase tests code coverage --------- Co-authored-by: Leandro Damascena <[email protected]>
1 parent f5b4e1f commit a157671

File tree

3 files changed

+239
-0
lines changed

3 files changed

+239
-0
lines changed

Diff for: aws_lambda_powertools/utilities/data_classes/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .alb_event import ALBEvent
66
from .api_gateway_proxy_event import APIGatewayProxyEvent, APIGatewayProxyEventV2
7+
from .api_gateway_websocket_event import APIGatewayWebSocketEvent
78
from .appsync_resolver_event import AppSyncResolverEvent
89
from .aws_config_rule_event import AWSConfigRuleEvent
910
from .bedrock_agent_event import BedrockAgentEvent
@@ -51,6 +52,7 @@
5152
__all__ = [
5253
"APIGatewayProxyEvent",
5354
"APIGatewayProxyEventV2",
55+
"APIGatewayWebSocketEvent",
5456
"SecretsManagerEvent",
5557
"AppSyncResolverEvent",
5658
"ALBEvent",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
from functools import cached_property
5+
from typing import Any
6+
7+
from aws_lambda_powertools.utilities.data_classes.common import (
8+
CaseInsensitiveDict,
9+
DictWrapper,
10+
)
11+
12+
13+
class APIGatewayWebSocketEventIdentity(DictWrapper):
14+
@property
15+
def source_ip(self) -> str:
16+
return self["sourceIp"]
17+
18+
@property
19+
def user_agent(self) -> str | None:
20+
return self.get("userAgent")
21+
22+
23+
class APIGatewayWebSocketEventRequestContext(DictWrapper):
24+
@property
25+
def route_key(self) -> str:
26+
return self["routeKey"]
27+
28+
@property
29+
def disconnect_status_code(self) -> int | None:
30+
return self.get("disconnectStatusCode")
31+
32+
@property
33+
def message_id(self) -> str | None:
34+
return self.get("messageId")
35+
36+
@property
37+
def event_type(self) -> str:
38+
return self["eventType"]
39+
40+
@property
41+
def extended_request_id(self) -> str:
42+
return self["extendedRequestId"]
43+
44+
@property
45+
def request_time(self) -> str:
46+
return self["requestTime"]
47+
48+
@property
49+
def message_direction(self) -> str:
50+
return self["messageDirection"]
51+
52+
@property
53+
def disconnect_reason(self) -> str | None:
54+
return self.get("disconnectReason")
55+
56+
@property
57+
def stage(self) -> str:
58+
return self["stage"]
59+
60+
@property
61+
def connected_at(self) -> int:
62+
return self["connectedAt"]
63+
64+
@property
65+
def request_time_epoch(self) -> int:
66+
return self["requestTimeEpoch"]
67+
68+
@property
69+
def identity(self) -> APIGatewayWebSocketEventIdentity:
70+
return APIGatewayWebSocketEventIdentity(self["identity"])
71+
72+
@property
73+
def request_id(self) -> str:
74+
return self["requestId"]
75+
76+
@property
77+
def domain_name(self) -> str:
78+
return self["domainName"]
79+
80+
@property
81+
def connection_id(self) -> str:
82+
return self["connectionId"]
83+
84+
@property
85+
def api_id(self) -> str:
86+
return self["apiId"]
87+
88+
89+
class APIGatewayWebSocketEvent(DictWrapper):
90+
"""AWS proxy integration event for WebSocket API
91+
92+
Documentation:
93+
--------------
94+
- https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-websocket-api-integration-requests.html
95+
"""
96+
97+
@property
98+
def is_base64_encoded(self) -> bool:
99+
return self["isBase64Encoded"]
100+
101+
@property
102+
def body(self) -> str | None:
103+
return self.get("body")
104+
105+
@cached_property
106+
def decoded_body(self) -> str | None:
107+
body = self.body
108+
if self.is_base64_encoded and body:
109+
return base64.b64decode(body.encode()).decode()
110+
return body
111+
112+
@cached_property
113+
def json_body(self) -> Any:
114+
if self.decoded_body:
115+
return self._json_deserializer(self.decoded_body)
116+
return None
117+
118+
@property
119+
def headers(self) -> dict[str, str]:
120+
return CaseInsensitiveDict(self.get("headers"))
121+
122+
@property
123+
def multi_value_headers(self) -> dict[str, list[str]]:
124+
return CaseInsensitiveDict(self.get("multiValueHeaders"))
125+
126+
@property
127+
def request_context(self) -> APIGatewayWebSocketEventRequestContext:
128+
return APIGatewayWebSocketEventRequestContext(self["requestContext"])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import json
2+
3+
from aws_lambda_powertools.utilities.data_classes import APIGatewayWebSocketEvent
4+
from tests.functional.utils import load_event
5+
6+
7+
def test_connect_api_gateway_websocket_event():
8+
raw_event = load_event("apiGatewayWebSocketApiConnect.json")
9+
parsed_event = APIGatewayWebSocketEvent(raw_event)
10+
11+
assert parsed_event.is_base64_encoded is False
12+
assert parsed_event.body is None
13+
assert parsed_event.decoded_body is None
14+
assert parsed_event.json_body is None
15+
assert parsed_event.headers == raw_event["headers"]
16+
assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"]
17+
18+
request_context = parsed_event.request_context
19+
request_context_raw = raw_event["requestContext"]
20+
assert request_context.route_key == request_context_raw["routeKey"]
21+
assert request_context.disconnect_status_code is None
22+
assert request_context.message_id is None
23+
assert request_context.event_type == request_context_raw["eventType"]
24+
assert request_context.extended_request_id == request_context_raw["extendedRequestId"]
25+
assert request_context.request_time == request_context_raw["requestTime"]
26+
assert request_context.message_direction == request_context_raw["messageDirection"]
27+
assert request_context.disconnect_reason is None
28+
assert request_context.stage == request_context_raw["stage"]
29+
assert request_context.connected_at == request_context_raw["connectedAt"]
30+
assert request_context.request_time_epoch == request_context_raw["requestTimeEpoch"]
31+
assert request_context.request_id == request_context_raw["requestId"]
32+
assert request_context.domain_name == request_context_raw["domainName"]
33+
assert request_context.connection_id == request_context_raw["connectionId"]
34+
assert request_context.api_id == request_context_raw["apiId"]
35+
36+
identity = request_context.identity
37+
identity_raw = request_context_raw["identity"]
38+
assert identity.source_ip == identity_raw["sourceIp"]
39+
assert identity.user_agent is None
40+
41+
42+
def test_disconnect_api_gateway_websocket_event():
43+
raw_event = load_event("apiGatewayWebSocketApiDisconnect.json")
44+
parsed_event = APIGatewayWebSocketEvent(raw_event)
45+
46+
assert parsed_event.is_base64_encoded is False
47+
assert parsed_event.body is None
48+
assert parsed_event.decoded_body is None
49+
assert parsed_event.json_body is None
50+
assert parsed_event.headers == raw_event["headers"]
51+
assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"]
52+
53+
request_context = parsed_event.request_context
54+
request_context_raw = raw_event["requestContext"]
55+
assert request_context.route_key == request_context_raw["routeKey"]
56+
assert request_context.disconnect_status_code == request_context_raw["disconnectStatusCode"]
57+
assert request_context.message_id is None
58+
assert request_context.event_type == request_context_raw["eventType"]
59+
assert request_context.extended_request_id == request_context_raw["extendedRequestId"]
60+
assert request_context.request_time == request_context_raw["requestTime"]
61+
assert request_context.message_direction == request_context_raw["messageDirection"]
62+
assert request_context.disconnect_reason == request_context_raw["disconnectReason"]
63+
assert request_context.stage == request_context_raw["stage"]
64+
assert request_context.connected_at == request_context_raw["connectedAt"]
65+
assert request_context.request_time_epoch == request_context_raw["requestTimeEpoch"]
66+
assert request_context.request_id == request_context_raw["requestId"]
67+
assert request_context.domain_name == request_context_raw["domainName"]
68+
assert request_context.connection_id == request_context_raw["connectionId"]
69+
assert request_context.api_id == request_context_raw["apiId"]
70+
71+
identity = request_context.identity
72+
identity_raw = request_context_raw["identity"]
73+
assert identity.source_ip == identity_raw["sourceIp"]
74+
assert identity.user_agent is None
75+
76+
77+
def test_message_api_gateway_websocket_event():
78+
raw_event = load_event("apiGatewayWebSocketApiMessage.json")
79+
parsed_event = APIGatewayWebSocketEvent(raw_event)
80+
81+
assert parsed_event.is_base64_encoded is False
82+
assert parsed_event.body == raw_event["body"]
83+
assert parsed_event.decoded_body == raw_event["body"]
84+
assert parsed_event.json_body == json.loads(raw_event["body"])
85+
assert parsed_event.headers == {}
86+
assert parsed_event.multi_value_headers == {}
87+
88+
request_context = parsed_event.request_context
89+
request_context_raw = raw_event["requestContext"]
90+
assert request_context.route_key == request_context_raw["routeKey"]
91+
assert request_context.disconnect_status_code is None
92+
assert request_context.message_id == request_context_raw["messageId"]
93+
assert request_context.event_type == request_context_raw["eventType"]
94+
assert request_context.extended_request_id == request_context_raw["extendedRequestId"]
95+
assert request_context.request_time == request_context_raw["requestTime"]
96+
assert request_context.message_direction == request_context_raw["messageDirection"]
97+
assert request_context.disconnect_reason is None
98+
assert request_context.stage == request_context_raw["stage"]
99+
assert request_context.connected_at == request_context_raw["connectedAt"]
100+
assert request_context.request_time_epoch == request_context_raw["requestTimeEpoch"]
101+
assert request_context.request_id == request_context_raw["requestId"]
102+
assert request_context.domain_name == request_context_raw["domainName"]
103+
assert request_context.connection_id == request_context_raw["connectionId"]
104+
assert request_context.api_id == request_context_raw["apiId"]
105+
106+
identity = request_context.identity
107+
identity_raw = request_context_raw["identity"]
108+
assert identity.source_ip == identity_raw["sourceIp"]
109+
assert identity.user_agent is None

0 commit comments

Comments
 (0)