From 8e1b3524d376bf76da9b11acbf0d9dacc3f1763c Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Fri, 16 Jul 2021 22:31:59 +0200 Subject: [PATCH] fix: mypy generic type to preserve signature --- aws_lambda_powertools/tracing/tracer.py | 36 ++++++++++++++++++------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/aws_lambda_powertools/tracing/tracer.py b/aws_lambda_powertools/tracing/tracer.py index 48b7866cf0a..5709b1956c2 100644 --- a/aws_lambda_powertools/tracing/tracer.py +++ b/aws_lambda_powertools/tracing/tracer.py @@ -5,7 +5,7 @@ import logging import numbers import os -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, TypeVar, Union, cast, overload from ..shared import constants from ..shared.functions import resolve_env_var_choice, resolve_truthy_env_var_choice @@ -18,6 +18,9 @@ aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE) aws_xray_sdk.core = LazyLoader(constants.XRAY_SDK_CORE_MODULE, globals(), constants.XRAY_SDK_CORE_MODULE) +AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001 +AnyAwaitableT = TypeVar("AnyAwaitableT", bound=Awaitable) + class Tracer: """Tracer using AWS-XRay to provide decorators with known defaults for Lambda functions @@ -329,12 +332,26 @@ def decorate(event, context, **kwargs): return decorate + # see #465 + @overload + def capture_method(self, method: "AnyCallableT") -> "AnyCallableT": + ... + + @overload def capture_method( self, - method: Optional[Callable] = None, + method: None = None, capture_response: Optional[bool] = None, capture_error: Optional[bool] = None, - ): + ) -> Callable[["AnyCallableT"], "AnyCallableT"]: + ... + + def capture_method( + self, + method: Optional[AnyCallableT] = None, + capture_response: Optional[bool] = None, + capture_error: Optional[bool] = None, + ) -> AnyCallableT: """Decorator to create subsegment for arbitrary functions It also captures both response and exceptions as metadata @@ -487,8 +504,9 @@ async def async_tasks(): # Return a partial function with args filled if method is None: logger.debug("Decorator called with parameters") - return functools.partial( - self.capture_method, capture_response=capture_response, capture_error=capture_error + return cast( + AnyCallableT, + functools.partial(self.capture_method, capture_response=capture_response, capture_error=capture_error), ) method_name = f"{method.__name__}" @@ -509,7 +527,7 @@ async def async_tasks(): return self._decorate_generator_function( method=method, capture_response=capture_response, capture_error=capture_error, method_name=method_name ) - elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__): + elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__): # type: ignore return self._decorate_generator_function_with_context_manager( method=method, capture_response=capture_response, capture_error=capture_error, method_name=method_name ) @@ -602,11 +620,11 @@ def decorate(*args, **kwargs): def _decorate_sync_function( self, - method: Callable, + method: AnyCallableT, capture_response: Optional[Union[bool, str]] = None, capture_error: Optional[Union[bool, str]] = None, method_name: Optional[str] = None, - ): + ) -> AnyCallableT: @functools.wraps(method) def decorate(*args, **kwargs): with self.provider.in_subsegment(name=f"## {method_name}") as subsegment: @@ -628,7 +646,7 @@ def decorate(*args, **kwargs): return response - return decorate + return cast(AnyCallableT, decorate) def _add_response_as_metadata( self,