Skip to content

Commit 29d662b

Browse files
committed
update response iterator from async auth
1 parent 3e58f20 commit 29d662b

File tree

4 files changed

+25
-75
lines changed

4 files changed

+25
-75
lines changed

google/api_core/rest_streaming_async.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from typing import Union
1818

1919
import proto
20-
import google.protobuf.message
2120
import google.auth.aio.transport
21+
import google.protobuf.message
2222
from google.api_core._rest_streaming_base import BaseResponseIterator
2323

2424

@@ -43,58 +43,18 @@ def __init__(
4343
):
4444
self._response = response
4545
self._chunk_size = 1024
46-
self._response_itr = None
46+
self._response_itr = self._response.content().__aiter__()
4747
super(AsyncResponseIterator, self).__init__(
4848
response_message_cls=response_message_cls
4949
)
5050

51-
async def _create_response_iter(self):
52-
53-
if not isinstance(self._response, google.auth.aio.transport.Response):
54-
raise ValueError(
55-
"Response must be of type google.auth.aio.transport.Response"
56-
)
57-
58-
# TODO (ohmayr): Ideally the response from auth should expose an attribute
59-
# to read streaming response iterator directly i.e.
60-
#
61-
# self -> google.auth.aio.transport.aiohttp.Response:
62-
# def stream_response_itr(self, chunk_size):
63-
# return self._response.content.iter_chunked(chunk_size)
64-
#
65-
# self -> google.auth.aio.transport.httpx.Response:
66-
# def stream_response_itr(self, chunk_size):
67-
# return self._response.aiter_raw(chunk_size)
68-
#
69-
# this way we can just call the property directly to get the appropriate
70-
# response iterator without having to deal with the underlying API differences
71-
# or alternatively, having to check the type of inherited response types here
72-
# i.e we could do: self._response_itr = self._response.stream_response_itr(self._chunk_size)
73-
74-
content = self._response.content
75-
if hasattr(content, "iter_chunked"):
76-
return content.iter_chunked(self._chunk_size)
77-
else:
78-
# TODO (ohmayr): since iter_chunked is only available in an instance of
79-
# google.auth.aio.transport.aiohttp.Response, we indirectly depend on
80-
# on the inherited class.
81-
raise ValueError(
82-
f"Unsupported Response type: {type(self._response)}. Expected google.auth.aio.transport.aiohttp.Response."
83-
)
84-
8551
async def cancel(self):
8652
"""Cancel existing streaming operation."""
8753
await self._response.close()
8854

8955
async def __anext__(self):
9056
while not self._ready_objs:
91-
try:
92-
if not self._response_itr:
93-
self._response_itr = await self._create_response_iter()
94-
# TODO (ohmayr): cleanup
95-
# content = await self._response.content
96-
# self._response_itr = content.iter_chunked(self._chunk_size)
97-
57+
try:
9858
chunk = await self._response_itr.__anext__()
9959
chunk = chunk.decode("utf-8")
10060
self._process_chunk(chunk)
@@ -110,5 +70,5 @@ def __aiter__(self):
11070
return self
11171

11272
async def __aexit__(self, exc_type, exc, tb):
113-
"""Cancel existing streaming operation."""
73+
"""Cancel existing async streaming operation."""
11474
await self._response.close()
File renamed without changes.

tests/unit/test_rest_streaming.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
from google.api import http_pb2
2828
from google.api import httpbody_pb2
2929

30-
# TODO (ohmayr): confirm if relative path is not an issue.
31-
from .._helpers import Composer, Song, EchoResponse, parse_responses
30+
from ..conftest import Composer, Song, EchoResponse, parse_responses
3231

3332

3433
__protobuf__ = proto.module(package=__name__)

tests/unit/test_rest_streaming_async.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,17 @@
1717
import random
1818
import time
1919
from typing import List, AsyncIterator
20-
from unittest import mock
20+
import mock
2121

2222
import proto
2323
import pytest
24-
from google.auth.aio.transport import Response
2524

2625
from google.api_core import rest_streaming_async
2726
from google.api import http_pb2
2827
from google.api import httpbody_pb2
28+
from google.auth.aio.transport import Response
2929

30-
# TODO (ohmayr): confirm if relative path is not an issue.
31-
from .._helpers import Composer, Song, EchoResponse, parse_responses
30+
from ..conftest import Composer, Song, EchoResponse, parse_responses
3231

3332
# TODO (ohmayr): check if we need to log.
3433
__protobuf__ = proto.module(package=__name__)
@@ -37,9 +36,10 @@
3736
random.seed(SEED)
3837

3938

40-
class AIOHTTPContentMock:
41-
def __init__(self, iter_chunked):
42-
self.iter_chunked = iter_chunked
39+
async def mock_async_gen(data, chunk_size=1):
40+
for i in range(0, len(data)):
41+
chunk = data[i : i + chunk_size]
42+
yield chunk.encode("utf-8")
4343

4444

4545
class ResponseMock(Response):
@@ -49,6 +49,9 @@ def __init__(self, _response_bytes: bytes, random_split=False):
4949
self._i = 0
5050
self._random_split = random_split
5151

52+
def __aiter__(self):
53+
return self
54+
5255
async def __anext__(self):
5356
if self._i == len(self._responses_bytes):
5457
raise StopAsyncIteration
@@ -74,13 +77,15 @@ def __init__(
7477
async def close(self):
7578
raise NotImplementedError()
7679

77-
@property
78-
def content(self):
79-
iter_chunked = lambda chunk_size: self._ResponseItr(
80-
self._parse_responses(),
81-
random_split=self._random_split,
80+
async def content(self, chunk_size=None):
81+
itr = self._ResponseItr(
82+
self._parse_responses(), random_split=self._random_split
8283
)
83-
return AIOHTTPContentMock(iter_chunked)
84+
async for chunk in itr:
85+
yield chunk
86+
87+
async def read(self):
88+
raise NotImplementedError()
8489

8590
@property
8691
async def headers(self):
@@ -273,14 +278,10 @@ async def test_next_escaped_characters_in_string(
273278
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
274279
async def test_next_not_array(response_type):
275280

276-
mock_content = mock.Mock()
277281
data = '{"hello": 0}'
278-
mock_content.iter_chunked = lambda chunk_size: async_iter(data, chunk_size)
279-
280282
with mock.patch.object(
281-
ResponseMock, "content", new_callable=mock.PropertyMock
283+
ResponseMock, "content", return_value=mock_async_gen(data)
282284
) as mock_method:
283-
mock_method.return_value = mock_content
284285
resp = ResponseMock(responses=[], response_cls=response_type)
285286
itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
286287
with pytest.raises(ValueError):
@@ -321,24 +322,14 @@ async def test_check_buffer(response_type, return_value):
321322
await itr.__anext__()
322323

323324

324-
async def async_iter(data, chunk_size):
325-
for i in range(0, len(data) + chunk_size):
326-
chunk = data[i : i + chunk_size]
327-
yield chunk.encode("utf-8")
328-
329-
330325
@pytest.mark.asyncio
331326
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
332327
async def test_next_html(response_type):
333328

334-
mock_content = mock.Mock()
335329
data = "<!DOCTYPE html><html></html>"
336-
mock_content.iter_chunked = lambda chunk_size: async_iter(data, chunk_size)
337-
338330
with mock.patch.object(
339-
ResponseMock, "content", new_callable=mock.PropertyMock
331+
ResponseMock, "content", return_value=mock_async_gen(data)
340332
) as mock_method:
341-
mock_method.return_value = mock_content
342333
resp = ResponseMock(responses=[], response_cls=response_type)
343334

344335
itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)

0 commit comments

Comments
 (0)