Skip to content

Commit c97a636

Browse files
chore: avoid type checks in error wrapper
1 parent 0cc03bd commit c97a636

File tree

1 file changed

+11
-18
lines changed

1 file changed

+11
-18
lines changed

google/api_core/grpc_helpers_async.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,6 @@ class _WrappedStreamStreamCall(
149149

150150
def _wrap_unary_errors(callable_):
151151
"""Map errors for Unary-Unary async callables."""
152-
grpc_helpers._patch_callable_name(callable_)
153-
154152
@functools.wraps(callable_)
155153
def error_remapped_callable(*args, **kwargs):
156154
call = callable_(*args, **kwargs)
@@ -159,25 +157,13 @@ def error_remapped_callable(*args, **kwargs):
159157
return error_remapped_callable
160158

161159

162-
def _wrap_stream_errors(callable_):
160+
def _wrap_stream_errors(callable_, wrapper_type):
163161
"""Map errors for streaming RPC async callables."""
164-
grpc_helpers._patch_callable_name(callable_)
165-
166162
@functools.wraps(callable_)
167163
async def error_remapped_callable(*args, **kwargs):
168164
call = callable_(*args, **kwargs)
169-
170-
if isinstance(call, aio.UnaryStreamCall):
171-
call = _WrappedUnaryStreamCall().with_call(call)
172-
elif isinstance(call, aio.StreamUnaryCall):
173-
call = _WrappedStreamUnaryCall().with_call(call)
174-
elif isinstance(call, aio.StreamStreamCall):
175-
call = _WrappedStreamStreamCall().with_call(call)
176-
else:
177-
raise TypeError("Unexpected type of call %s" % type(call))
178-
179165
await call.wait_for_connection()
180-
return call
166+
return wrapper_type().with_call(call)
181167

182168
return error_remapped_callable
183169

@@ -197,10 +183,17 @@ def wrap_errors(callable_):
197183
198184
Returns: Callable: The wrapped gRPC callable.
199185
"""
186+
grpc_helpers._patch_callable_name(callable_)
200187
if isinstance(callable_, aio.UnaryUnaryMultiCallable):
201-
return _wrap_unary_errors(callable_)
202188
else:
203-
return _wrap_stream_errors(callable_)
189+
if isinstance(callable_, aio.UnaryStreamMultiCallable):
190+
return _wrap_stream_errors(callable_, _WrappedUnaryStreamCall)
191+
elif isinstance(callable_, aio.StreamUnaryMultiCallable):
192+
return _wrap_stream_errors(callable_, _WrappedStreamUnaryCall)
193+
elif isinstance(callable_, aio.StreamStreamMultiCallable):
194+
return _wrap_stream_errors(callable_, _WrappedStreamStreamCall)
195+
else:
196+
raise TypeError("Unexpected type of callable: {}".format(type(callable_)))
204197

205198

206199
def create_channel(

0 commit comments

Comments
 (0)