From d53588ccdf879c2a8240517431269b616dd0c416 Mon Sep 17 00:00:00 2001 From: Aza Tulepbergenov Date: Thu, 9 Dec 2021 15:59:28 -0800 Subject: [PATCH 1/8] feat: files for REST streaming. --- google/api_core/rest_streaming.py | 99 ++++++++++++++++++ tests/unit/test_rest_streaming.py | 160 ++++++++++++++++++++++++++++++ 2 files changed, 259 insertions(+) create mode 100644 google/api_core/rest_streaming.py create mode 100644 tests/unit/test_rest_streaming.py diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py new file mode 100644 index 00000000..6b288c06 --- /dev/null +++ b/google/api_core/rest_streaming.py @@ -0,0 +1,99 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for server-side streaming in REST.""" + +import json +import string + +import requests + + +class ResponseIterator: + """Iterator over REST API responses. + + Args: + response (requests.Response): An API response object. + response_message_cls (Callable[proto.Message]): A proto + class expected to be returned from an API. + """ + + def __init__(self, response: requests.Response, response_message_cls): + self._response = response + self._response_message_cls = response_message_cls + # Inner iterator over HTTP response's content. + self._response_itr = self._response.iter_content(decode_unicode=True) + # Contains a list of JSON responses ready to be sent to user. + self._ready_objs = [] + # Current JSON response being built. + self._obj = "" + # Keeps track of the nesting level within a JSON object. + self._level = 0 + # Keeps track whether HTTP response is currently sending values + # inside of a string value. + self._in_string = False + + def cancel(self): + """Cancel existing streaming operation. + """ + self._response.close() + + def _process_chunk(self, chunk: str): + if self._level == 0: + if chunk[0] != "[": + raise ValueError( + "Can only parse array of JSON objects, instead got %s" % chunk + ) + for char in chunk: + if char == "{": + if self._level == 1: + # Level 1 corresponds to the outermost JSON object + # (i.e. the one we care about). + self._obj = "" + if not self._in_string: + self._level += 1 + self._obj += char + elif char == '"': + self._in_string = not self._in_string + self._obj += char + elif char == "}": + self._obj += char + if not self._in_string: + self._level -= 1 + if not self._in_string and self._level == 1: + self._ready_objs.append(self._obj) + elif char in string.whitespace: + if self._in_string: + self._obj += char + elif char == "[": + self._level += 1 + elif char == "]": + self._level -= 1 + else: + self._obj += char + + def __next__(self): + while not self._ready_objs: + chunk = next(self._response_itr) + self._process_chunk(chunk) + return self._grab() + + def _grab(self): + obj = self._ready_objs[0] + self._ready_objs = self._ready_objs[1:] + # Add extra quotes to make json.loads happy. + return self._response_message_cls.from_json(json.loads('"' + obj + '"')) + + def __iter__(self): + return self diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py new file mode 100644 index 00000000..e0d5d304 --- /dev/null +++ b/tests/unit/test_rest_streaming.py @@ -0,0 +1,160 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import random +from typing import List +from unittest.mock import patch + +import proto +import pytest +import requests + +from google.api_core import rest_streaming + + +class Composer(proto.Message): + given_name = proto.Field(proto.STRING, number=1) + family_name = proto.Field(proto.STRING, number=2) + + +class Song(proto.Message): + composer = proto.Field(Composer, number=1) + title = proto.Field(proto.STRING, number=2) + lyrics = proto.Field(proto.STRING, number=3) + year = proto.Field(proto.INT32, number=4) + + +class EchoResponse(proto.Message): + content = proto.Field(proto.STRING, number=1) + + +class ResponseMock(requests.Response): + class _ResponseItr: + def __init__( + self, + _response_bytes: bytes, + decode_unicode: bool, + random_split=False, + seed=0, + ): + self._responses_bytes = _response_bytes + self._i = 0 + self._decode_unicode = decode_unicode + self._random_split = random_split + random.seed(seed) + + def __next__(self): + if self._i == len(self._responses_bytes): + raise StopIteration + if self._random_split: + n = random.randint(1, len(self._responses_bytes[self._i :])) + else: + n = 1 + x = self._responses_bytes[self._i : self._i + n] + self._i += n + if self._decode_unicode: + x = x.decode("utf-8") + return x + + def __iter__(self): + return self + + def __init__( + self, + *args, + responses: List[proto.Message], + random_split=False, + response_cls, + **kwargs + ): + super().__init__(*args, **kwargs) + self._responses = responses + self._random_split = random_split + self._response_message_cls = response_cls + + def _parse_responses(self, responses: List[proto.Message]) -> bytes: + ret_val = "[" + # json.dumps returns a string surrounded with quotes that need to be stripped + # in order to be an actual JSON. + json_responses = [ + json.dumps(self._response_message_cls.to_json(r))[1:-1] for r in responses + ] + for x in json_responses: + ret_val += x + ret_val += "," + ret_val = ret_val[:-1] # Remove trailing comma. + ret_val += "]" + return bytes(ret_val, "utf-8") + + def close(self): + raise NotImplementedError() + + def iter_content(self, *args, decode_unicode=True, **kwargs): + return self._ResponseItr( + self._parse_responses(self._responses), + decode_unicode, + random_split=self._random_split, + ) + + +def test_next_simple(): + responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")] + resp = ResponseMock( + responses=responses, random_split=False, response_cls=EchoResponse + ) + itr = rest_streaming.ResponseIterator(resp, EchoResponse) + assert list(itr) == responses + + +def test_next_nested(): + responses = [ + Song(title="some song", composer=Composer(given_name="some name")), + Song(title="another song"), + ] + resp = ResponseMock(responses=responses, random_split=True, response_cls=Song) + itr = rest_streaming.ResponseIterator(resp, Song) + assert list(itr) == responses + + +def test_next_stress(): + n = 50 + responses = [ + Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i)) + for i in range(n) + ] + resp = ResponseMock(responses=responses, random_split=True, response_cls=Song) + itr = rest_streaming.ResponseIterator(resp, Song) + assert list(itr) == responses + + +def test_next_escaped_characters_in_string(): + responses = [ + Song(title="title\nfoo\tbar", composer=Composer(given_name="name\n\n\n")) + ] + resp = ResponseMock(responses=responses, random_split=True, response_cls=Song) + itr = rest_streaming.ResponseIterator(resp, Song) + assert list(itr) == responses + + +def test_next_not_array(): + with patch.object( + ResponseMock, "iter_content", return_value=iter('{"hello": 0}') + ) as mock_method: + + resp = ResponseMock(responses=[], response_cls=EchoResponse) + itr = rest_streaming.ResponseIterator(resp, EchoResponse) + with pytest.raises(ValueError): + next(itr) + mock_method.assert_called_once() From bdf469842c0eff8096fc1d280d6b210fd066d369 Mon Sep 17 00:00:00 2001 From: Aza Tulepbergenov Date: Mon, 13 Dec 2021 15:33:22 -0800 Subject: [PATCH 2/8] chore: adds more test cases and fixes. --- google/api_core/rest_streaming.py | 10 ++++-- tests/unit/test_rest_streaming.py | 52 ++++++++++++++++++++++--------- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py index 6b288c06..cec4ddf8 100644 --- a/google/api_core/rest_streaming.py +++ b/google/api_core/rest_streaming.py @@ -77,9 +77,15 @@ def _process_chunk(self, chunk: str): if self._in_string: self._obj += char elif char == "[": - self._level += 1 + if self._level == 0: + self._level += 1 + else: + self._obj += char elif char == "]": - self._level -= 1 + if self._level == 1: + self._level -= 1 + else: + self._obj += char else: self._obj += char diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index e0d5d304..0bb10157 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -24,9 +24,18 @@ from google.api_core import rest_streaming +class Genre(proto.Enum): + GENRE_UNSPECIFIED = 0 + CLASSICAL = 1 + JAZZ = 2 + ROCK = 3 + + class Composer(proto.Message): given_name = proto.Field(proto.STRING, number=1) family_name = proto.Field(proto.STRING, number=2) + relateds = proto.RepeatedField(proto.STRING, number=3) + indices = proto.MapField(proto.STRING, proto.STRING, number=4) class Song(proto.Message): @@ -34,6 +43,10 @@ class Song(proto.Message): title = proto.Field(proto.STRING, number=2) lyrics = proto.Field(proto.STRING, number=3) year = proto.Field(proto.INT32, number=4) + genre = proto.Field(Genre, number=5) + is_five_mins_longer = proto.Field(proto.BOOL, number=6) + score = proto.Field(proto.DOUBLE, number=7) + likes = proto.Field(proto.INT64, number=8) class EchoResponse(proto.Message): @@ -43,15 +56,10 @@ class EchoResponse(proto.Message): class ResponseMock(requests.Response): class _ResponseItr: def __init__( - self, - _response_bytes: bytes, - decode_unicode: bool, - random_split=False, - seed=0, + self, _response_bytes: bytes, random_split=False, seed=0, ): self._responses_bytes = _response_bytes self._i = 0 - self._decode_unicode = decode_unicode self._random_split = random_split random.seed(seed) @@ -64,9 +72,7 @@ def __next__(self): n = 1 x = self._responses_bytes[self._i : self._i + n] self._i += n - if self._decode_unicode: - x = x.decode("utf-8") - return x + return x.decode("utf-8") def __iter__(self): return self @@ -101,11 +107,9 @@ def _parse_responses(self, responses: List[proto.Message]) -> bytes: def close(self): raise NotImplementedError() - def iter_content(self, *args, decode_unicode=True, **kwargs): + def iter_content(self, *args, **kwargs): return self._ResponseItr( - self._parse_responses(self._responses), - decode_unicode, - random_split=self._random_split, + self._parse_responses(self._responses), random_split=self._random_split, ) @@ -119,6 +123,16 @@ def test_next_simple(): def test_next_nested(): + responses = [ + Song(title="some song", composer=Composer(given_name="some name")), + Song(title="another song"), + ] + resp = ResponseMock(responses=responses, random_split=False, response_cls=Song) + itr = rest_streaming.ResponseIterator(resp, Song) + assert list(itr) == responses + + +def test_next_random(): responses = [ Song(title="some song", composer=Composer(given_name="some name")), Song(title="another song"), @@ -141,7 +155,7 @@ def test_next_stress(): def test_next_escaped_characters_in_string(): responses = [ - Song(title="title\nfoo\tbar", composer=Composer(given_name="name\n\n\n")) + Song(title="title\nfoo\tbar{}", composer=Composer(given_name="name\n\n\n")) ] resp = ResponseMock(responses=responses, random_split=True, response_cls=Song) itr = rest_streaming.ResponseIterator(resp, Song) @@ -158,3 +172,13 @@ def test_next_not_array(): with pytest.raises(ValueError): next(itr) mock_method.assert_called_once() + + +def test_cancel(): + with patch.object( + rest_streaming.ResponseIterator, "cancel", return_value=None + ) as mock_method: + resp = ResponseMock(responses=[], response_cls=EchoResponse) + itr = rest_streaming.ResponseIterator(resp, EchoResponse) + itr.cancel() + mock_method.assert_called_once() From ffe4a568e3747caa852a9522f6c82975e40ffba2 Mon Sep 17 00:00:00 2001 From: Aza Tulepbergenov Date: Mon, 13 Dec 2021 15:45:38 -0800 Subject: [PATCH 3/8] chore: fixes coverage. --- tests/unit/test_rest_streaming.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index 0bb10157..3485e307 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -74,8 +74,6 @@ def __next__(self): self._i += n return x.decode("utf-8") - def __iter__(self): - return self def __init__( self, @@ -176,7 +174,7 @@ def test_next_not_array(): def test_cancel(): with patch.object( - rest_streaming.ResponseIterator, "cancel", return_value=None + ResponseMock, "close", return_value=None ) as mock_method: resp = ResponseMock(responses=[], response_cls=EchoResponse) itr = rest_streaming.ResponseIterator(resp, EchoResponse) From f2213ab0af1289f8878507cf4b2f1e506d5b0249 Mon Sep 17 00:00:00 2001 From: Aza Tulepbergenov Date: Tue, 14 Dec 2021 12:16:32 -0800 Subject: [PATCH 4/8] chore: pr comments. --- google/api_core/rest_streaming.py | 10 +++-- tests/unit/test_rest_streaming.py | 71 ++++++++++++++----------------- 2 files changed, 38 insertions(+), 43 deletions(-) diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py index cec4ddf8..a6310dd6 100644 --- a/google/api_core/rest_streaming.py +++ b/google/api_core/rest_streaming.py @@ -14,8 +14,10 @@ """Helpers for server-side streaming in REST.""" +from collections import deque import json import string +from typing import Deque import requests @@ -35,7 +37,7 @@ def __init__(self, response: requests.Response, response_message_cls): # Inner iterator over HTTP response's content. self._response_itr = self._response.iter_content(decode_unicode=True) # Contains a list of JSON responses ready to be sent to user. - self._ready_objs = [] + self._ready_objs: Deque[str] = deque() # Current JSON response being built. self._obj = "" # Keeps track of the nesting level within a JSON object. @@ -96,10 +98,10 @@ def __next__(self): return self._grab() def _grab(self): - obj = self._ready_objs[0] - self._ready_objs = self._ready_objs[1:] # Add extra quotes to make json.loads happy. - return self._response_message_cls.from_json(json.loads('"' + obj + '"')) + return self._response_message_cls.from_json( + json.loads('"' + self._ready_objs.popleft() + '"') + ) def __iter__(self): return self diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index 3485e307..a0022493 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import logging import random from typing import List from unittest.mock import patch @@ -22,6 +23,13 @@ import requests from google.api_core import rest_streaming +from google.protobuf import duration_pb2 +from google.protobuf import timestamp_pb2 + + +SEED = 0 +logging.info(f"Starting rest streaming tests with random seed: {SEED}") +random.seed(SEED) class Genre(proto.Enum): @@ -47,6 +55,8 @@ class Song(proto.Message): is_five_mins_longer = proto.Field(proto.BOOL, number=6) score = proto.Field(proto.DOUBLE, number=7) likes = proto.Field(proto.INT64, number=8) + duration = proto.Field(duration_pb2.Duration, number=9) + date_added = proto.Field(timestamp_pb2.Timestamp, number=10) class EchoResponse(proto.Message): @@ -55,13 +65,10 @@ class EchoResponse(proto.Message): class ResponseMock(requests.Response): class _ResponseItr: - def __init__( - self, _response_bytes: bytes, random_split=False, seed=0, - ): + def __init__(self, _response_bytes: bytes, random_split=False): self._responses_bytes = _response_bytes self._i = 0 self._random_split = random_split - random.seed(seed) def __next__(self): if self._i == len(self._responses_bytes): @@ -74,32 +81,22 @@ def __next__(self): self._i += n return x.decode("utf-8") - def __init__( - self, - *args, - responses: List[proto.Message], - random_split=False, - response_cls, - **kwargs + self, responses: List[proto.Message], response_cls, random_split=False, ): - super().__init__(*args, **kwargs) + super().__init__() self._responses = responses self._random_split = random_split self._response_message_cls = response_cls def _parse_responses(self, responses: List[proto.Message]) -> bytes: - ret_val = "[" # json.dumps returns a string surrounded with quotes that need to be stripped # in order to be an actual JSON. json_responses = [ - json.dumps(self._response_message_cls.to_json(r))[1:-1] for r in responses + json.dumps(self._response_message_cls.to_json(r)).strip('"') + for r in responses ] - for x in json_responses: - ret_val += x - ret_val += "," - ret_val = ret_val[:-1] # Remove trailing comma. - ret_val += "]" + ret_val = "[{}]".format(",".join(json_responses)) return bytes(ret_val, "utf-8") def close(self): @@ -111,36 +108,31 @@ def iter_content(self, *args, **kwargs): ) -def test_next_simple(): +@pytest.mark.parametrize("random_split", [True, False]) +def test_next_simple(random_split): responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")] resp = ResponseMock( - responses=responses, random_split=False, response_cls=EchoResponse + responses=responses, random_split=random_split, response_cls=EchoResponse ) itr = rest_streaming.ResponseIterator(resp, EchoResponse) assert list(itr) == responses -def test_next_nested(): +@pytest.mark.parametrize("random_split", [True, False]) +def test_next_nested(random_split): responses = [ Song(title="some song", composer=Composer(given_name="some name")), Song(title="another song"), ] - resp = ResponseMock(responses=responses, random_split=False, response_cls=Song) - itr = rest_streaming.ResponseIterator(resp, Song) - assert list(itr) == responses - - -def test_next_random(): - responses = [ - Song(title="some song", composer=Composer(given_name="some name")), - Song(title="another song"), - ] - resp = ResponseMock(responses=responses, random_split=True, response_cls=Song) + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=Song + ) itr = rest_streaming.ResponseIterator(resp, Song) assert list(itr) == responses -def test_next_stress(): +@pytest.mark.parametrize("random_split", [True, False]) +def test_next_stress(random_split): n = 50 responses = [ Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i)) @@ -151,11 +143,14 @@ def test_next_stress(): assert list(itr) == responses -def test_next_escaped_characters_in_string(): +@pytest.mark.parametrize("random_split", [True, False]) +def test_next_escaped_characters_in_string(random_split): responses = [ Song(title="title\nfoo\tbar{}", composer=Composer(given_name="name\n\n\n")) ] - resp = ResponseMock(responses=responses, random_split=True, response_cls=Song) + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=Song + ) itr = rest_streaming.ResponseIterator(resp, Song) assert list(itr) == responses @@ -173,9 +168,7 @@ def test_next_not_array(): def test_cancel(): - with patch.object( - ResponseMock, "close", return_value=None - ) as mock_method: + with patch.object(ResponseMock, "close", return_value=None) as mock_method: resp = ResponseMock(responses=[], response_cls=EchoResponse) itr = rest_streaming.ResponseIterator(resp, EchoResponse) itr.cancel() From 3f39b67b94f74a5145b4f4a2a2bb5aa7805b94ac Mon Sep 17 00:00:00 2001 From: Aza Tulepbergenov Date: Thu, 16 Dec 2021 13:33:10 -0800 Subject: [PATCH 5/8] chore: pr comments. --- google/api_core/rest_streaming.py | 39 +++++++++++++++-------- tests/unit/test_rest_streaming.py | 51 ++++++++++++++++++++++++++----- 2 files changed, 69 insertions(+), 21 deletions(-) diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py index a6310dd6..7a202b2c 100644 --- a/google/api_core/rest_streaming.py +++ b/google/api_core/rest_streaming.py @@ -15,7 +15,6 @@ """Helpers for server-side streaming in REST.""" from collections import deque -import json import string from typing import Deque @@ -45,6 +44,8 @@ def __init__(self, response: requests.Response, response_message_cls): # Keeps track whether HTTP response is currently sending values # inside of a string value. self._in_string = False + # Whether an escape symbol "\" was encountered. + self._next_should_be_escaped = False def cancel(self): """Cancel existing streaming operation. @@ -62,19 +63,19 @@ def _process_chunk(self, chunk: str): if self._level == 1: # Level 1 corresponds to the outermost JSON object # (i.e. the one we care about). - self._obj = "" + self._obj = char if not self._in_string: self._level += 1 - self._obj += char - elif char == '"': - self._in_string = not self._in_string - self._obj += char elif char == "}": - self._obj += char if not self._in_string: self._level -= 1 if not self._in_string and self._level == 1: - self._ready_objs.append(self._obj) + self._ready_objs.append(self._obj + char) + elif char == '"': + # Helps to deal with an escaped quotes inside of a string. + if not self._next_should_be_escaped: + self._in_string = not self._in_string + self._obj += char elif char in string.whitespace: if self._in_string: self._obj += char @@ -91,17 +92,29 @@ def _process_chunk(self, chunk: str): else: self._obj += char + if char == "\\": + # Escaping the "\". + if self._next_should_be_escaped: + self._next_should_be_escaped = False + else: + self._next_should_be_escaped = True + else: + self._next_should_be_escaped = False + def __next__(self): while not self._ready_objs: - chunk = next(self._response_itr) - self._process_chunk(chunk) + try: + chunk = next(self._response_itr) + self._process_chunk(chunk) + except StopIteration as e: + if self._level > 0: + raise ValueError("Unfinished stream: %s" % self._obj) + raise e return self._grab() def _grab(self): # Add extra quotes to make json.loads happy. - return self._response_message_cls.from_json( - json.loads('"' + self._ready_objs.popleft() + '"') - ) + return self._response_message_cls.from_json(self._ready_objs.popleft()) def __iter__(self): return self diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index a0022493..2a771c54 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +import datetime import logging import random +import time from typing import List from unittest.mock import patch @@ -27,7 +28,7 @@ from google.protobuf import timestamp_pb2 -SEED = 0 +SEED = int(time.time()) logging.info(f"Starting rest streaming tests with random seed: {SEED}") random.seed(SEED) @@ -93,8 +94,7 @@ def _parse_responses(self, responses: List[proto.Message]) -> bytes: # json.dumps returns a string surrounded with quotes that need to be stripped # in order to be an actual JSON. json_responses = [ - json.dumps(self._response_message_cls.to_json(r)).strip('"') - for r in responses + self._response_message_cls.to_json(r).strip('"') for r in responses ] ret_val = "[{}]".format(",".join(json_responses)) return bytes(ret_val, "utf-8") @@ -108,7 +108,7 @@ def iter_content(self, *args, **kwargs): ) -@pytest.mark.parametrize("random_split", [True, False]) +@pytest.mark.parametrize("random_split", [False]) def test_next_simple(random_split): responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")] resp = ResponseMock( @@ -122,7 +122,7 @@ def test_next_simple(random_split): def test_next_nested(random_split): responses = [ Song(title="some song", composer=Composer(given_name="some name")), - Song(title="another song"), + Song(title="another song", duration=datetime.datetime()), ] resp = ResponseMock( responses=responses, random_split=random_split, response_cls=Song @@ -138,15 +138,25 @@ def test_next_stress(random_split): Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i)) for i in range(n) ] - resp = ResponseMock(responses=responses, random_split=True, response_cls=Song) + resp = ResponseMock( + responses=responses, random_split=random_split, response_cls=Song + ) itr = rest_streaming.ResponseIterator(resp, Song) assert list(itr) == responses @pytest.mark.parametrize("random_split", [True, False]) def test_next_escaped_characters_in_string(random_split): + composer_with_relateds = Composer() + relateds = ["Artist A", "Artist B"] + composer_with_relateds.relateds = relateds + responses = [ - Song(title="title\nfoo\tbar{}", composer=Composer(given_name="name\n\n\n")) + Song(title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")), + Song( + title='{"this is weird": "totally"}', composer=Composer(given_name="\\{}\\") + ), + Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds), ] resp = ResponseMock( responses=responses, random_split=random_split, response_cls=Song @@ -173,3 +183,28 @@ def test_cancel(): itr = rest_streaming.ResponseIterator(resp, EchoResponse) itr.cancel() mock_method.assert_called_once() + + +def test_check_buffer(): + with patch.object( + ResponseMock, + "_parse_responses", + return_value=bytes('[{"content": "hello"}, {', "utf-8"), + ): + resp = ResponseMock(responses=[], response_cls=EchoResponse) + itr = rest_streaming.ResponseIterator(resp, EchoResponse) + with pytest.raises(ValueError): + next(itr) + next(itr) + + +def test_next_html(): + with patch.object( + ResponseMock, "iter_content", return_value=iter("") + ) as mock_method: + + resp = ResponseMock(responses=[], response_cls=EchoResponse) + itr = rest_streaming.ResponseIterator(resp, EchoResponse) + with pytest.raises(ValueError): + next(itr) + mock_method.assert_called_once() From ac7ef79110a097ad418930ef678287003fade714 Mon Sep 17 00:00:00 2001 From: Aza Tulepbergenov Date: Thu, 16 Dec 2021 13:50:55 -0800 Subject: [PATCH 6/8] chore: fixes bugs. --- google/api_core/rest_streaming.py | 7 +++++-- tests/unit/test_rest_streaming.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py index 7a202b2c..6ca29f26 100644 --- a/google/api_core/rest_streaming.py +++ b/google/api_core/rest_streaming.py @@ -63,14 +63,16 @@ def _process_chunk(self, chunk: str): if self._level == 1: # Level 1 corresponds to the outermost JSON object # (i.e. the one we care about). - self._obj = char + self._obj = "" if not self._in_string: self._level += 1 + self._obj += char elif char == "}": + self._obj += char if not self._in_string: self._level -= 1 if not self._in_string and self._level == 1: - self._ready_objs.append(self._obj + char) + self._ready_objs.append(self._obj) elif char == '"': # Helps to deal with an escaped quotes inside of a string. if not self._next_should_be_escaped: @@ -114,6 +116,7 @@ def __next__(self): def _grab(self): # Add extra quotes to make json.loads happy. + print(self._ready_objs) return self._response_message_cls.from_json(self._ready_objs.popleft()) def __iter__(self): diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index 2a771c54..4be59580 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -96,6 +96,7 @@ def _parse_responses(self, responses: List[proto.Message]) -> bytes: json_responses = [ self._response_message_cls.to_json(r).strip('"') for r in responses ] + logging.info(f"Sending JSON stream: {json_responses}") ret_val = "[{}]".format(",".join(json_responses)) return bytes(ret_val, "utf-8") @@ -122,7 +123,7 @@ def test_next_simple(random_split): def test_next_nested(random_split): responses = [ Song(title="some song", composer=Composer(given_name="some name")), - Song(title="another song", duration=datetime.datetime()), + Song(title="another song", date_added=datetime.datetime(2021, 12, 17)), ] resp = ResponseMock( responses=responses, random_split=random_split, response_cls=Song From 1f8a9c0c6734becea47e9859ca6c2825ed26a379 Mon Sep 17 00:00:00 2001 From: Aza Tulepbergenov Date: Thu, 16 Dec 2021 13:52:26 -0800 Subject: [PATCH 7/8] chore: removes print statement. --- google/api_core/rest_streaming.py | 1 - 1 file changed, 1 deletion(-) diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py index 6ca29f26..0db66ed0 100644 --- a/google/api_core/rest_streaming.py +++ b/google/api_core/rest_streaming.py @@ -116,7 +116,6 @@ def __next__(self): def _grab(self): # Add extra quotes to make json.loads happy. - print(self._ready_objs) return self._response_message_cls.from_json(self._ready_objs.popleft()) def __iter__(self): From b4936f7da6b4cf71d93613f75a42475c9fa76ca5 Mon Sep 17 00:00:00 2001 From: Aza Tulepbergenov Date: Fri, 7 Jan 2022 20:10:11 -0800 Subject: [PATCH 8/8] feat: small refactor. --- google/api_core/rest_streaming.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py index 0db66ed0..69f5b41b 100644 --- a/google/api_core/rest_streaming.py +++ b/google/api_core/rest_streaming.py @@ -45,7 +45,7 @@ def __init__(self, response: requests.Response, response_message_cls): # inside of a string value. self._in_string = False # Whether an escape symbol "\" was encountered. - self._next_should_be_escaped = False + self._escape_next = False def cancel(self): """Cancel existing streaming operation. @@ -75,7 +75,7 @@ def _process_chunk(self, chunk: str): self._ready_objs.append(self._obj) elif char == '"': # Helps to deal with an escaped quotes inside of a string. - if not self._next_should_be_escaped: + if not self._escape_next: self._in_string = not self._in_string self._obj += char elif char in string.whitespace: @@ -93,15 +93,7 @@ def _process_chunk(self, chunk: str): self._obj += char else: self._obj += char - - if char == "\\": - # Escaping the "\". - if self._next_should_be_escaped: - self._next_should_be_escaped = False - else: - self._next_should_be_escaped = True - else: - self._next_should_be_escaped = False + self._escape_next = not self._escape_next if char == "\\" else False def __next__(self): while not self._ready_objs: