diff --git a/.coveragerc b/.coveragerc index d097511c..34417c3f 100644 --- a/.coveragerc +++ b/.coveragerc @@ -11,3 +11,5 @@ exclude_lines = def __repr__ # Ignore abstract methods raise NotImplementedError + # Ignore coverage for code specific to static type checkers + TYPE_CHECKING diff --git a/google/api_core/retry.py b/google/api_core/retry.py index df1e65e0..84b5d0fe 100644 --- a/google/api_core/retry.py +++ b/google/api_core/retry.py @@ -54,13 +54,15 @@ def check_if_exists(): """ -from __future__ import unicode_literals +from __future__ import annotations import datetime import functools import logging import random +import sys import time +from typing import Any, Callable, TypeVar, TYPE_CHECKING import requests.exceptions @@ -68,6 +70,15 @@ def check_if_exists(): from google.api_core import exceptions from google.auth import exceptions as auth_exceptions +if TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import ParamSpec + else: + from typing_extensions import ParamSpec + + _P = ParamSpec("_P") + _R = TypeVar("_R") + _LOGGER = logging.getLogger(__name__) _DEFAULT_INITIAL_DELAY = 1.0 # seconds _DEFAULT_MAXIMUM_DELAY = 60.0 # seconds @@ -75,7 +86,9 @@ def check_if_exists(): _DEFAULT_DEADLINE = 60.0 * 2.0 # seconds -def if_exception_type(*exception_types): +def if_exception_type( + *exception_types: type[BaseException], +) -> Callable[[BaseException], bool]: """Creates a predicate to check if the exception is of a given type. Args: @@ -87,7 +100,7 @@ def if_exception_type(*exception_types): exception is of the given type(s). """ - def if_exception_type_predicate(exception): + def if_exception_type_predicate(exception: BaseException) -> bool: """Bound predicate for checking an exception type.""" return isinstance(exception, exception_types) @@ -307,14 +320,14 @@ class Retry(object): def __init__( self, - predicate=if_transient_error, - initial=_DEFAULT_INITIAL_DELAY, - maximum=_DEFAULT_MAXIMUM_DELAY, - multiplier=_DEFAULT_DELAY_MULTIPLIER, - timeout=_DEFAULT_DEADLINE, - on_error=None, - **kwargs - ): + predicate: Callable[[BaseException], bool] = if_transient_error, + initial: float = _DEFAULT_INITIAL_DELAY, + maximum: float = _DEFAULT_MAXIMUM_DELAY, + multiplier: float = _DEFAULT_DELAY_MULTIPLIER, + timeout: float = _DEFAULT_DEADLINE, + on_error: Callable[[BaseException], Any] | None = None, + **kwargs: Any, + ) -> None: self._predicate = predicate self._initial = initial self._multiplier = multiplier @@ -323,7 +336,11 @@ def __init__( self._deadline = self._timeout self._on_error = on_error - def __call__(self, func, on_error=None): + def __call__( + self, + func: Callable[_P, _R], + on_error: Callable[[BaseException], Any] | None = None, + ) -> Callable[_P, _R]: """Wrap a callable with retry behavior. Args: @@ -340,7 +357,7 @@ def __call__(self, func, on_error=None): on_error = self._on_error @functools.wraps(func) - def retry_wrapped_func(*args, **kwargs): + def retry_wrapped_func(*args: _P.args, **kwargs: _P.kwargs) -> _R: """A wrapper that calls target function with retry.""" target = functools.partial(func, *args, **kwargs) sleep_generator = exponential_sleep_generator(