Skip to content

Commit b10e803

Browse files
committed
fix(typing): ensure return type is a T when default_value is set
1 parent ca0091f commit b10e803

File tree

7 files changed

+129
-16
lines changed

7 files changed

+129
-16
lines changed

aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py

+11
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,17 @@ def path_parameters(self) -> Optional[Dict[str, str]]:
283283
def stage_variables(self) -> Optional[Dict[str, str]]:
284284
return self.get("stageVariables")
285285

286+
@overload
287+
def get_header_value(self, name: str, default_value: str, case_sensitive: bool = False) -> str: ...
288+
289+
@overload
290+
def get_header_value(
291+
self,
292+
name: str,
293+
default_value: Optional[str] = None,
294+
case_sensitive: Optional[bool] = False,
295+
) -> Optional[str]: ...
296+
286297
def get_header_value(
287298
self,
288299
name: str,

aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Union
1+
from typing import Any, Dict, List, Optional, Union, overload
22

33
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
44
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
@@ -214,6 +214,22 @@ def stash(self) -> Optional[dict]:
214214
a pipeline resolver."""
215215
return self.get("stash")
216216

217+
@overload
218+
def get_header_value(
219+
self,
220+
name: str,
221+
default_value: str,
222+
case_sensitive: Optional[bool] = False,
223+
) -> str: ...
224+
225+
@overload
226+
def get_header_value(
227+
self,
228+
name: str,
229+
default_value: Optional[str] = None,
230+
case_sensitive: Optional[bool] = False,
231+
) -> Optional[str]: ...
232+
217233
def get_header_value(
218234
self,
219235
name: str,

aws_lambda_powertools/utilities/data_classes/common.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
from collections.abc import Mapping
44
from functools import cached_property
5-
from typing import Any, Callable, Dict, Iterator, List, Optional, overload
5+
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, overload
66

77
from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer
88
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
@@ -11,6 +11,8 @@
1111
get_query_string_value,
1212
)
1313

14+
T = TypeVar("T")
15+
1416

1517
class DictWrapper(Mapping):
1618
"""Provides a single read only access to a wrapper dict"""
@@ -86,7 +88,13 @@ def _str_helper(self) -> Dict[str, Any]:
8688
def _properties(self) -> List[str]:
8789
return [p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)]
8890

89-
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
91+
@overload
92+
def get(self, key: str, default: T) -> T: ...
93+
94+
@overload
95+
def get(self, key: str, default: Optional[T] = None) -> Optional[T]: ...
96+
97+
def get(self, key: str, default: Optional[T] = None) -> Optional[T]:
9098
return self._data.get(key, default)
9199

92100
@property
@@ -172,6 +180,12 @@ def http_method(self) -> str:
172180
"""The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT."""
173181
return self["httpMethod"]
174182

183+
@overload
184+
def get_query_string_value(self, name: str, default_value: str) -> str: ...
185+
186+
@overload
187+
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ...
188+
175189
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]:
176190
"""Get query string value by name
177191

aws_lambda_powertools/utilities/data_classes/kafka_event.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import base64
22
from functools import cached_property
3-
from typing import Any, Dict, Iterator, List, Optional
3+
from typing import Any, Dict, Iterator, List, Optional, overload
44

55
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
66
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
@@ -69,10 +69,26 @@ def decoded_headers(self) -> Dict[str, bytes]:
6969
"""Decodes the headers as a single dictionary."""
7070
return {k: bytes(v) for chunk in self.headers for k, v in chunk.items()}
7171

72+
@overload
7273
def get_header_value(
7374
self,
7475
name: str,
75-
default_value: Optional[Any] = None,
76+
default_value: str,
77+
case_sensitive: bool = False,
78+
) -> str: ...
79+
80+
@overload
81+
def get_header_value(
82+
self,
83+
name: str,
84+
default_value: Optional[str] = None,
85+
case_sensitive: bool = False,
86+
) -> Optional[str]: ...
87+
88+
def get_header_value(
89+
self,
90+
name: str,
91+
default_value: Optional[str] = None,
7692
case_sensitive: bool = True,
7793
) -> Optional[str]:
7894
"""Get a decoded header value by name."""

aws_lambda_powertools/utilities/data_classes/s3_object_event.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional
1+
from typing import Dict, Optional, overload
22

33
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
44
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
@@ -73,6 +73,22 @@ def headers(self) -> Dict[str, str]:
7373
The case of the original headers is retained in this map."""
7474
return self["headers"]
7575

76+
@overload
77+
def get_header_value(
78+
self,
79+
name: str,
80+
default_value: str,
81+
case_sensitive: Optional[bool] = False,
82+
) -> str: ...
83+
84+
@overload
85+
def get_header_value(
86+
self,
87+
name: str,
88+
default_value: Optional[str] = None,
89+
case_sensitive: Optional[bool] = False,
90+
) -> Optional[str]: ...
91+
7692
def get_header_value(
7793
self,
7894
name: str,

aws_lambda_powertools/utilities/data_classes/shared_functions.py

+44-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import base64
4-
from typing import Any, Dict
4+
from typing import Any, Dict, List, Optional, overload
55

66

77
def base64_decode(value: str) -> str:
@@ -21,12 +21,30 @@ def base64_decode(value: str) -> str:
2121
return base64.b64decode(value).decode("UTF-8")
2222

2323

24+
@overload
2425
def get_header_value(
2526
headers: dict[str, Any],
2627
name: str,
27-
default_value: str | None,
28-
case_sensitive: bool | None,
29-
) -> str | None:
28+
default_value: str,
29+
case_sensitive: bool,
30+
) -> str: ...
31+
32+
33+
@overload
34+
def get_header_value(
35+
headers: dict[str, Any],
36+
name: str,
37+
default_value: Optional[str],
38+
case_sensitive: bool,
39+
) -> Optional[str]: ...
40+
41+
42+
def get_header_value(
43+
headers: dict[str, Any],
44+
name: str,
45+
default_value: Optional[str],
46+
case_sensitive: bool,
47+
) -> Optional[str]:
3048
"""
3149
Get the value of a header by its name.
3250
@@ -62,11 +80,27 @@ def get_header_value(
6280
)
6381

6482

83+
@overload
84+
def get_query_string_value(
85+
query_string_parameters: Dict[str, str] | None,
86+
name: str,
87+
default_value: str,
88+
) -> str: ...
89+
90+
91+
@overload
92+
def get_query_string_value(
93+
query_string_parameters: Dict[str, str] | None,
94+
name: str,
95+
default_value: Optional[str] = None,
96+
) -> Optional[str]: ...
97+
98+
6599
def get_query_string_value(
66100
query_string_parameters: Dict[str, str] | None,
67101
name: str,
68-
default_value: str | None = None,
69-
) -> str | None:
102+
default_value: Optional[str] = None,
103+
) -> Optional[str]:
70104
"""
71105
Retrieves the value of a query string parameter specified by the given name.
72106
@@ -87,18 +121,18 @@ def get_query_string_value(
87121

88122

89123
def get_multi_value_query_string_values(
90-
multi_value_query_string_parameters: Dict[str, list[str]] | None,
124+
multi_value_query_string_parameters: Dict[str, List[str]] | None,
91125
name: str,
92-
default_values: list[str] | None = None,
93-
) -> list[str]:
126+
default_values: Optional[List[str]] = None,
127+
) -> List[str]:
94128
"""
95129
Retrieves the values of a multi-value string parameters specified by the given name.
96130
97131
Parameters
98132
----------
99133
name: str
100134
The name of the query string parameter to retrieve.
101-
default_value: list[str], optional
135+
default_value: List[str], optional
102136
The default value to return if the parameter is not found. Defaults to None.
103137
104138
Returns

aws_lambda_powertools/utilities/data_classes/vpc_lattice.py

+6
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def http_method(self) -> str:
4747
"""The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT."""
4848
return self["method"]
4949

50+
@overload
51+
def get_query_string_value(self, name: str, default_value: str) -> str: ...
52+
53+
@overload
54+
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ...
55+
5056
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]:
5157
"""Get query string value by name
5258

0 commit comments

Comments
 (0)