Skip to content

Commit f9f2696

Browse files
authored
feat: iterator for processing JSON responses in REST streaming. (googleapis#317)
* feat: files for REST server streaming.
1 parent 69a99d8 commit f9f2696

File tree

2 files changed

+325
-0
lines changed

2 files changed

+325
-0
lines changed

google/api_core/rest_streaming.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Helpers for server-side streaming in REST."""
16+
17+
from collections import deque
18+
import string
19+
from typing import Deque
20+
21+
import requests
22+
23+
24+
class ResponseIterator:
25+
"""Iterator over REST API responses.
26+
27+
Args:
28+
response (requests.Response): An API response object.
29+
response_message_cls (Callable[proto.Message]): A proto
30+
class expected to be returned from an API.
31+
"""
32+
33+
def __init__(self, response: requests.Response, response_message_cls):
34+
self._response = response
35+
self._response_message_cls = response_message_cls
36+
# Inner iterator over HTTP response's content.
37+
self._response_itr = self._response.iter_content(decode_unicode=True)
38+
# Contains a list of JSON responses ready to be sent to user.
39+
self._ready_objs: Deque[str] = deque()
40+
# Current JSON response being built.
41+
self._obj = ""
42+
# Keeps track of the nesting level within a JSON object.
43+
self._level = 0
44+
# Keeps track whether HTTP response is currently sending values
45+
# inside of a string value.
46+
self._in_string = False
47+
# Whether an escape symbol "\" was encountered.
48+
self._escape_next = False
49+
50+
def cancel(self):
51+
"""Cancel existing streaming operation.
52+
"""
53+
self._response.close()
54+
55+
def _process_chunk(self, chunk: str):
56+
if self._level == 0:
57+
if chunk[0] != "[":
58+
raise ValueError(
59+
"Can only parse array of JSON objects, instead got %s" % chunk
60+
)
61+
for char in chunk:
62+
if char == "{":
63+
if self._level == 1:
64+
# Level 1 corresponds to the outermost JSON object
65+
# (i.e. the one we care about).
66+
self._obj = ""
67+
if not self._in_string:
68+
self._level += 1
69+
self._obj += char
70+
elif char == "}":
71+
self._obj += char
72+
if not self._in_string:
73+
self._level -= 1
74+
if not self._in_string and self._level == 1:
75+
self._ready_objs.append(self._obj)
76+
elif char == '"':
77+
# Helps to deal with an escaped quotes inside of a string.
78+
if not self._escape_next:
79+
self._in_string = not self._in_string
80+
self._obj += char
81+
elif char in string.whitespace:
82+
if self._in_string:
83+
self._obj += char
84+
elif char == "[":
85+
if self._level == 0:
86+
self._level += 1
87+
else:
88+
self._obj += char
89+
elif char == "]":
90+
if self._level == 1:
91+
self._level -= 1
92+
else:
93+
self._obj += char
94+
else:
95+
self._obj += char
96+
self._escape_next = not self._escape_next if char == "\\" else False
97+
98+
def __next__(self):
99+
while not self._ready_objs:
100+
try:
101+
chunk = next(self._response_itr)
102+
self._process_chunk(chunk)
103+
except StopIteration as e:
104+
if self._level > 0:
105+
raise ValueError("Unfinished stream: %s" % self._obj)
106+
raise e
107+
return self._grab()
108+
109+
def _grab(self):
110+
# Add extra quotes to make json.loads happy.
111+
return self._response_message_cls.from_json(self._ready_objs.popleft())
112+
113+
def __iter__(self):
114+
return self

tests/unit/test_rest_streaming.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import datetime
16+
import logging
17+
import random
18+
import time
19+
from typing import List
20+
from unittest.mock import patch
21+
22+
import proto
23+
import pytest
24+
import requests
25+
26+
from google.api_core import rest_streaming
27+
from google.protobuf import duration_pb2
28+
from google.protobuf import timestamp_pb2
29+
30+
31+
SEED = int(time.time())
32+
logging.info(f"Starting rest streaming tests with random seed: {SEED}")
33+
random.seed(SEED)
34+
35+
36+
class Genre(proto.Enum):
37+
GENRE_UNSPECIFIED = 0
38+
CLASSICAL = 1
39+
JAZZ = 2
40+
ROCK = 3
41+
42+
43+
class Composer(proto.Message):
44+
given_name = proto.Field(proto.STRING, number=1)
45+
family_name = proto.Field(proto.STRING, number=2)
46+
relateds = proto.RepeatedField(proto.STRING, number=3)
47+
indices = proto.MapField(proto.STRING, proto.STRING, number=4)
48+
49+
50+
class Song(proto.Message):
51+
composer = proto.Field(Composer, number=1)
52+
title = proto.Field(proto.STRING, number=2)
53+
lyrics = proto.Field(proto.STRING, number=3)
54+
year = proto.Field(proto.INT32, number=4)
55+
genre = proto.Field(Genre, number=5)
56+
is_five_mins_longer = proto.Field(proto.BOOL, number=6)
57+
score = proto.Field(proto.DOUBLE, number=7)
58+
likes = proto.Field(proto.INT64, number=8)
59+
duration = proto.Field(duration_pb2.Duration, number=9)
60+
date_added = proto.Field(timestamp_pb2.Timestamp, number=10)
61+
62+
63+
class EchoResponse(proto.Message):
64+
content = proto.Field(proto.STRING, number=1)
65+
66+
67+
class ResponseMock(requests.Response):
68+
class _ResponseItr:
69+
def __init__(self, _response_bytes: bytes, random_split=False):
70+
self._responses_bytes = _response_bytes
71+
self._i = 0
72+
self._random_split = random_split
73+
74+
def __next__(self):
75+
if self._i == len(self._responses_bytes):
76+
raise StopIteration
77+
if self._random_split:
78+
n = random.randint(1, len(self._responses_bytes[self._i :]))
79+
else:
80+
n = 1
81+
x = self._responses_bytes[self._i : self._i + n]
82+
self._i += n
83+
return x.decode("utf-8")
84+
85+
def __init__(
86+
self, responses: List[proto.Message], response_cls, random_split=False,
87+
):
88+
super().__init__()
89+
self._responses = responses
90+
self._random_split = random_split
91+
self._response_message_cls = response_cls
92+
93+
def _parse_responses(self, responses: List[proto.Message]) -> bytes:
94+
# json.dumps returns a string surrounded with quotes that need to be stripped
95+
# in order to be an actual JSON.
96+
json_responses = [
97+
self._response_message_cls.to_json(r).strip('"') for r in responses
98+
]
99+
logging.info(f"Sending JSON stream: {json_responses}")
100+
ret_val = "[{}]".format(",".join(json_responses))
101+
return bytes(ret_val, "utf-8")
102+
103+
def close(self):
104+
raise NotImplementedError()
105+
106+
def iter_content(self, *args, **kwargs):
107+
return self._ResponseItr(
108+
self._parse_responses(self._responses), random_split=self._random_split,
109+
)
110+
111+
112+
@pytest.mark.parametrize("random_split", [False])
113+
def test_next_simple(random_split):
114+
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
115+
resp = ResponseMock(
116+
responses=responses, random_split=random_split, response_cls=EchoResponse
117+
)
118+
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
119+
assert list(itr) == responses
120+
121+
122+
@pytest.mark.parametrize("random_split", [True, False])
123+
def test_next_nested(random_split):
124+
responses = [
125+
Song(title="some song", composer=Composer(given_name="some name")),
126+
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
127+
]
128+
resp = ResponseMock(
129+
responses=responses, random_split=random_split, response_cls=Song
130+
)
131+
itr = rest_streaming.ResponseIterator(resp, Song)
132+
assert list(itr) == responses
133+
134+
135+
@pytest.mark.parametrize("random_split", [True, False])
136+
def test_next_stress(random_split):
137+
n = 50
138+
responses = [
139+
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
140+
for i in range(n)
141+
]
142+
resp = ResponseMock(
143+
responses=responses, random_split=random_split, response_cls=Song
144+
)
145+
itr = rest_streaming.ResponseIterator(resp, Song)
146+
assert list(itr) == responses
147+
148+
149+
@pytest.mark.parametrize("random_split", [True, False])
150+
def test_next_escaped_characters_in_string(random_split):
151+
composer_with_relateds = Composer()
152+
relateds = ["Artist A", "Artist B"]
153+
composer_with_relateds.relateds = relateds
154+
155+
responses = [
156+
Song(title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")),
157+
Song(
158+
title='{"this is weird": "totally"}', composer=Composer(given_name="\\{}\\")
159+
),
160+
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
161+
]
162+
resp = ResponseMock(
163+
responses=responses, random_split=random_split, response_cls=Song
164+
)
165+
itr = rest_streaming.ResponseIterator(resp, Song)
166+
assert list(itr) == responses
167+
168+
169+
def test_next_not_array():
170+
with patch.object(
171+
ResponseMock, "iter_content", return_value=iter('{"hello": 0}')
172+
) as mock_method:
173+
174+
resp = ResponseMock(responses=[], response_cls=EchoResponse)
175+
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
176+
with pytest.raises(ValueError):
177+
next(itr)
178+
mock_method.assert_called_once()
179+
180+
181+
def test_cancel():
182+
with patch.object(ResponseMock, "close", return_value=None) as mock_method:
183+
resp = ResponseMock(responses=[], response_cls=EchoResponse)
184+
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
185+
itr.cancel()
186+
mock_method.assert_called_once()
187+
188+
189+
def test_check_buffer():
190+
with patch.object(
191+
ResponseMock,
192+
"_parse_responses",
193+
return_value=bytes('[{"content": "hello"}, {', "utf-8"),
194+
):
195+
resp = ResponseMock(responses=[], response_cls=EchoResponse)
196+
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
197+
with pytest.raises(ValueError):
198+
next(itr)
199+
next(itr)
200+
201+
202+
def test_next_html():
203+
with patch.object(
204+
ResponseMock, "iter_content", return_value=iter("<!DOCTYPE html><html></html>")
205+
) as mock_method:
206+
207+
resp = ResponseMock(responses=[], response_cls=EchoResponse)
208+
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
209+
with pytest.raises(ValueError):
210+
next(itr)
211+
mock_method.assert_called_once()

0 commit comments

Comments
 (0)