Skip to content

fix: update the async transactional types to not require extra awaits #1045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions google/cloud/firestore_v1/async_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
"""Helpers for applying Google Cloud Firestore changes in a transaction."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Optional
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Awaitable,
Callable,
Coroutine,
Optional,
TypeVar,
Protocol,
)

from google.api_core import exceptions, gapic_v1
from google.api_core import retry_async as retries
Expand All @@ -41,6 +51,9 @@
from google.cloud.firestore_v1.query_profile import ExplainOptions


T = TypeVar("T", bound=Callable[..., Any])


class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction):
"""Accumulate read-and-write operations to be sent in a transaction.

Expand Down Expand Up @@ -236,11 +249,13 @@ class _AsyncTransactional(_BaseTransactional):
A coroutine that should be run (and retried) in a transaction.
"""

def __init__(self, to_wrap) -> None:
def __init__(
self, to_wrap: Callable[..., Awaitable[T]]
) -> None:
super(_AsyncTransactional, self).__init__(to_wrap)

async def _pre_commit(
self, transaction: AsyncTransaction, *args, **kwargs
self, transaction: AsyncTransaction, *args: Any, **kwargs: Any
) -> Coroutine:
"""Begin transaction and call the wrapped coroutine.

Expand All @@ -254,7 +269,7 @@ async def _pre_commit(
along to the wrapped coroutine.

Returns:
Any: result of the wrapped coroutine.
T: result of the wrapped coroutine.

Raises:
Exception: Any failure caused by ``to_wrap``.
Expand All @@ -269,20 +284,22 @@ async def _pre_commit(
self.retry_id = self.current_id
return await self.to_wrap(transaction, *args, **kwargs)

async def __call__(self, transaction, *args, **kwargs):
async def __call__(
self, transaction: AsyncTransaction, *args: Any, **kwargs: Any
) -> T:
"""Execute the wrapped callable within a transaction.

Args:
transaction
(:class:`~google.cloud.firestore_v1.transaction.Transaction`):
(:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`):
A transaction to execute the callable within.
args (Tuple[Any, ...]): The extra positional arguments to pass
along to the wrapped callable.
kwargs (Dict[str, Any]): The extra keyword arguments to pass
along to the wrapped callable.

Returns:
Any: The result of the wrapped callable.
T: The result of the wrapped callable.

Raises:
ValueError: If the transaction does not succeed in
Expand Down Expand Up @@ -320,14 +337,17 @@ async def __call__(self, transaction, *args, **kwargs):
raise


class WithAsyncTransaction(Protocol[T]):
def __call__(self, transaction: AsyncTransaction, *args: Any, **kwargs: Any) -> Awaitable[T]: ...

def async_transactional(
to_wrap: Callable[[AsyncTransaction], Any]
) -> _AsyncTransactional:
to_wrap: Callable[..., Awaitable[T]]
) -> WithAsyncTransaction[T]:
"""Decorate a callable so that it runs in a transaction.

Args:
to_wrap
(Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]):
(Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Any]):
A callable that should be run (and retried) in a transaction.

Returns:
Expand Down
Loading