Skip to content

Commit c9e3277

Browse files
committed
fix: resolve issue handling protobuf responses in rest streaming
1 parent 82c3118 commit c9e3277

File tree

2 files changed

+163
-55
lines changed

2 files changed

+163
-55
lines changed

google/api_core/rest_streaming.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,28 @@
1616

1717
from collections import deque
1818
import string
19-
from typing import Deque
19+
from typing import Deque, Union
2020

21+
import proto
2122
import requests
23+
import google.protobuf.message
24+
from google.protobuf.json_format import Parse
2225

2326

2427
class ResponseIterator:
2528
"""Iterator over REST API responses.
2629
2730
Args:
2831
response (requests.Response): An API response object.
29-
response_message_cls (Callable[proto.Message]): A proto
32+
response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
3033
class expected to be returned from an API.
3134
"""
3235

33-
def __init__(self, response: requests.Response, response_message_cls):
36+
def __init__(
37+
self,
38+
response: requests.Response,
39+
response_message_cls: Union[proto.Message, google.protobuf.message.Message],
40+
):
3441
self._response = response
3542
self._response_message_cls = response_message_cls
3643
# Inner iterator over HTTP response's content.
@@ -107,7 +114,10 @@ def __next__(self):
107114

108115
def _grab(self):
109116
# Add extra quotes to make json.loads happy.
110-
return self._response_message_cls.from_json(self._ready_objs.popleft())
117+
if issubclass(self._response_message_cls, proto.Message):
118+
return self._response_message_cls.from_json(self._ready_objs.popleft())
119+
else:
120+
return Parse(self._ready_objs.popleft(), self._response_message_cls())
111121

112122
def __iter__(self):
113123
return self

tests/unit/test_rest_streaming.py

Lines changed: 149 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@
2424
import requests
2525

2626
from google.api_core import rest_streaming
27+
from google.api import http_pb2
28+
from google.api import httpbody_pb2
2729
from google.protobuf import duration_pb2
2830
from google.protobuf import timestamp_pb2
31+
from google.protobuf.json_format import MessageToJson
2932

3033

3134
__protobuf__ = proto.module(package=__name__)
@@ -98,7 +101,10 @@ def _parse_responses(self, responses: List[proto.Message]) -> bytes:
98101
# json.dumps returns a string surrounded with quotes that need to be stripped
99102
# in order to be an actual JSON.
100103
json_responses = [
101-
self._response_message_cls.to_json(r).strip('"') for r in responses
104+
self._response_message_cls.to_json(r).strip('"')
105+
if issubclass(self._response_message_cls, proto.Message)
106+
else MessageToJson(r).strip('"')
107+
for r in responses
102108
]
103109
logging.info(f"Sending JSON stream: {json_responses}")
104110
ret_val = "[{}]".format(",".join(json_responses))
@@ -114,103 +120,195 @@ def iter_content(self, *args, **kwargs):
114120
)
115121

116122

117-
@pytest.mark.parametrize("random_split", [False])
118-
def test_next_simple(random_split):
119-
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
123+
@pytest.mark.parametrize(
124+
"random_split,resp_message_is_proto_plus,response_type",
125+
[(False, True, EchoResponse), (False, False, httpbody_pb2.HttpBody)],
126+
)
127+
def test_next_simple(random_split, resp_message_is_proto_plus, response_type):
128+
if resp_message_is_proto_plus:
129+
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
130+
else:
131+
responses = [
132+
httpbody_pb2.HttpBody(content_type="hello world"),
133+
httpbody_pb2.HttpBody(content_type="yes"),
134+
]
135+
120136
resp = ResponseMock(
121-
responses=responses, random_split=random_split, response_cls=EchoResponse
137+
responses=responses, random_split=random_split, response_cls=response_type
122138
)
123-
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
139+
itr = rest_streaming.ResponseIterator(resp, response_type)
124140
assert list(itr) == responses
125141

126142

127-
@pytest.mark.parametrize("random_split", [True, False])
128-
def test_next_nested(random_split):
129-
responses = [
130-
Song(title="some song", composer=Composer(given_name="some name")),
131-
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
132-
]
143+
@pytest.mark.parametrize(
144+
"random_split,resp_message_is_proto_plus,response_type",
145+
[
146+
(True, True, Song),
147+
(False, True, Song),
148+
(True, False, http_pb2.HttpRule),
149+
(False, False, http_pb2.HttpRule),
150+
],
151+
)
152+
def test_next_nested(random_split, resp_message_is_proto_plus, response_type):
153+
if resp_message_is_proto_plus:
154+
responses = [
155+
Song(title="some song", composer=Composer(given_name="some name")),
156+
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
157+
]
158+
else:
159+
# Although `http_pb2.HttpRule`` is used in the response, any response message
160+
# can be used which meets this criteria for the test of having a nested field.
161+
responses = [
162+
http_pb2.HttpRule(
163+
selector="some selector",
164+
custom=http_pb2.CustomHttpPattern(kind="some kind"),
165+
),
166+
http_pb2.HttpRule(
167+
selector="another selector",
168+
custom=http_pb2.CustomHttpPattern(path="some path"),
169+
),
170+
]
133171
resp = ResponseMock(
134-
responses=responses, random_split=random_split, response_cls=Song
172+
responses=responses, random_split=random_split, response_cls=response_type
135173
)
136-
itr = rest_streaming.ResponseIterator(resp, Song)
174+
itr = rest_streaming.ResponseIterator(resp, response_type)
137175
assert list(itr) == responses
138176

139177

140-
@pytest.mark.parametrize("random_split", [True, False])
141-
def test_next_stress(random_split):
178+
@pytest.mark.parametrize(
179+
"random_split,resp_message_is_proto_plus,response_type",
180+
[
181+
(True, True, Song),
182+
(False, True, Song),
183+
(True, False, http_pb2.HttpRule),
184+
(False, False, http_pb2.HttpRule),
185+
],
186+
)
187+
def test_next_stress(random_split, resp_message_is_proto_plus, response_type):
142188
n = 50
143-
responses = [
144-
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
145-
for i in range(n)
146-
]
189+
if resp_message_is_proto_plus:
190+
responses = [
191+
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
192+
for i in range(n)
193+
]
194+
else:
195+
responses = [
196+
http_pb2.HttpRule(
197+
selector="selector_%d" % i,
198+
custom=http_pb2.CustomHttpPattern(path="path_%d" % i),
199+
)
200+
for i in range(n)
201+
]
147202
resp = ResponseMock(
148-
responses=responses, random_split=random_split, response_cls=Song
203+
responses=responses, random_split=random_split, response_cls=response_type
149204
)
150-
itr = rest_streaming.ResponseIterator(resp, Song)
205+
itr = rest_streaming.ResponseIterator(resp, response_type)
151206
assert list(itr) == responses
152207

153208

154-
@pytest.mark.parametrize("random_split", [True, False])
155-
def test_next_escaped_characters_in_string(random_split):
156-
composer_with_relateds = Composer()
157-
relateds = ["Artist A", "Artist B"]
158-
composer_with_relateds.relateds = relateds
159-
160-
responses = [
161-
Song(title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")),
162-
Song(
163-
title='{"this is weird": "totally"}', composer=Composer(given_name="\\{}\\")
164-
),
165-
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
166-
]
209+
@pytest.mark.parametrize(
210+
"random_split,resp_message_is_proto_plus,response_type",
211+
[
212+
(True, True, Song),
213+
(False, True, Song),
214+
(True, False, http_pb2.Http),
215+
(False, False, http_pb2.Http),
216+
],
217+
)
218+
def test_next_escaped_characters_in_string(
219+
random_split, resp_message_is_proto_plus, response_type
220+
):
221+
if resp_message_is_proto_plus:
222+
composer_with_relateds = Composer()
223+
relateds = ["Artist A", "Artist B"]
224+
composer_with_relateds.relateds = relateds
225+
226+
responses = [
227+
Song(
228+
title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")
229+
),
230+
Song(
231+
title='{"this is weird": "totally"}',
232+
composer=Composer(given_name="\\{}\\"),
233+
),
234+
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
235+
]
236+
else:
237+
responses = [
238+
http_pb2.Http(
239+
rules=[
240+
http_pb2.HttpRule(
241+
selector='ti"tle\nfoo\tbar{}',
242+
custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"),
243+
)
244+
]
245+
),
246+
http_pb2.Http(
247+
rules=[
248+
http_pb2.HttpRule(
249+
selector='{"this is weird": "totally"}',
250+
custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
251+
)
252+
]
253+
),
254+
http_pb2.Http(
255+
rules=[
256+
http_pb2.HttpRule(
257+
selector='\\{"key": ["value",]}\\',
258+
custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
259+
)
260+
]
261+
),
262+
]
167263
resp = ResponseMock(
168-
responses=responses, random_split=random_split, response_cls=Song
264+
responses=responses, random_split=random_split, response_cls=response_type
169265
)
170-
itr = rest_streaming.ResponseIterator(resp, Song)
266+
itr = rest_streaming.ResponseIterator(resp, response_type)
171267
assert list(itr) == responses
172268

173269

174-
def test_next_not_array():
270+
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
271+
def test_next_not_array(response_type):
175272
with patch.object(
176273
ResponseMock, "iter_content", return_value=iter('{"hello": 0}')
177274
) as mock_method:
178-
179-
resp = ResponseMock(responses=[], response_cls=EchoResponse)
180-
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
275+
resp = ResponseMock(responses=[], response_cls=response_type)
276+
itr = rest_streaming.ResponseIterator(resp, response_type)
181277
with pytest.raises(ValueError):
182278
next(itr)
183279
mock_method.assert_called_once()
184280

185281

186-
def test_cancel():
282+
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
283+
def test_cancel(response_type):
187284
with patch.object(ResponseMock, "close", return_value=None) as mock_method:
188-
resp = ResponseMock(responses=[], response_cls=EchoResponse)
189-
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
285+
resp = ResponseMock(responses=[], response_cls=response_type)
286+
itr = rest_streaming.ResponseIterator(resp, response_type)
190287
itr.cancel()
191288
mock_method.assert_called_once()
192289

193290

194-
def test_check_buffer():
291+
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
292+
def test_check_buffer(response_type):
195293
with patch.object(
196294
ResponseMock,
197295
"_parse_responses",
198296
return_value=bytes('[{"content": "hello"}, {', "utf-8"),
199297
):
200-
resp = ResponseMock(responses=[], response_cls=EchoResponse)
201-
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
298+
resp = ResponseMock(responses=[], response_cls=response_type)
299+
itr = rest_streaming.ResponseIterator(resp, response_type)
202300
with pytest.raises(ValueError):
203301
next(itr)
204302
next(itr)
205303

206304

207-
def test_next_html():
305+
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
306+
def test_next_html(response_type):
208307
with patch.object(
209308
ResponseMock, "iter_content", return_value=iter("<!DOCTYPE html><html></html>")
210309
) as mock_method:
211-
212-
resp = ResponseMock(responses=[], response_cls=EchoResponse)
213-
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
310+
resp = ResponseMock(responses=[], response_cls=response_type)
311+
itr = rest_streaming.ResponseIterator(resp, response_type)
214312
with pytest.raises(ValueError):
215313
next(itr)
216314
mock_method.assert_called_once()

0 commit comments

Comments
 (0)