17
17
import random
18
18
import time
19
19
from typing import List , AsyncIterator
20
- from unittest import mock
20
+ import mock
21
21
22
22
import proto
23
23
import pytest
24
- from google .auth .aio .transport import Response
25
24
26
25
from google .api_core import rest_streaming_async
27
26
from google .api import http_pb2
28
27
from google .api import httpbody_pb2
28
+ from google .auth .aio .transport import Response
29
29
30
- # TODO (ohmayr): confirm if relative path is not an issue.
31
- from .._helpers import Composer , Song , EchoResponse , parse_responses
30
+ from ..conftest import Composer , Song , EchoResponse , parse_responses
32
31
33
32
# TODO (ohmayr): check if we need to log.
34
33
__protobuf__ = proto .module (package = __name__ )
37
36
random .seed (SEED )
38
37
39
38
40
- class AIOHTTPContentMock :
41
- def __init__ (self , iter_chunked ):
42
- self .iter_chunked = iter_chunked
39
+ async def mock_async_gen (data , chunk_size = 1 ):
40
+ for i in range (0 , len (data )):
41
+ chunk = data [i : i + chunk_size ]
42
+ yield chunk .encode ("utf-8" )
43
43
44
44
45
45
class ResponseMock (Response ):
@@ -49,6 +49,9 @@ def __init__(self, _response_bytes: bytes, random_split=False):
49
49
self ._i = 0
50
50
self ._random_split = random_split
51
51
52
+ def __aiter__ (self ):
53
+ return self
54
+
52
55
async def __anext__ (self ):
53
56
if self ._i == len (self ._responses_bytes ):
54
57
raise StopAsyncIteration
@@ -74,13 +77,15 @@ def __init__(
74
77
async def close (self ):
75
78
raise NotImplementedError ()
76
79
77
- @property
78
- def content (self ):
79
- iter_chunked = lambda chunk_size : self ._ResponseItr (
80
- self ._parse_responses (),
81
- random_split = self ._random_split ,
80
+ async def content (self , chunk_size = None ):
81
+ itr = self ._ResponseItr (
82
+ self ._parse_responses (), random_split = self ._random_split
82
83
)
83
- return AIOHTTPContentMock (iter_chunked )
84
+ async for chunk in itr :
85
+ yield chunk
86
+
87
+ async def read (self ):
88
+ raise NotImplementedError ()
84
89
85
90
@property
86
91
async def headers (self ):
@@ -273,14 +278,10 @@ async def test_next_escaped_characters_in_string(
273
278
@pytest .mark .parametrize ("response_type" , [EchoResponse , httpbody_pb2 .HttpBody ])
274
279
async def test_next_not_array (response_type ):
275
280
276
- mock_content = mock .Mock ()
277
281
data = '{"hello": 0}'
278
- mock_content .iter_chunked = lambda chunk_size : async_iter (data , chunk_size )
279
-
280
282
with mock .patch .object (
281
- ResponseMock , "content" , new_callable = mock . PropertyMock
283
+ ResponseMock , "content" , return_value = mock_async_gen ( data )
282
284
) as mock_method :
283
- mock_method .return_value = mock_content
284
285
resp = ResponseMock (responses = [], response_cls = response_type )
285
286
itr = rest_streaming_async .AsyncResponseIterator (resp , response_type )
286
287
with pytest .raises (ValueError ):
@@ -321,24 +322,14 @@ async def test_check_buffer(response_type, return_value):
321
322
await itr .__anext__ ()
322
323
323
324
324
- async def async_iter (data , chunk_size ):
325
- for i in range (0 , len (data ) + chunk_size ):
326
- chunk = data [i : i + chunk_size ]
327
- yield chunk .encode ("utf-8" )
328
-
329
-
330
325
@pytest .mark .asyncio
331
326
@pytest .mark .parametrize ("response_type" , [EchoResponse , httpbody_pb2 .HttpBody ])
332
327
async def test_next_html (response_type ):
333
328
334
- mock_content = mock .Mock ()
335
329
data = "<!DOCTYPE html><html></html>"
336
- mock_content .iter_chunked = lambda chunk_size : async_iter (data , chunk_size )
337
-
338
330
with mock .patch .object (
339
- ResponseMock , "content" , new_callable = mock . PropertyMock
331
+ ResponseMock , "content" , return_value = mock_async_gen ( data )
340
332
) as mock_method :
341
- mock_method .return_value = mock_content
342
333
resp = ResponseMock (responses = [], response_cls = response_type )
343
334
344
335
itr = rest_streaming_async .AsyncResponseIterator (resp , response_type )
0 commit comments