Skip to content

fix(metrics): lambda_handler typing, and **kwargs preservation all middlewares #3460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
7 changes: 5 additions & 2 deletions aws_lambda_powertools/metrics/metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# NOTE: keeps for compatibility
from __future__ import annotations

from typing import Any, Callable, Dict, Optional
from typing import Any, Dict

from aws_lambda_powertools.metrics.base import MetricResolution, MetricUnit
from aws_lambda_powertools.metrics.provider.cloudwatch_emf.cloudwatch import AmazonCloudWatchEMFProvider
from aws_lambda_powertools.metrics.provider.cloudwatch_emf.types import CloudWatchEMFOutput
from aws_lambda_powertools.shared.types import AnyCallableT


class Metrics:
Expand Down Expand Up @@ -129,16 +130,18 @@ def flush_metrics(self, raise_on_empty_metrics: bool = False) -> None:

def log_metrics(
self,
lambda_handler: Callable[[Dict, Any], Any] | Optional[Callable[[Dict, Any, Optional[Dict]], Any]] = None,
lambda_handler: AnyCallableT | None = None,
capture_cold_start_metric: bool = False,
raise_on_empty_metrics: bool = False,
default_dimensions: Dict[str, str] | None = None,
**kwargs,
):
return self.provider.log_metrics(
lambda_handler=lambda_handler,
capture_cold_start_metric=capture_cold_start_metric,
raise_on_empty_metrics=raise_on_empty_metrics,
default_dimensions=default_dimensions,
**kwargs,
)

def set_default_dimensions(self, **dimensions) -> None:
Expand Down
5 changes: 3 additions & 2 deletions aws_lambda_powertools/metrics/provider/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import functools
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional
from typing import Any

from aws_lambda_powertools.metrics.provider import cold_start
from aws_lambda_powertools.shared.types import AnyCallableT
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -140,7 +141,7 @@ def add_cold_start_metric(self, context: LambdaContext) -> Any:

def log_metrics(
self,
lambda_handler: Callable[[Dict, Any], Any] | Optional[Callable[[Dict, Any, Optional[Dict]], Any]] = None,
lambda_handler: AnyCallableT | None = None,
capture_cold_start_metric: bool = False,
raise_on_empty_metrics: bool = False,
**kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import warnings
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List

from aws_lambda_powertools.metrics.base import single_metric
from aws_lambda_powertools.metrics.exceptions import MetricValueError, SchemaValidationError
Expand All @@ -22,6 +22,7 @@
from aws_lambda_powertools.metrics.types import MetricNameUnitResolution
from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.shared.functions import resolve_env_var_choice
from aws_lambda_powertools.shared.types import AnyCallableT
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -334,7 +335,7 @@ def flush_metrics(self, raise_on_empty_metrics: bool = False) -> None:

def log_metrics(
self,
lambda_handler: Callable[[Dict, Any], Any] | Optional[Callable[[Dict, Any, Optional[Dict]], Any]] = None,
lambda_handler: AnyCallableT | None = None,
capture_cold_start_metric: bool = False,
raise_on_empty_metrics: bool = False,
**kwargs,
Expand Down
5 changes: 3 additions & 2 deletions aws_lambda_powertools/metrics/provider/datadog/datadog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import re
import time
import warnings
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List

from aws_lambda_powertools.metrics.exceptions import MetricValueError, SchemaValidationError
from aws_lambda_powertools.metrics.provider import BaseProvider
from aws_lambda_powertools.metrics.provider.datadog.warnings import DatadogDataValidationWarning
from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.shared.functions import resolve_env_var_choice
from aws_lambda_powertools.shared.types import AnyCallableT
from aws_lambda_powertools.utilities.typing import LambdaContext

METRIC_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_.]+$")
Expand Down Expand Up @@ -226,7 +227,7 @@ def add_cold_start_metric(self, context: LambdaContext) -> None:

def log_metrics(
self,
lambda_handler: Callable[[Dict, Any], Any] | Optional[Callable[[Dict, Any, Optional[Dict]], Any]] = None,
lambda_handler: AnyCallableT | None = None,
capture_cold_start_metric: bool = False,
raise_on_empty_metrics: bool = False,
**kwargs,
Expand Down
5 changes: 3 additions & 2 deletions aws_lambda_powertools/metrics/provider/datadog/metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# NOTE: keeps for compatibility
from __future__ import annotations

from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List

from aws_lambda_powertools.metrics.provider.datadog.datadog import DatadogProvider
from aws_lambda_powertools.shared.types import AnyCallableT


class DatadogMetrics:
Expand Down Expand Up @@ -90,7 +91,7 @@ def flush_metrics(self, raise_on_empty_metrics: bool = False) -> None:

def log_metrics(
self,
lambda_handler: Callable[[Dict, Any], Any] | Optional[Callable[[Dict, Any, Optional[Dict]], Any]] = None,
lambda_handler: AnyCallableT | None = None,
capture_cold_start_metric: bool = False,
raise_on_empty_metrics: bool = False,
default_tags: Dict[str, Any] | None = None,
Expand Down
4 changes: 2 additions & 2 deletions aws_lambda_powertools/middleware_factory/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def final_decorator(func: Optional[Callable] = None, **kwargs: Any):
)

@functools.wraps(func)
def wrapper(event, context):
def wrapper(event, context, **handler_kwargs):
try:
middleware = functools.partial(decorator, func, event, context, **kwargs)
middleware = functools.partial(decorator, func, event, context, **kwargs, **handler_kwargs)
if trace_execution:
tracer = Tracer(auto_patch=False)
with tracer.provider.in_subsegment(name=f"## {decorator.__qualname__}"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def idempotent(
"""

if os.getenv(constants.IDEMPOTENCY_DISABLED_ENV):
return handler(event, context)
return handler(event, context, **kwargs)

config = config or IdempotencyConfig()
config.register_lambda_context(context)
Expand All @@ -91,6 +91,7 @@ def idempotent_function(
persistence_store: BasePersistenceLayer,
config: Optional[IdempotencyConfig] = None,
output_serializer: Optional[Union[BaseIdempotencySerializer, Type[BaseIdempotencyModelSerializer]]] = None,
**kwargs: Any,
) -> Any:
"""
Decorator to handle idempotency of any function
Expand Down Expand Up @@ -136,6 +137,7 @@ def process_order(customer_id: str, order: dict, **kwargs):
persistence_store=persistence_store,
config=config,
output_serializer=output_serializer,
**kwargs,
),
)

Expand Down
13 changes: 9 additions & 4 deletions aws_lambda_powertools/utilities/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

@lambda_handler_decorator
def event_parser(
handler: Callable[[Any, LambdaContext], EventParserReturnType],
handler: Callable[..., EventParserReturnType],
event: Dict[str, Any],
context: LambdaContext,
model: Optional[Type[Model]] = None,
envelope: Optional[Type[Envelope]] = None,
**kwargs: Any,
) -> EventParserReturnType:
"""Lambda handler decorator to parse & validate events using Pydantic models

Expand Down Expand Up @@ -93,9 +94,13 @@ def handler(event: Order, context: LambdaContext):
"or as the type hint of `event` in the handler that it wraps",
)

parsed_event = parse(event=event, model=model, envelope=envelope) if envelope else parse(event=event, model=model)
if envelope:
parsed_event = parse(event=event, model=model, envelope=envelope)
else:
parsed_event = parse(event=event, model=model)

logger.debug(f"Calling handler {handler.__name__}")
return handler(parsed_event, context)
return handler(parsed_event, context, **kwargs)


@overload
Expand All @@ -104,7 +109,7 @@ def parse(event: Dict[str, Any], model: Type[Model]) -> Model:


@overload
def parse(event: Dict[str, Any], model: Type[Model], envelope: Type[Envelope]):
def parse(event: Dict[str, Any], model: Type[Model], envelope: Type[Envelope]) -> Model:
... # pragma: no cover


Expand Down
3 changes: 2 additions & 1 deletion aws_lambda_powertools/utilities/validation/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def validator(
outbound_formats: Optional[Dict] = None,
envelope: str = "",
jmespath_options: Optional[Dict] = None,
**kwargs: Any,
) -> Any:
"""Lambda handler decorator to validate incoming/outbound data using a JSON Schema

Expand Down Expand Up @@ -128,7 +129,7 @@ def handler(event, context):
logger.debug("Validating inbound event")
validate_data_against_schema(data=event, schema=inbound_schema, formats=inbound_formats)

response = handler(event, context)
response = handler(event, context, **kwargs)

if outbound_schema:
logger.debug("Validating outbound event")
Expand Down