Skip to content

Commit fc12b40

Browse files
daniel-sanchegcf-owl-bot[bot]parthea
authored
feat: add type annotations to wrapped grpc calls (#554)
* add types to grpc call wrappers * fixed tests * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * changed type * changed async types * added tests * fixed lint issues * Update tests/asyncio/test_grpc_helpers_async.py Co-authored-by: Anthonios Partheniou <[email protected]> * turned GrpcStream into a type alias * added test for GrpcStream * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * added comment * reordered types * changed type var to P --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com> Co-authored-by: Anthonios Partheniou <[email protected]>
1 parent 448923a commit fc12b40

File tree

4 files changed

+70
-13
lines changed

4 files changed

+70
-13
lines changed

google/api_core/grpc_helpers.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Helpers for :mod:`grpc`."""
16+
from typing import Generic, TypeVar, Iterator
1617

1718
import collections
1819
import functools
@@ -54,6 +55,9 @@
5455

5556
_LOGGER = logging.getLogger(__name__)
5657

58+
# denotes the proto response type for grpc calls
59+
P = TypeVar("P")
60+
5761

5862
def _patch_callable_name(callable_):
5963
"""Fix-up gRPC callable attributes.
@@ -79,7 +83,7 @@ def error_remapped_callable(*args, **kwargs):
7983
return error_remapped_callable
8084

8185

82-
class _StreamingResponseIterator(grpc.Call):
86+
class _StreamingResponseIterator(Generic[P], grpc.Call):
8387
def __init__(self, wrapped, prefetch_first_result=True):
8488
self._wrapped = wrapped
8589

@@ -97,11 +101,11 @@ def __init__(self, wrapped, prefetch_first_result=True):
97101
# ignore stop iteration at this time. This should be handled outside of retry.
98102
pass
99103

100-
def __iter__(self):
104+
def __iter__(self) -> Iterator[P]:
101105
"""This iterator is also an iterable that returns itself."""
102106
return self
103107

104-
def __next__(self):
108+
def __next__(self) -> P:
105109
"""Get the next response from the stream.
106110
107111
Returns:
@@ -144,6 +148,10 @@ def trailing_metadata(self):
144148
return self._wrapped.trailing_metadata()
145149

146150

151+
# public type alias denoting the return type of streaming gapic calls
152+
GrpcStream = _StreamingResponseIterator[P]
153+
154+
147155
def _wrap_stream_errors(callable_):
148156
"""Wrap errors for Unary-Stream and Stream-Stream gRPC callables.
149157

google/api_core/grpc_helpers_async.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@
2121
import asyncio
2222
import functools
2323

24+
from typing import Generic, Iterator, AsyncGenerator, TypeVar
25+
2426
import grpc
2527
from grpc import aio
2628

2729
from google.api_core import exceptions, grpc_helpers
2830

31+
# denotes the proto response type for grpc calls
32+
P = TypeVar("P")
2933

3034
# NOTE(lidiz) Alternatively, we can hack "__getattribute__" to perform
3135
# automatic patching for us. But that means the overhead of creating an
@@ -75,26 +79,26 @@ async def wait_for_connection(self):
7579
raise exceptions.from_grpc_error(rpc_error) from rpc_error
7680

7781

78-
class _WrappedUnaryResponseMixin(_WrappedCall):
79-
def __await__(self):
82+
class _WrappedUnaryResponseMixin(Generic[P], _WrappedCall):
83+
def __await__(self) -> Iterator[P]:
8084
try:
8185
response = yield from self._call.__await__()
8286
return response
8387
except grpc.RpcError as rpc_error:
8488
raise exceptions.from_grpc_error(rpc_error) from rpc_error
8589

8690

87-
class _WrappedStreamResponseMixin(_WrappedCall):
91+
class _WrappedStreamResponseMixin(Generic[P], _WrappedCall):
8892
def __init__(self):
8993
self._wrapped_async_generator = None
9094

91-
async def read(self):
95+
async def read(self) -> P:
9296
try:
9397
return await self._call.read()
9498
except grpc.RpcError as rpc_error:
9599
raise exceptions.from_grpc_error(rpc_error) from rpc_error
96100

97-
async def _wrapped_aiter(self):
101+
async def _wrapped_aiter(self) -> AsyncGenerator[P, None]:
98102
try:
99103
# NOTE(lidiz) coverage doesn't understand the exception raised from
100104
# __anext__ method. It is covered by test case:
@@ -104,7 +108,7 @@ async def _wrapped_aiter(self):
104108
except grpc.RpcError as rpc_error:
105109
raise exceptions.from_grpc_error(rpc_error) from rpc_error
106110

107-
def __aiter__(self):
111+
def __aiter__(self) -> AsyncGenerator[P, None]:
108112
if not self._wrapped_async_generator:
109113
self._wrapped_async_generator = self._wrapped_aiter()
110114
return self._wrapped_async_generator
@@ -127,26 +131,32 @@ async def done_writing(self):
127131
# NOTE(lidiz) Implementing each individual class separately, so we don't
128132
# expose any API that should not be seen. E.g., __aiter__ in unary-unary
129133
# RPC, or __await__ in stream-stream RPC.
130-
class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin, aio.UnaryUnaryCall):
134+
class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin[P], aio.UnaryUnaryCall):
131135
"""Wrapped UnaryUnaryCall to map exceptions."""
132136

133137

134-
class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin, aio.UnaryStreamCall):
138+
class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin[P], aio.UnaryStreamCall):
135139
"""Wrapped UnaryStreamCall to map exceptions."""
136140

137141

138142
class _WrappedStreamUnaryCall(
139-
_WrappedUnaryResponseMixin, _WrappedStreamRequestMixin, aio.StreamUnaryCall
143+
_WrappedUnaryResponseMixin[P], _WrappedStreamRequestMixin, aio.StreamUnaryCall
140144
):
141145
"""Wrapped StreamUnaryCall to map exceptions."""
142146

143147

144148
class _WrappedStreamStreamCall(
145-
_WrappedStreamRequestMixin, _WrappedStreamResponseMixin, aio.StreamStreamCall
149+
_WrappedStreamRequestMixin, _WrappedStreamResponseMixin[P], aio.StreamStreamCall
146150
):
147151
"""Wrapped StreamStreamCall to map exceptions."""
148152

149153

154+
# public type alias denoting the return type of async streaming gapic calls
155+
GrpcAsyncStream = _WrappedStreamResponseMixin[P]
156+
# public type alias denoting the return type of unary gapic calls
157+
AwaitableGrpcCall = _WrappedUnaryResponseMixin[P]
158+
159+
150160
def _wrap_unary_errors(callable_):
151161
"""Map errors for Unary-Unary async callables."""
152162
grpc_helpers._patch_callable_name(callable_)

tests/asyncio/test_grpc_helpers_async.py

+22
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,28 @@ def test_wrap_errors_non_streaming(wrap_unary_errors):
266266
wrap_unary_errors.assert_called_once_with(callable_)
267267

268268

269+
def test_grpc_async_stream():
270+
"""
271+
GrpcAsyncStream type should be both an AsyncIterator and a grpc.aio.Call.
272+
"""
273+
instance = grpc_helpers_async.GrpcAsyncStream[int]()
274+
assert isinstance(instance, grpc.aio.Call)
275+
# should implement __aiter__ and __anext__
276+
assert hasattr(instance, "__aiter__")
277+
it = instance.__aiter__()
278+
assert hasattr(it, "__anext__")
279+
280+
281+
def test_awaitable_grpc_call():
282+
"""
283+
AwaitableGrpcCall type should be an Awaitable and a grpc.aio.Call.
284+
"""
285+
instance = grpc_helpers_async.AwaitableGrpcCall[int]()
286+
assert isinstance(instance, grpc.aio.Call)
287+
# should implement __await__
288+
assert hasattr(instance, "__await__")
289+
290+
269291
@mock.patch("google.api_core.grpc_helpers_async._wrap_stream_errors")
270292
def test_wrap_errors_streaming(wrap_stream_errors):
271293
callable_ = mock.create_autospec(aio.UnaryStreamMultiCallable)

tests/unit/test_grpc_helpers.py

+17
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,23 @@ def test_trailing_metadata(self):
195195
wrapped.trailing_metadata.assert_called_once_with()
196196

197197

198+
class TestGrpcStream(Test_StreamingResponseIterator):
199+
@staticmethod
200+
def _make_one(wrapped, **kw):
201+
return grpc_helpers.GrpcStream(wrapped, **kw)
202+
203+
def test_grpc_stream_attributes(self):
204+
"""
205+
Should be both a grpc.Call and an iterable
206+
"""
207+
call = self._make_one(None)
208+
assert isinstance(call, grpc.Call)
209+
# should implement __iter__
210+
assert hasattr(call, "__iter__")
211+
it = call.__iter__()
212+
assert hasattr(it, "__next__")
213+
214+
198215
def test_wrap_stream_okay():
199216
expected_responses = [1, 2, 3]
200217
callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses))

0 commit comments

Comments
 (0)