25
25
# --------------------------------------------------------------------------
26
26
from typing import Any , Optional , AsyncIterator as AsyncIteratorType
27
27
from collections .abc import AsyncIterator
28
+ try :
29
+ import cchardet as chardet
30
+ except ImportError : # pragma: no cover
31
+ import chardet # type: ignore
28
32
29
33
import logging
30
34
import asyncio
35
+ import codecs
31
36
import aiohttp
32
37
from multidict import CIMultiDict
33
38
from requests .exceptions import StreamConsumedError
@@ -66,7 +71,7 @@ class AioHttpTransport(AsyncHttpTransport):
66
71
:dedent: 4
67
72
:caption: Asynchronous transport with aiohttp.
68
73
"""
69
- def __init__ (self , * , session = None , loop = None , session_owner = True , ** kwargs ):
74
+ def __init__ (self , * , session : Optional [ aiohttp . ClientSession ] = None , loop = None , session_owner = True , ** kwargs ):
70
75
self ._loop = loop
71
76
self ._session_owner = session_owner
72
77
self .session = session
@@ -145,6 +150,11 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
145
150
:keyword str proxy: will define the proxy to use all the time
146
151
"""
147
152
await self .open ()
153
+ try :
154
+ auto_decompress = self .session .auto_decompress # type: ignore
155
+ except AttributeError :
156
+ # auto_decompress is introduced in Python 3.7. We need this to handle Python 3.6.
157
+ auto_decompress = True
148
158
149
159
proxies = config .pop ('proxies' , None )
150
160
if proxies and 'proxy' not in config :
@@ -171,7 +181,7 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
171
181
timeout = config .pop ('connection_timeout' , self .connection_config .timeout )
172
182
read_timeout = config .pop ('read_timeout' , self .connection_config .read_timeout )
173
183
socket_timeout = aiohttp .ClientTimeout (sock_connect = timeout , sock_read = read_timeout )
174
- result = await self .session .request (
184
+ result = await self .session .request ( # type: ignore
175
185
request .method ,
176
186
request .url ,
177
187
headers = request .headers ,
@@ -180,7 +190,9 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
180
190
allow_redirects = False ,
181
191
** config
182
192
)
183
- response = AioHttpTransportResponse (request , result , self .connection_config .data_block_size )
193
+ response = AioHttpTransportResponse (request , result ,
194
+ self .connection_config .data_block_size ,
195
+ decompress = not auto_decompress )
184
196
if not stream_response :
185
197
await response .load_body ()
186
198
except aiohttp .client_exceptions .ClientResponseError as err :
@@ -196,17 +208,15 @@ class AioHttpStreamDownloadGenerator(AsyncIterator):
196
208
197
209
:param pipeline: The pipeline object
198
210
:param response: The client response object.
199
- :keyword bool decompress: If True which is default, will attempt to decode the body based
200
- on the ‘ content-encoding’ header.
211
+ :param bool decompress: If True which is default, will attempt to decode the body based
212
+ on the * content-encoding* header.
201
213
"""
202
- def __init__ (self , pipeline : Pipeline , response : AsyncHttpResponse , ** kwargs ) -> None :
214
+ def __init__ (self , pipeline : Pipeline , response : AsyncHttpResponse , * , decompress = True ) -> None :
203
215
self .pipeline = pipeline
204
216
self .request = response .request
205
217
self .response = response
206
218
self .block_size = response .block_size
207
- self ._decompress = kwargs .pop ("decompress" , True )
208
- if len (kwargs ) > 0 :
209
- raise TypeError ("Got an unexpected keyword argument: {}" .format (list (kwargs .keys ())[0 ]))
219
+ self ._decompress = decompress
210
220
self .content_length = int (response .internal_response .headers .get ('Content-Length' , 0 ))
211
221
self ._decompressor = None
212
222
@@ -250,21 +260,41 @@ class AioHttpTransportResponse(AsyncHttpResponse):
250
260
:type aiohttp_response: aiohttp.ClientResponse object
251
261
:param block_size: block size of data sent over connection.
252
262
:type block_size: int
263
+ :param bool decompress: If True which is default, will attempt to decode the body based
264
+ on the *content-encoding* header.
253
265
"""
254
- def __init__ (self , request : HttpRequest , aiohttp_response : aiohttp .ClientResponse , block_size = None ) -> None :
266
+ def __init__ (self , request : HttpRequest ,
267
+ aiohttp_response : aiohttp .ClientResponse ,
268
+ block_size = None , * , decompress = True ) -> None :
255
269
super (AioHttpTransportResponse , self ).__init__ (request , aiohttp_response , block_size = block_size )
256
270
# https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse
257
271
self .status_code = aiohttp_response .status
258
272
self .headers = CIMultiDict (aiohttp_response .headers )
259
273
self .reason = aiohttp_response .reason
260
274
self .content_type = aiohttp_response .headers .get ('content-type' )
261
275
self ._body = None
276
+ self ._decompressed_body = None
277
+ self ._decompress = decompress
262
278
263
279
def body (self ) -> bytes :
264
280
"""Return the whole body as bytes in memory.
265
281
"""
266
282
if self ._body is None :
267
283
raise ValueError ("Body is not available. Call async method load_body, or do your call with stream=False." )
284
+ if not self ._decompress :
285
+ return self ._body
286
+ enc = self .headers .get ('Content-Encoding' )
287
+ if not enc :
288
+ return self ._body
289
+ enc = enc .lower ()
290
+ if enc in ("gzip" , "deflate" ):
291
+ if self ._decompressed_body :
292
+ return self ._decompressed_body
293
+ import zlib
294
+ zlib_mode = 16 + zlib .MAX_WBITS if enc == "gzip" else zlib .MAX_WBITS
295
+ decompressor = zlib .decompressobj (wbits = zlib_mode )
296
+ self ._decompressed_body = decompressor .decompress (self ._body )
297
+ return self ._decompressed_body
268
298
return self ._body
269
299
270
300
def text (self , encoding : Optional [str ] = None ) -> str :
@@ -274,10 +304,36 @@ def text(self, encoding: Optional[str] = None) -> str:
274
304
275
305
:param str encoding: The encoding to apply.
276
306
"""
307
+ # super().text detects charset based on self._body() which is compressed
308
+ # implement the decoding explicitly here
309
+ body = self .body ()
310
+
311
+ ctype = self .headers .get (aiohttp .hdrs .CONTENT_TYPE , "" ).lower ()
312
+ mimetype = aiohttp .helpers .parse_mimetype (ctype )
313
+
314
+ encoding = mimetype .parameters .get ("charset" )
315
+ if encoding :
316
+ try :
317
+ codecs .lookup (encoding )
318
+ except LookupError :
319
+ encoding = None
320
+ if not encoding :
321
+ if mimetype .type == "application" and (
322
+ mimetype .subtype == "json" or mimetype .subtype == "rdap"
323
+ ):
324
+ # RFC 7159 states that the default encoding is UTF-8.
325
+ # RFC 7483 defines application/rdap+json
326
+ encoding = "utf-8"
327
+ elif body is None :
328
+ raise RuntimeError (
329
+ "Cannot guess the encoding of a not yet read body"
330
+ )
331
+ else :
332
+ encoding = chardet .detect (body )["encoding" ]
277
333
if not encoding :
278
- encoding = self . internal_response . get_encoding ()
334
+ encoding = "utf-8-sig"
279
335
280
- return super (). text (encoding )
336
+ return body . decode (encoding )
281
337
282
338
async def load_body (self ) -> None :
283
339
"""Load in memory the body, so it could be accessible from sync methods."""
@@ -289,7 +345,7 @@ def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
289
345
:param pipeline: The pipeline object
290
346
:type pipeline: azure.core.pipeline.Pipeline
291
347
:keyword bool decompress: If True which is default, will attempt to decode the body based
292
- on the ‘ content-encoding’ header.
348
+ on the * content-encoding* header.
293
349
"""
294
350
return AioHttpStreamDownloadGenerator (pipeline , self , ** kwargs )
295
351
0 commit comments