|
14 | 14 |
|
15 | 15 | """Helpers for :mod:`grpc`."""
|
16 | 16 |
|
| 17 | +import collections |
| 18 | + |
17 | 19 | import grpc
|
18 | 20 | import six
|
19 | 21 |
|
@@ -136,3 +138,185 @@ def create_channel(target, credentials=None, scopes=None, **kwargs):
|
136 | 138 |
|
137 | 139 | return google.auth.transport.grpc.secure_authorized_channel(
|
138 | 140 | credentials, request, target, **kwargs)
|
| 141 | + |
| 142 | + |
| 143 | +_MethodCall = collections.namedtuple( |
| 144 | + '_MethodCall', ('request', 'timeout', 'metadata', 'credentials')) |
| 145 | + |
| 146 | +_ChannelRequest = collections.namedtuple( |
| 147 | + '_ChannelRequest', ('method', 'request')) |
| 148 | + |
| 149 | + |
| 150 | +class _CallableStub(object): |
| 151 | + """Stub for the grpc.*MultiCallable interfaces.""" |
| 152 | + |
| 153 | + def __init__(self, method, channel): |
| 154 | + self._method = method |
| 155 | + self._channel = channel |
| 156 | + self.response = None |
| 157 | + """Union[protobuf.Message, Callable[protobuf.Message], exception]: |
| 158 | + The response to give when invoking this callable. If this is a |
| 159 | + callable, it will be invoked with the request protobuf. If it's an |
| 160 | + exception, the exception will be raised when this is invoked. |
| 161 | + """ |
| 162 | + self.responses = None |
| 163 | + """Iterator[ |
| 164 | + Union[protobuf.Message, Callable[protobuf.Message], exception]]: |
| 165 | + An iterator of responses. If specified, self.response will be populated |
| 166 | + on each invocation by calling ``next(self.responses)``.""" |
| 167 | + self.requests = [] |
| 168 | + """List[protobuf.Message]: All requests sent to this callable.""" |
| 169 | + self.calls = [] |
| 170 | + """List[Tuple]: All invocations of this callable. Each tuple is the |
| 171 | + request, timeout, metadata, and credentials.""" |
| 172 | + |
| 173 | + def __call__(self, request, timeout=None, metadata=None, credentials=None): |
| 174 | + self._channel.requests.append( |
| 175 | + _ChannelRequest(self._method, request)) |
| 176 | + self.calls.append( |
| 177 | + _MethodCall(request, timeout, metadata, credentials)) |
| 178 | + self.requests.append(request) |
| 179 | + |
| 180 | + response = self.response |
| 181 | + if self.responses is not None: |
| 182 | + if response is None: |
| 183 | + response = next(self.responses) |
| 184 | + else: |
| 185 | + raise ValueError( |
| 186 | + '{method}.response and {method}.responses are mutually ' |
| 187 | + 'exclusive.'.format(method=self._method)) |
| 188 | + |
| 189 | + if callable(response): |
| 190 | + return response(request) |
| 191 | + |
| 192 | + if isinstance(response, Exception): |
| 193 | + raise response |
| 194 | + |
| 195 | + if response is not None: |
| 196 | + return response |
| 197 | + |
| 198 | + raise ValueError( |
| 199 | + 'Method stub for "{}" has no response.'.format(self._method)) |
| 200 | + |
| 201 | + |
| 202 | +def _simplify_method_name(method): |
| 203 | + """Simplifies a gRPC method name. |
| 204 | +
|
| 205 | + When gRPC invokes the channel to create a callable, it gives a full |
| 206 | + method name like "/google.pubsub.v1.Publisher/CreateTopic". This |
| 207 | + returns just the name of the method, in this case "CreateTopic". |
| 208 | +
|
| 209 | + Args: |
| 210 | + method (str): The name of the method. |
| 211 | +
|
| 212 | + Returns: |
| 213 | + str: The simplified name of the method. |
| 214 | + """ |
| 215 | + return method.rsplit('/', 1).pop() |
| 216 | + |
| 217 | + |
| 218 | +class ChannelStub(grpc.Channel): |
| 219 | + """A testing stub for the grpc.Channel interface. |
| 220 | +
|
| 221 | + This can be used to test any client that eventually uses a gRPC channel |
| 222 | + to communicate. By passing in a channel stub, you can configure which |
| 223 | + responses are returned and track which requests are made. |
| 224 | +
|
| 225 | + For example: |
| 226 | +
|
| 227 | + .. code-block:: python |
| 228 | +
|
| 229 | + channel_stub = grpc_helpers.ChannelStub() |
| 230 | + client = FooClient(channel=channel_stub) |
| 231 | +
|
| 232 | + channel_stub.GetFoo.response = foo_pb2.Foo(name='bar') |
| 233 | +
|
| 234 | + foo = client.get_foo(labels=['baz']) |
| 235 | +
|
| 236 | + assert foo.name == 'bar' |
| 237 | + assert channel_stub.GetFoo.requests[0].labels = ['baz'] |
| 238 | +
|
| 239 | + Each method on the stub can be accessed and configured on the channel. |
| 240 | + Here's some examples of various configurations: |
| 241 | +
|
| 242 | + .. code-block:: python |
| 243 | +
|
| 244 | + # Return a basic response: |
| 245 | +
|
| 246 | + channel_stub.GetFoo.response = foo_pb2.Foo(name='bar') |
| 247 | + assert client.get_foo().name == 'bar' |
| 248 | +
|
| 249 | + # Raise an exception: |
| 250 | + channel_stub.GetFoo.response = NotFound('...') |
| 251 | +
|
| 252 | + with pytest.raises(NotFound): |
| 253 | + client.get_foo() |
| 254 | +
|
| 255 | + # Use a sequence of responses: |
| 256 | + channel_stub.GetFoo.responses = iter([ |
| 257 | + foo_pb2.Foo(name='bar'), |
| 258 | + foo_pb2.Foo(name='baz'), |
| 259 | + ]) |
| 260 | +
|
| 261 | + assert client.get_foo().name == 'bar' |
| 262 | + assert client.get_foo().name == 'baz' |
| 263 | +
|
| 264 | + # Use a callable |
| 265 | +
|
| 266 | + def on_get_foo(request): |
| 267 | + return foo_pb2.Foo(name='bar' + request.id) |
| 268 | +
|
| 269 | + channel_stub.GetFoo.response = on_get_foo |
| 270 | +
|
| 271 | + assert client.get_foo(id='123').name == 'bar123' |
| 272 | + """ |
| 273 | + |
| 274 | + def __init__(self, responses=[]): |
| 275 | + self.requests = [] |
| 276 | + """Sequence[Tuple[str, protobuf.Message]]: A list of all requests made |
| 277 | + on this channel in order. The tuple is of method name, request |
| 278 | + message.""" |
| 279 | + self._method_stubs = {} |
| 280 | + |
| 281 | + def _stub_for_method(self, method): |
| 282 | + method = _simplify_method_name(method) |
| 283 | + self._method_stubs[method] = _CallableStub(method, self) |
| 284 | + return self._method_stubs[method] |
| 285 | + |
| 286 | + def __getattr__(self, key): |
| 287 | + try: |
| 288 | + return self._method_stubs[key] |
| 289 | + except KeyError: |
| 290 | + raise AttributeError |
| 291 | + |
| 292 | + def unary_unary( |
| 293 | + self, method, |
| 294 | + request_serializer=None, response_deserializer=None): |
| 295 | + """grpc.Channel.unary_unary implementation.""" |
| 296 | + return self._stub_for_method(method) |
| 297 | + |
| 298 | + def unary_stream( |
| 299 | + self, method, |
| 300 | + request_serializer=None, response_deserializer=None): |
| 301 | + """grpc.Channel.unary_stream implementation.""" |
| 302 | + return self._stub_for_method(method) |
| 303 | + |
| 304 | + def stream_unary( |
| 305 | + self, method, |
| 306 | + request_serializer=None, response_deserializer=None): |
| 307 | + """grpc.Channel.stream_unary implementation.""" |
| 308 | + return self._stub_for_method(method) |
| 309 | + |
| 310 | + def stream_stream( |
| 311 | + self, method, |
| 312 | + request_serializer=None, response_deserializer=None): |
| 313 | + """grpc.Channel.stream_stream implementation.""" |
| 314 | + return self._stub_for_method(method) |
| 315 | + |
| 316 | + def subscribe(self, callback, try_to_connect=False): |
| 317 | + """grpc.Channel.subscribe implementation.""" |
| 318 | + pass |
| 319 | + |
| 320 | + def unsubscribe(self, callback): |
| 321 | + """grpc.Channel.unsubscribe implementation.""" |
| 322 | + pass |
0 commit comments