Skip to content

Commit 502c702

Browse files
authored
Core decompress body (#18581)
* decompress body * update * update * update * update * update recorded tests * update * update * update * update * update * add type annotation * update * update * update * update * update * update * update doc * update * update * add comments * update
1 parent 802c887 commit 502c702

File tree

701 files changed

+56247
-111641
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

701 files changed

+56247
-111641
lines changed

sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,14 @@
2525
# --------------------------------------------------------------------------
2626
from typing import Any, Optional, AsyncIterator as AsyncIteratorType
2727
from collections.abc import AsyncIterator
28+
try:
29+
import cchardet as chardet
30+
except ImportError: # pragma: no cover
31+
import chardet # type: ignore
2832

2933
import logging
3034
import asyncio
35+
import codecs
3136
import aiohttp
3237
from multidict import CIMultiDict
3338
from requests.exceptions import StreamConsumedError
@@ -66,7 +71,7 @@ class AioHttpTransport(AsyncHttpTransport):
6671
:dedent: 4
6772
:caption: Asynchronous transport with aiohttp.
6873
"""
69-
def __init__(self, *, session=None, loop=None, session_owner=True, **kwargs):
74+
def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner=True, **kwargs):
7075
self._loop = loop
7176
self._session_owner = session_owner
7277
self.session = session
@@ -145,6 +150,11 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
145150
:keyword str proxy: will define the proxy to use all the time
146151
"""
147152
await self.open()
153+
try:
154+
auto_decompress = self.session.auto_decompress # type: ignore
155+
except AttributeError:
156+
# auto_decompress is introduced in Python 3.7. We need this to handle Python 3.6.
157+
auto_decompress = True
148158

149159
proxies = config.pop('proxies', None)
150160
if proxies and 'proxy' not in config:
@@ -171,7 +181,7 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
171181
timeout = config.pop('connection_timeout', self.connection_config.timeout)
172182
read_timeout = config.pop('read_timeout', self.connection_config.read_timeout)
173183
socket_timeout = aiohttp.ClientTimeout(sock_connect=timeout, sock_read=read_timeout)
174-
result = await self.session.request(
184+
result = await self.session.request( # type: ignore
175185
request.method,
176186
request.url,
177187
headers=request.headers,
@@ -180,7 +190,9 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
180190
allow_redirects=False,
181191
**config
182192
)
183-
response = AioHttpTransportResponse(request, result, self.connection_config.data_block_size)
193+
response = AioHttpTransportResponse(request, result,
194+
self.connection_config.data_block_size,
195+
decompress=not auto_decompress)
184196
if not stream_response:
185197
await response.load_body()
186198
except aiohttp.client_exceptions.ClientResponseError as err:
@@ -196,17 +208,15 @@ class AioHttpStreamDownloadGenerator(AsyncIterator):
196208
197209
:param pipeline: The pipeline object
198210
:param response: The client response object.
199-
:keyword bool decompress: If True which is default, will attempt to decode the body based
200-
on the content-encoding header.
211+
:param bool decompress: If True which is default, will attempt to decode the body based
212+
on the *content-encoding* header.
201213
"""
202-
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
214+
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, *, decompress=True) -> None:
203215
self.pipeline = pipeline
204216
self.request = response.request
205217
self.response = response
206218
self.block_size = response.block_size
207-
self._decompress = kwargs.pop("decompress", True)
208-
if len(kwargs) > 0:
209-
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
219+
self._decompress = decompress
210220
self.content_length = int(response.internal_response.headers.get('Content-Length', 0))
211221
self._decompressor = None
212222

@@ -250,21 +260,41 @@ class AioHttpTransportResponse(AsyncHttpResponse):
250260
:type aiohttp_response: aiohttp.ClientResponse object
251261
:param block_size: block size of data sent over connection.
252262
:type block_size: int
263+
:param bool decompress: If True which is default, will attempt to decode the body based
264+
on the *content-encoding* header.
253265
"""
254-
def __init__(self, request: HttpRequest, aiohttp_response: aiohttp.ClientResponse, block_size=None) -> None:
266+
def __init__(self, request: HttpRequest,
267+
aiohttp_response: aiohttp.ClientResponse,
268+
block_size=None, *, decompress=True) -> None:
255269
super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size)
256270
# https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse
257271
self.status_code = aiohttp_response.status
258272
self.headers = CIMultiDict(aiohttp_response.headers)
259273
self.reason = aiohttp_response.reason
260274
self.content_type = aiohttp_response.headers.get('content-type')
261275
self._body = None
276+
self._decompressed_body = None
277+
self._decompress = decompress
262278

263279
def body(self) -> bytes:
264280
"""Return the whole body as bytes in memory.
265281
"""
266282
if self._body is None:
267283
raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.")
284+
if not self._decompress:
285+
return self._body
286+
enc = self.headers.get('Content-Encoding')
287+
if not enc:
288+
return self._body
289+
enc = enc.lower()
290+
if enc in ("gzip", "deflate"):
291+
if self._decompressed_body:
292+
return self._decompressed_body
293+
import zlib
294+
zlib_mode = 16 + zlib.MAX_WBITS if enc == "gzip" else zlib.MAX_WBITS
295+
decompressor = zlib.decompressobj(wbits=zlib_mode)
296+
self._decompressed_body = decompressor.decompress(self._body)
297+
return self._decompressed_body
268298
return self._body
269299

270300
def text(self, encoding: Optional[str] = None) -> str:
@@ -274,10 +304,36 @@ def text(self, encoding: Optional[str] = None) -> str:
274304
275305
:param str encoding: The encoding to apply.
276306
"""
307+
# super().text detects charset based on self._body() which is compressed
308+
# implement the decoding explicitly here
309+
body = self.body()
310+
311+
ctype = self.headers.get(aiohttp.hdrs.CONTENT_TYPE, "").lower()
312+
mimetype = aiohttp.helpers.parse_mimetype(ctype)
313+
314+
encoding = mimetype.parameters.get("charset")
315+
if encoding:
316+
try:
317+
codecs.lookup(encoding)
318+
except LookupError:
319+
encoding = None
320+
if not encoding:
321+
if mimetype.type == "application" and (
322+
mimetype.subtype == "json" or mimetype.subtype == "rdap"
323+
):
324+
# RFC 7159 states that the default encoding is UTF-8.
325+
# RFC 7483 defines application/rdap+json
326+
encoding = "utf-8"
327+
elif body is None:
328+
raise RuntimeError(
329+
"Cannot guess the encoding of a not yet read body"
330+
)
331+
else:
332+
encoding = chardet.detect(body)["encoding"]
277333
if not encoding:
278-
encoding = self.internal_response.get_encoding()
334+
encoding = "utf-8-sig"
279335

280-
return super().text(encoding)
336+
return body.decode(encoding)
281337

282338
async def load_body(self) -> None:
283339
"""Load in memory the body, so it could be accessible from sync methods."""
@@ -289,7 +345,7 @@ def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
289345
:param pipeline: The pipeline object
290346
:type pipeline: azure.core.pipeline.Pipeline
291347
:keyword bool decompress: If True which is default, will attempt to decode the body based
292-
on the content-encoding header.
348+
on the *content-encoding* header.
293349
"""
294350
return AioHttpStreamDownloadGenerator(pipeline, self, **kwargs)
295351

sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
133133
:param pipeline: The pipeline object
134134
:type pipeline: azure.core.pipeline.Pipeline
135135
:keyword bool decompress: If True which is default, will attempt to decode the body based
136-
on the content-encoding header.
136+
on the *content-encoding* header.
137137
"""
138138

139139
def parts(self) -> AsyncIterator:

sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class AsyncioStreamDownloadGenerator(AsyncIterator):
139139
:param pipeline: The pipeline object
140140
:param response: The response object.
141141
:keyword bool decompress: If True which is default, will attempt to decode the body based
142-
on the content-encoding header.
142+
on the *content-encoding* header.
143143
"""
144144
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
145145
self.pipeline = pipeline

sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class StreamDownloadGenerator(object):
121121
:param pipeline: The pipeline object
122122
:param response: The response object.
123123
:keyword bool decompress: If True which is default, will attempt to decode the body based
124-
on the content-encoding header.
124+
on the *content-encoding* header.
125125
"""
126126
def __init__(self, pipeline, response, **kwargs):
127127
self.pipeline = pipeline

sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class TrioStreamDownloadGenerator(AsyncIterator):
5555
:param pipeline: The pipeline object
5656
:param response: The response object.
5757
:keyword bool decompress: If True which is default, will attempt to decode the body based
58-
on the content-encoding header.
58+
on the *content-encoding* header.
5959
"""
6060
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
6161
self.pipeline = pipeline
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# --------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Microsoft Corporation. All rights reserved.
4+
#
5+
# The MIT License (MIT)
6+
#
7+
# Permission is hereby granted, free of charge, to any person obtaining a copy
8+
# of this software and associated documentation files (the ""Software""), to deal
9+
# in the Software without restriction, including without limitation the rights
10+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
# copies of the Software, and to permit persons to whom the Software is
12+
# furnished to do so, subject to the following conditions:
13+
#
14+
# The above copyright notice and this permission notice shall be included in
15+
# all copies or substantial portions of the Software.
16+
#
17+
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23+
# THE SOFTWARE.
24+
#
25+
# --------------------------------------------------------------------------
26+
import os
27+
import pytest
28+
from azure.core import AsyncPipelineClient
29+
30+
@pytest.mark.asyncio
31+
async def test_decompress_plain_no_header():
32+
# expect plain text
33+
account_name = "coretests"
34+
account_url = "https://{}.blob.core.windows.net".format(account_name)
35+
url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name)
36+
client = AsyncPipelineClient(account_url)
37+
request = client.get(url)
38+
pipeline_response = await client._pipeline.run(request, stream=True)
39+
response = pipeline_response.http_response
40+
data = response.stream_download(client._pipeline, decompress=True)
41+
content = b""
42+
async for d in data:
43+
content += d
44+
decoded = content.decode('utf-8')
45+
assert decoded == "test"
46+
47+
@pytest.mark.asyncio
48+
async def test_compress_plain_no_header():
49+
# expect plain text
50+
account_name = "coretests"
51+
account_url = "https://{}.blob.core.windows.net".format(account_name)
52+
url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name)
53+
client = AsyncPipelineClient(account_url)
54+
request = client.get(url)
55+
pipeline_response = await client._pipeline.run(request, stream=True)
56+
response = pipeline_response.http_response
57+
data = response.stream_download(client._pipeline, decompress=False)
58+
content = b""
59+
async for d in data:
60+
content += d
61+
decoded = content.decode('utf-8')
62+
assert decoded == "test"
63+
64+
@pytest.mark.asyncio
65+
async def test_decompress_compressed_no_header():
66+
# expect compressed text
67+
account_name = "coretests"
68+
account_url = "https://{}.blob.core.windows.net".format(account_name)
69+
url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name)
70+
client = AsyncPipelineClient(account_url)
71+
request = client.get(url)
72+
pipeline_response = await client._pipeline.run(request, stream=True)
73+
response = pipeline_response.http_response
74+
data = response.stream_download(client._pipeline, decompress=True)
75+
content = b""
76+
async for d in data:
77+
content += d
78+
try:
79+
decoded = content.decode('utf-8')
80+
assert False
81+
except UnicodeDecodeError:
82+
pass
83+
84+
@pytest.mark.asyncio
85+
async def test_compress_compressed_no_header():
86+
# expect compressed text
87+
account_name = "coretests"
88+
account_url = "https://{}.blob.core.windows.net".format(account_name)
89+
url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name)
90+
client = AsyncPipelineClient(account_url)
91+
request = client.get(url)
92+
pipeline_response = await client._pipeline.run(request, stream=True)
93+
response = pipeline_response.http_response
94+
data = response.stream_download(client._pipeline, decompress=False)
95+
content = b""
96+
async for d in data:
97+
content += d
98+
try:
99+
decoded = content.decode('utf-8')
100+
assert False
101+
except UnicodeDecodeError:
102+
pass
103+
104+
@pytest.mark.asyncio
105+
async def test_decompress_plain_header():
106+
# expect error
107+
import zlib
108+
account_name = "coretests"
109+
account_url = "https://{}.blob.core.windows.net".format(account_name)
110+
url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name)
111+
client = AsyncPipelineClient(account_url)
112+
request = client.get(url)
113+
pipeline_response = await client._pipeline.run(request, stream=True)
114+
response = pipeline_response.http_response
115+
data = response.stream_download(client._pipeline, decompress=True)
116+
try:
117+
content = b""
118+
async for d in data:
119+
content += d
120+
assert False
121+
except zlib.error:
122+
pass
123+
124+
@pytest.mark.asyncio
125+
async def test_compress_plain_header():
126+
# expect plain text
127+
account_name = "coretests"
128+
account_url = "https://{}.blob.core.windows.net".format(account_name)
129+
url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name)
130+
client = AsyncPipelineClient(account_url)
131+
request = client.get(url)
132+
pipeline_response = await client._pipeline.run(request, stream=True)
133+
response = pipeline_response.http_response
134+
data = response.stream_download(client._pipeline, decompress=False)
135+
content = b""
136+
async for d in data:
137+
content += d
138+
decoded = content.decode('utf-8')
139+
assert decoded == "test"
140+
141+
@pytest.mark.asyncio
142+
async def test_decompress_compressed_header():
143+
# expect plain text
144+
account_name = "coretests"
145+
account_url = "https://{}.blob.core.windows.net".format(account_name)
146+
url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name)
147+
client = AsyncPipelineClient(account_url)
148+
request = client.get(url)
149+
pipeline_response = await client._pipeline.run(request, stream=True)
150+
response = pipeline_response.http_response
151+
data = response.stream_download(client._pipeline, decompress=True)
152+
content = b""
153+
async for d in data:
154+
content += d
155+
decoded = content.decode('utf-8')
156+
assert decoded == "test"
157+
158+
@pytest.mark.asyncio
159+
async def test_compress_compressed_header():
160+
# expect compressed text
161+
account_name = "coretests"
162+
account_url = "https://{}.blob.core.windows.net".format(account_name)
163+
url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name)
164+
client = AsyncPipelineClient(account_url)
165+
request = client.get(url)
166+
pipeline_response = await client._pipeline.run(request, stream=True)
167+
response = pipeline_response.http_response
168+
data = response.stream_download(client._pipeline, decompress=False)
169+
content = b""
170+
async for d in data:
171+
content += d
172+
try:
173+
decoded = content.decode('utf-8')
174+
assert False
175+
except UnicodeDecodeError:
176+
pass

0 commit comments

Comments
 (0)