Skip to content

Core decompress body #18581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
82 changes: 69 additions & 13 deletions sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@
# --------------------------------------------------------------------------
from typing import Any, Optional, AsyncIterator as AsyncIteratorType
from collections.abc import AsyncIterator
try:
import cchardet as chardet
except ImportError: # pragma: no cover
import chardet # type: ignore

import logging
import asyncio
import codecs
import aiohttp
from multidict import CIMultiDict
from requests.exceptions import StreamConsumedError
Expand Down Expand Up @@ -66,7 +71,7 @@ class AioHttpTransport(AsyncHttpTransport):
:dedent: 4
:caption: Asynchronous transport with aiohttp.
"""
def __init__(self, *, session=None, loop=None, session_owner=True, **kwargs):
def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner=True, **kwargs):
self._loop = loop
self._session_owner = session_owner
self.session = session
Expand Down Expand Up @@ -145,6 +150,11 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
:keyword str proxy: will define the proxy to use all the time
"""
await self.open()
try:
auto_decompress = self.session.auto_decompress # type: ignore
except AttributeError:
# auto_decompress is introduced in Python 3.7. We need this to handle Python 3.6.
auto_decompress = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a comment as to why we need this. I know I would be confused unless I knew the history...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.


proxies = config.pop('proxies', None)
if proxies and 'proxy' not in config:
Expand All @@ -171,7 +181,7 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
timeout = config.pop('connection_timeout', self.connection_config.timeout)
read_timeout = config.pop('read_timeout', self.connection_config.read_timeout)
socket_timeout = aiohttp.ClientTimeout(sock_connect=timeout, sock_read=read_timeout)
result = await self.session.request(
result = await self.session.request( # type: ignore
request.method,
request.url,
headers=request.headers,
Expand All @@ -180,7 +190,9 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
allow_redirects=False,
**config
)
response = AioHttpTransportResponse(request, result, self.connection_config.data_block_size)
response = AioHttpTransportResponse(request, result,
self.connection_config.data_block_size,
decompress=not auto_decompress)
if not stream_response:
await response.load_body()
except aiohttp.client_exceptions.ClientResponseError as err:
Expand All @@ -196,17 +208,15 @@ class AioHttpStreamDownloadGenerator(AsyncIterator):

:param pipeline: The pipeline object
:param response: The client response object.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the content-encoding header.
:param bool decompress: If True which is default, will attempt to decode the body based
on the *content-encoding* header.
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, *, decompress=True) -> None:
self.pipeline = pipeline
self.request = response.request
self.response = response
self.block_size = response.block_size
self._decompress = kwargs.pop("decompress", True)
if len(kwargs) > 0:
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
self._decompress = decompress
self.content_length = int(response.internal_response.headers.get('Content-Length', 0))
self._decompressor = None

Expand Down Expand Up @@ -250,21 +260,41 @@ class AioHttpTransportResponse(AsyncHttpResponse):
:type aiohttp_response: aiohttp.ClientResponse object
:param block_size: block size of data sent over connection.
:type block_size: int
:param bool decompress: If True which is default, will attempt to decode the body based
on the *content-encoding* header.
"""
def __init__(self, request: HttpRequest, aiohttp_response: aiohttp.ClientResponse, block_size=None) -> None:
def __init__(self, request: HttpRequest,
aiohttp_response: aiohttp.ClientResponse,
block_size=None, *, decompress=True) -> None:
super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size)
# https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse
self.status_code = aiohttp_response.status
self.headers = CIMultiDict(aiohttp_response.headers)
self.reason = aiohttp_response.reason
self.content_type = aiohttp_response.headers.get('content-type')
self._body = None
self._decompressed_body = None
self._decompress = decompress

def body(self) -> bytes:
"""Return the whole body as bytes in memory.
"""
if self._body is None:
raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.")
if not self._decompress:
return self._body
enc = self.headers.get('Content-Encoding')
if not enc:
return self._body
enc = enc.lower()
if enc in ("gzip", "deflate"):
if self._decompressed_body:
return self._decompressed_body
import zlib
zlib_mode = 16 + zlib.MAX_WBITS if enc == "gzip" else zlib.MAX_WBITS
decompressor = zlib.decompressobj(wbits=zlib_mode)
self._decompressed_body = decompressor.decompress(self._body)
return self._decompressed_body
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest keeping a single copy of the body around. Unless you still need the compressed version for some reason...

return self._body

def text(self, encoding: Optional[str] = None) -> str:
Expand All @@ -274,10 +304,36 @@ def text(self, encoding: Optional[str] = None) -> str:

:param str encoding: The encoding to apply.
"""
# super().text detects charset based on self._body() which is compressed
# implement the decoding explicitly here
body = self.body()

ctype = self.headers.get(aiohttp.hdrs.CONTENT_TYPE, "").lower()
mimetype = aiohttp.helpers.parse_mimetype(ctype)

encoding = mimetype.parameters.get("charset")
if encoding:
try:
codecs.lookup(encoding)
except LookupError:
encoding = None
if not encoding:
if mimetype.type == "application" and (
mimetype.subtype == "json" or mimetype.subtype == "rdap"
):
# RFC 7159 states that the default encoding is UTF-8.
# RFC 7483 defines application/rdap+json
encoding = "utf-8"
elif body is None:
raise RuntimeError(
"Cannot guess the encoding of a not yet read body"
)
else:
encoding = chardet.detect(body)["encoding"]
if not encoding:
encoding = self.internal_response.get_encoding()
encoding = "utf-8-sig"

return super().text(encoding)
return body.decode(encoding)

async def load_body(self) -> None:
"""Load in memory the body, so it could be accessible from sync methods."""
Expand All @@ -289,7 +345,7 @@ def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
:param pipeline: The pipeline object
:type pipeline: azure.core.pipeline.Pipeline
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the content-encoding header.
on the *content-encoding* header.
"""
return AioHttpStreamDownloadGenerator(pipeline, self, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
:param pipeline: The pipeline object
:type pipeline: azure.core.pipeline.Pipeline
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the content-encoding header.
on the *content-encoding* header.
"""

def parts(self) -> AsyncIterator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class AsyncioStreamDownloadGenerator(AsyncIterator):
:param pipeline: The pipeline object
:param response: The response object.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the content-encoding header.
on the *content-encoding* header.
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
self.pipeline = pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class StreamDownloadGenerator(object):
:param pipeline: The pipeline object
:param response: The response object.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the content-encoding header.
on the *content-encoding* header.
"""
def __init__(self, pipeline, response, **kwargs):
self.pipeline = pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TrioStreamDownloadGenerator(AsyncIterator):
:param pipeline: The pipeline object
:param response: The response object.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the content-encoding header.
on the *content-encoding* header.
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
self.pipeline = pipeline
Expand Down
176 changes: 176 additions & 0 deletions sdk/core/azure-core/tests/async_tests/test_streaming_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# --------------------------------------------------------------------------
import os
import pytest
from azure.core import AsyncPipelineClient

@pytest.mark.asyncio
async def test_decompress_plain_no_header():
# expect plain text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=True)
content = b""
async for d in data:
content += d
decoded = content.decode('utf-8')
assert decoded == "test"

@pytest.mark.asyncio
async def test_compress_plain_no_header():
# expect plain text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=False)
content = b""
async for d in data:
content += d
decoded = content.decode('utf-8')
assert decoded == "test"

@pytest.mark.asyncio
async def test_decompress_compressed_no_header():
# expect compressed text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=True)
content = b""
async for d in data:
content += d
try:
decoded = content.decode('utf-8')
assert False
except UnicodeDecodeError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this raises because we couldn't decompress because no header way found - is that right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there is no encoding header, we will not try to decompress it.

Here raises error because it fails to decode a compressed stream.

pass

@pytest.mark.asyncio
async def test_compress_compressed_no_header():
# expect compressed text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=False)
content = b""
async for d in data:
content += d
try:
decoded = content.decode('utf-8')
assert False
except UnicodeDecodeError:
pass

@pytest.mark.asyncio
async def test_decompress_plain_header():
# expect error
import zlib
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the header that being returned here?
Is it raising because the header says it's gzip, but the content itself doesn't match?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right.

In this scenario, there is an encoding header and we pass in decompress=True. We will try to decompress the stream which is not in correct format. So the decompression will fail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this failing because the decompression algorithm mismatches the header? Or because the content itself mismatches the header?
I'm wondering because it's a zlib error, but the test name indicates a 'plain' content header.

Copy link
Member Author

@xiangyan99 xiangyan99 May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With encoding header "gzip" we will try to use "gzip" algorithm to decompress the stream.

But the content of the stream itself is not compressed (it is plain text).

What happens here is we try to decompress an un-compressed stream so we fail to decompress.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha - perfect!

try:
content = b""
async for d in data:
content += d
assert False
except zlib.error:
pass

@pytest.mark.asyncio
async def test_compress_plain_header():
# expect plain text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=False)
content = b""
async for d in data:
content += d
decoded = content.decode('utf-8')
assert decoded == "test"

@pytest.mark.asyncio
async def test_decompress_compressed_header():
# expect plain text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=True)
content = b""
async for d in data:
content += d
decoded = content.decode('utf-8')
assert decoded == "test"

@pytest.mark.asyncio
async def test_compress_compressed_header():
# expect compressed text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=False)
content = b""
async for d in data:
content += d
try:
decoded = content.decode('utf-8')
assert False
except UnicodeDecodeError:
pass
Loading