diff --git a/google/api_core/bidi.py b/google/api_core/bidi.py index be52d97d..db19f286 100644 --- a/google/api_core/bidi.py +++ b/google/api_core/bidi.py @@ -240,7 +240,8 @@ class BidiRpc(object): yield. This is useful if an initial request is needed to start the stream. metadata (Sequence[Tuple(str, str)]): RPC metadata to include in - the request. + the request. If no metadata is provided, metadata from + start_rpc will be used. """ def __init__(self, start_rpc, initial_request=None, metadata=None): @@ -277,7 +278,11 @@ def open(self): request_generator = _RequestQueueGenerator( self._request_queue, initial_request=self._initial_request ) - call = self._start_rpc(iter(request_generator), metadata=self._rpc_metadata) + if self._rpc_metadata: + call = self._start_rpc(iter(request_generator), metadata=self._rpc_metadata) + # use metadata from self._start_rpc if no other metadata is specified + else: + call = self._start_rpc(iter(request_generator)) request_generator.call = call diff --git a/tests/unit/test_bidi.py b/tests/unit/test_bidi.py index 52215cbd..24a8ff0d 100644 --- a/tests/unit/test_bidi.py +++ b/tests/unit/test_bidi.py @@ -202,12 +202,12 @@ class _CallAndFuture(grpc.Call, grpc.Future): pass -def make_rpc(): +def make_rpc(metadata=None): """Makes a mock RPC used to test Bidi classes.""" call = mock.create_autospec(_CallAndFuture, instance=True) rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True) - def rpc_side_effect(request, metadata=None): + def rpc_side_effect(request, metadata=metadata): call.is_active.return_value = True call.request = request call.metadata = metadata @@ -265,12 +265,13 @@ def test_metadata(self): assert bidi_rpc.call.metadata == mock.sentinel.A def test_open(self): - rpc, call = make_rpc() - bidi_rpc = bidi.BidiRpc(rpc) + rpc, call = make_rpc(metadata=[(1, 2)]) + bidi_rpc = bidi.BidiRpc(rpc, metadata=[(3, 4)]) bidi_rpc.open() assert bidi_rpc.call == call + assert bidi_rpc.call.metadata == [(3, 4)] assert bidi_rpc.is_active call.add_done_callback.assert_called_once_with(bidi_rpc._on_call_done) @@ -283,6 +284,14 @@ def test_open_error_already_open(self): with pytest.raises(ValueError): bidi_rpc.open() + def test_open_use_start_rpc_metadata(self): + rpc, _ = make_rpc(metadata=[(1, 2)]) + bidi_rpc = bidi.BidiRpc(rpc) + + bidi_rpc.open() + + assert bidi_rpc.call.metadata == [(1, 2)] + def test_close(self): rpc, call = make_rpc() bidi_rpc = bidi.BidiRpc(rpc)