21
21
import asyncio
22
22
import functools
23
23
24
+ from typing import Generic , Iterator , AsyncGenerator , TypeVar
25
+
24
26
import grpc
25
27
from grpc import aio
26
28
27
29
from google .api_core import exceptions , grpc_helpers
28
30
31
+ # denotes the proto response type for grpc calls
32
+ P = TypeVar ("P" )
29
33
30
34
# NOTE(lidiz) Alternatively, we can hack "__getattribute__" to perform
31
35
# automatic patching for us. But that means the overhead of creating an
@@ -75,26 +79,26 @@ async def wait_for_connection(self):
75
79
raise exceptions .from_grpc_error (rpc_error ) from rpc_error
76
80
77
81
78
- class _WrappedUnaryResponseMixin (_WrappedCall ):
79
- def __await__ (self ):
82
+ class _WrappedUnaryResponseMixin (Generic [ P ], _WrappedCall ):
83
+ def __await__ (self ) -> Iterator [ P ] :
80
84
try :
81
85
response = yield from self ._call .__await__ ()
82
86
return response
83
87
except grpc .RpcError as rpc_error :
84
88
raise exceptions .from_grpc_error (rpc_error ) from rpc_error
85
89
86
90
87
- class _WrappedStreamResponseMixin (_WrappedCall ):
91
+ class _WrappedStreamResponseMixin (Generic [ P ], _WrappedCall ):
88
92
def __init__ (self ):
89
93
self ._wrapped_async_generator = None
90
94
91
- async def read (self ):
95
+ async def read (self ) -> P :
92
96
try :
93
97
return await self ._call .read ()
94
98
except grpc .RpcError as rpc_error :
95
99
raise exceptions .from_grpc_error (rpc_error ) from rpc_error
96
100
97
- async def _wrapped_aiter (self ):
101
+ async def _wrapped_aiter (self ) -> AsyncGenerator [ P , None ] :
98
102
try :
99
103
# NOTE(lidiz) coverage doesn't understand the exception raised from
100
104
# __anext__ method. It is covered by test case:
@@ -104,7 +108,7 @@ async def _wrapped_aiter(self):
104
108
except grpc .RpcError as rpc_error :
105
109
raise exceptions .from_grpc_error (rpc_error ) from rpc_error
106
110
107
- def __aiter__ (self ):
111
+ def __aiter__ (self ) -> AsyncGenerator [ P , None ] :
108
112
if not self ._wrapped_async_generator :
109
113
self ._wrapped_async_generator = self ._wrapped_aiter ()
110
114
return self ._wrapped_async_generator
@@ -127,26 +131,32 @@ async def done_writing(self):
127
131
# NOTE(lidiz) Implementing each individual class separately, so we don't
128
132
# expose any API that should not be seen. E.g., __aiter__ in unary-unary
129
133
# RPC, or __await__ in stream-stream RPC.
130
- class _WrappedUnaryUnaryCall (_WrappedUnaryResponseMixin , aio .UnaryUnaryCall ):
134
+ class _WrappedUnaryUnaryCall (_WrappedUnaryResponseMixin [ P ] , aio .UnaryUnaryCall ):
131
135
"""Wrapped UnaryUnaryCall to map exceptions."""
132
136
133
137
134
- class _WrappedUnaryStreamCall (_WrappedStreamResponseMixin , aio .UnaryStreamCall ):
138
+ class _WrappedUnaryStreamCall (_WrappedStreamResponseMixin [ P ] , aio .UnaryStreamCall ):
135
139
"""Wrapped UnaryStreamCall to map exceptions."""
136
140
137
141
138
142
class _WrappedStreamUnaryCall (
139
- _WrappedUnaryResponseMixin , _WrappedStreamRequestMixin , aio .StreamUnaryCall
143
+ _WrappedUnaryResponseMixin [ P ] , _WrappedStreamRequestMixin , aio .StreamUnaryCall
140
144
):
141
145
"""Wrapped StreamUnaryCall to map exceptions."""
142
146
143
147
144
148
class _WrappedStreamStreamCall (
145
- _WrappedStreamRequestMixin , _WrappedStreamResponseMixin , aio .StreamStreamCall
149
+ _WrappedStreamRequestMixin , _WrappedStreamResponseMixin [ P ] , aio .StreamStreamCall
146
150
):
147
151
"""Wrapped StreamStreamCall to map exceptions."""
148
152
149
153
154
+ # public type alias denoting the return type of async streaming gapic calls
155
+ GrpcAsyncStream = _WrappedStreamResponseMixin [P ]
156
+ # public type alias denoting the return type of unary gapic calls
157
+ AwaitableGrpcCall = _WrappedUnaryResponseMixin [P ]
158
+
159
+
150
160
def _wrap_unary_errors (callable_ ):
151
161
"""Map errors for Unary-Unary async callables."""
152
162
grpc_helpers ._patch_callable_name (callable_ )
0 commit comments