-
Notifications
You must be signed in to change notification settings - Fork 3k
Core decompress body #18581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Core decompress body #18581
Changes from all commits
fca7114
2eb475a
e701b08
eaaf0d9
1ef0f93
7fe6f99
c85194e
2f76efd
74ef986
4ea0e1d
e3362bf
1dcebea
6e20e09
2785d0d
5cb8420
3c687ff
c9cb8ec
c9eed8d
940c2bc
db86c02
b77ce5f
64ffe34
8d29899
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,9 +25,14 @@ | |
# -------------------------------------------------------------------------- | ||
from typing import Any, Optional, AsyncIterator as AsyncIteratorType | ||
from collections.abc import AsyncIterator | ||
try: | ||
import cchardet as chardet | ||
except ImportError: # pragma: no cover | ||
import chardet # type: ignore | ||
|
||
import logging | ||
import asyncio | ||
import codecs | ||
import aiohttp | ||
from multidict import CIMultiDict | ||
from requests.exceptions import StreamConsumedError | ||
|
@@ -66,7 +71,7 @@ class AioHttpTransport(AsyncHttpTransport): | |
:dedent: 4 | ||
:caption: Asynchronous transport with aiohttp. | ||
""" | ||
def __init__(self, *, session=None, loop=None, session_owner=True, **kwargs): | ||
def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner=True, **kwargs): | ||
self._loop = loop | ||
self._session_owner = session_owner | ||
self.session = session | ||
|
@@ -145,6 +150,11 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR | |
:keyword str proxy: will define the proxy to use all the time | ||
""" | ||
await self.open() | ||
try: | ||
auto_decompress = self.session.auto_decompress # type: ignore | ||
except AttributeError: | ||
# auto_decompress is introduced in Python 3.7. We need this to handle Python 3.6. | ||
auto_decompress = True | ||
|
||
proxies = config.pop('proxies', None) | ||
if proxies and 'proxy' not in config: | ||
|
@@ -171,7 +181,7 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR | |
timeout = config.pop('connection_timeout', self.connection_config.timeout) | ||
read_timeout = config.pop('read_timeout', self.connection_config.read_timeout) | ||
socket_timeout = aiohttp.ClientTimeout(sock_connect=timeout, sock_read=read_timeout) | ||
result = await self.session.request( | ||
result = await self.session.request( # type: ignore | ||
request.method, | ||
request.url, | ||
headers=request.headers, | ||
|
@@ -180,7 +190,9 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR | |
allow_redirects=False, | ||
**config | ||
) | ||
response = AioHttpTransportResponse(request, result, self.connection_config.data_block_size) | ||
response = AioHttpTransportResponse(request, result, | ||
self.connection_config.data_block_size, | ||
decompress=not auto_decompress) | ||
if not stream_response: | ||
await response.load_body() | ||
except aiohttp.client_exceptions.ClientResponseError as err: | ||
|
@@ -196,17 +208,15 @@ class AioHttpStreamDownloadGenerator(AsyncIterator): | |
|
||
:param pipeline: The pipeline object | ||
:param response: The client response object. | ||
:keyword bool decompress: If True which is default, will attempt to decode the body based | ||
on the ‘content-encoding’ header. | ||
:param bool decompress: If True which is default, will attempt to decode the body based | ||
on the *content-encoding* header. | ||
""" | ||
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None: | ||
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, *, decompress=True) -> None: | ||
self.pipeline = pipeline | ||
self.request = response.request | ||
self.response = response | ||
self.block_size = response.block_size | ||
self._decompress = kwargs.pop("decompress", True) | ||
if len(kwargs) > 0: | ||
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) | ||
self._decompress = decompress | ||
self.content_length = int(response.internal_response.headers.get('Content-Length', 0)) | ||
self._decompressor = None | ||
|
||
|
@@ -250,21 +260,41 @@ class AioHttpTransportResponse(AsyncHttpResponse): | |
:type aiohttp_response: aiohttp.ClientResponse object | ||
:param block_size: block size of data sent over connection. | ||
:type block_size: int | ||
:param bool decompress: If True which is default, will attempt to decode the body based | ||
on the *content-encoding* header. | ||
""" | ||
def __init__(self, request: HttpRequest, aiohttp_response: aiohttp.ClientResponse, block_size=None) -> None: | ||
def __init__(self, request: HttpRequest, | ||
aiohttp_response: aiohttp.ClientResponse, | ||
block_size=None, *, decompress=True) -> None: | ||
super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size) | ||
# https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse | ||
self.status_code = aiohttp_response.status | ||
self.headers = CIMultiDict(aiohttp_response.headers) | ||
self.reason = aiohttp_response.reason | ||
self.content_type = aiohttp_response.headers.get('content-type') | ||
self._body = None | ||
self._decompressed_body = None | ||
self._decompress = decompress | ||
|
||
def body(self) -> bytes: | ||
"""Return the whole body as bytes in memory. | ||
""" | ||
if self._body is None: | ||
raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.") | ||
if not self._decompress: | ||
return self._body | ||
enc = self.headers.get('Content-Encoding') | ||
if not enc: | ||
return self._body | ||
enc = enc.lower() | ||
if enc in ("gzip", "deflate"): | ||
if self._decompressed_body: | ||
return self._decompressed_body | ||
import zlib | ||
zlib_mode = 16 + zlib.MAX_WBITS if enc == "gzip" else zlib.MAX_WBITS | ||
decompressor = zlib.decompressobj(wbits=zlib_mode) | ||
self._decompressed_body = decompressor.decompress(self._body) | ||
return self._decompressed_body | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest keeping a single copy of the body around. Unless you still need the compressed version for some reason... |
||
return self._body | ||
|
||
def text(self, encoding: Optional[str] = None) -> str: | ||
|
@@ -274,10 +304,36 @@ def text(self, encoding: Optional[str] = None) -> str: | |
|
||
:param str encoding: The encoding to apply. | ||
""" | ||
# super().text detects charset based on self._body() which is compressed | ||
# implement the decoding explicitly here | ||
body = self.body() | ||
|
||
ctype = self.headers.get(aiohttp.hdrs.CONTENT_TYPE, "").lower() | ||
mimetype = aiohttp.helpers.parse_mimetype(ctype) | ||
|
||
encoding = mimetype.parameters.get("charset") | ||
if encoding: | ||
try: | ||
codecs.lookup(encoding) | ||
except LookupError: | ||
encoding = None | ||
if not encoding: | ||
if mimetype.type == "application" and ( | ||
mimetype.subtype == "json" or mimetype.subtype == "rdap" | ||
): | ||
# RFC 7159 states that the default encoding is UTF-8. | ||
# RFC 7483 defines application/rdap+json | ||
encoding = "utf-8" | ||
elif body is None: | ||
raise RuntimeError( | ||
"Cannot guess the encoding of a not yet read body" | ||
) | ||
else: | ||
encoding = chardet.detect(body)["encoding"] | ||
annatisch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not encoding: | ||
annatisch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
encoding = self.internal_response.get_encoding() | ||
encoding = "utf-8-sig" | ||
|
||
return super().text(encoding) | ||
return body.decode(encoding) | ||
|
||
async def load_body(self) -> None: | ||
"""Load in memory the body, so it could be accessible from sync methods.""" | ||
|
@@ -289,7 +345,7 @@ def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: | |
:param pipeline: The pipeline object | ||
:type pipeline: azure.core.pipeline.Pipeline | ||
:keyword bool decompress: If True which is default, will attempt to decode the body based | ||
on the ‘content-encoding’ header. | ||
on the *content-encoding* header. | ||
""" | ||
return AioHttpStreamDownloadGenerator(pipeline, self, **kwargs) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# -------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# | ||
# The MIT License (MIT) | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy | ||
# of this software and associated documentation files (the ""Software""), to deal | ||
# in the Software without restriction, including without limitation the rights | ||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
# copies of the Software, and to permit persons to whom the Software is | ||
# furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in | ||
# all copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
# THE SOFTWARE. | ||
# | ||
# -------------------------------------------------------------------------- | ||
import os | ||
import pytest | ||
from azure.core import AsyncPipelineClient | ||
|
||
@pytest.mark.asyncio | ||
async def test_decompress_plain_no_header(): | ||
# expect plain text | ||
account_name = "coretests" | ||
account_url = "https://{}.blob.core.windows.net".format(account_name) | ||
url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) | ||
client = AsyncPipelineClient(account_url) | ||
request = client.get(url) | ||
pipeline_response = await client._pipeline.run(request, stream=True) | ||
response = pipeline_response.http_response | ||
data = response.stream_download(client._pipeline, decompress=True) | ||
content = b"" | ||
async for d in data: | ||
content += d | ||
decoded = content.decode('utf-8') | ||
assert decoded == "test" | ||
|
||
@pytest.mark.asyncio | ||
async def test_compress_plain_no_header(): | ||
# expect plain text | ||
account_name = "coretests" | ||
account_url = "https://{}.blob.core.windows.net".format(account_name) | ||
url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) | ||
client = AsyncPipelineClient(account_url) | ||
request = client.get(url) | ||
pipeline_response = await client._pipeline.run(request, stream=True) | ||
response = pipeline_response.http_response | ||
data = response.stream_download(client._pipeline, decompress=False) | ||
content = b"" | ||
async for d in data: | ||
content += d | ||
decoded = content.decode('utf-8') | ||
assert decoded == "test" | ||
|
||
@pytest.mark.asyncio | ||
async def test_decompress_compressed_no_header(): | ||
# expect compressed text | ||
account_name = "coretests" | ||
account_url = "https://{}.blob.core.windows.net".format(account_name) | ||
url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name) | ||
client = AsyncPipelineClient(account_url) | ||
request = client.get(url) | ||
pipeline_response = await client._pipeline.run(request, stream=True) | ||
response = pipeline_response.http_response | ||
data = response.stream_download(client._pipeline, decompress=True) | ||
content = b"" | ||
async for d in data: | ||
content += d | ||
try: | ||
decoded = content.decode('utf-8') | ||
assert False | ||
except UnicodeDecodeError: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this raises because we couldn't decompress because no header way found - is that right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because there is no encoding header, we will not try to decompress it. Here raises error because it fails to decode a compressed stream. |
||
pass | ||
|
||
@pytest.mark.asyncio | ||
async def test_compress_compressed_no_header(): | ||
# expect compressed text | ||
account_name = "coretests" | ||
account_url = "https://{}.blob.core.windows.net".format(account_name) | ||
url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name) | ||
client = AsyncPipelineClient(account_url) | ||
request = client.get(url) | ||
pipeline_response = await client._pipeline.run(request, stream=True) | ||
response = pipeline_response.http_response | ||
data = response.stream_download(client._pipeline, decompress=False) | ||
content = b"" | ||
async for d in data: | ||
content += d | ||
try: | ||
decoded = content.decode('utf-8') | ||
assert False | ||
except UnicodeDecodeError: | ||
pass | ||
|
||
@pytest.mark.asyncio | ||
async def test_decompress_plain_header(): | ||
# expect error | ||
import zlib | ||
account_name = "coretests" | ||
account_url = "https://{}.blob.core.windows.net".format(account_name) | ||
url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name) | ||
client = AsyncPipelineClient(account_url) | ||
request = client.get(url) | ||
pipeline_response = await client._pipeline.run(request, stream=True) | ||
response = pipeline_response.http_response | ||
data = response.stream_download(client._pipeline, decompress=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the header that being returned here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. In this scenario, there is an encoding header and we pass in decompress=True. We will try to decompress the stream which is not in correct format. So the decompression will fail. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this failing because the decompression algorithm mismatches the header? Or because the content itself mismatches the header? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With encoding header "gzip" we will try to use "gzip" algorithm to decompress the stream. But the content of the stream itself is not compressed (it is plain text). What happens here is we try to decompress an un-compressed stream so we fail to decompress. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha - perfect! |
||
try: | ||
content = b"" | ||
async for d in data: | ||
content += d | ||
assert False | ||
except zlib.error: | ||
pass | ||
|
||
@pytest.mark.asyncio | ||
async def test_compress_plain_header(): | ||
# expect plain text | ||
account_name = "coretests" | ||
account_url = "https://{}.blob.core.windows.net".format(account_name) | ||
url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name) | ||
client = AsyncPipelineClient(account_url) | ||
request = client.get(url) | ||
pipeline_response = await client._pipeline.run(request, stream=True) | ||
response = pipeline_response.http_response | ||
data = response.stream_download(client._pipeline, decompress=False) | ||
content = b"" | ||
async for d in data: | ||
content += d | ||
decoded = content.decode('utf-8') | ||
assert decoded == "test" | ||
|
||
@pytest.mark.asyncio | ||
async def test_decompress_compressed_header(): | ||
# expect plain text | ||
account_name = "coretests" | ||
account_url = "https://{}.blob.core.windows.net".format(account_name) | ||
url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name) | ||
client = AsyncPipelineClient(account_url) | ||
request = client.get(url) | ||
pipeline_response = await client._pipeline.run(request, stream=True) | ||
response = pipeline_response.http_response | ||
data = response.stream_download(client._pipeline, decompress=True) | ||
content = b"" | ||
async for d in data: | ||
content += d | ||
decoded = content.decode('utf-8') | ||
assert decoded == "test" | ||
|
||
@pytest.mark.asyncio | ||
async def test_compress_compressed_header(): | ||
# expect compressed text | ||
account_name = "coretests" | ||
account_url = "https://{}.blob.core.windows.net".format(account_name) | ||
url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name) | ||
client = AsyncPipelineClient(account_url) | ||
request = client.get(url) | ||
pipeline_response = await client._pipeline.run(request, stream=True) | ||
response = pipeline_response.http_response | ||
data = response.stream_download(client._pipeline, decompress=False) | ||
content = b"" | ||
async for d in data: | ||
content += d | ||
try: | ||
decoded = content.decode('utf-8') | ||
assert False | ||
except UnicodeDecodeError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a comment as to why we need this. I know I would be confused unless I knew the history...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.