From ab1fe0af162db2a5cfa23a62fadf79b3151e1ca4 Mon Sep 17 00:00:00 2001 From: antisch Date: Wed, 14 Jul 2021 19:38:43 -0700 Subject: [PATCH 01/15] Partial list deserialization --- .../azure/storage/blob/_container_client.py | 7 +- .../aio/operations/_container_operations.py | 6 +- .../operations/_container_operations.py | 4 +- .../azure/storage/blob/_list_blobs_helper.py | 118 ++++++++++++++---- .../blob/aio/_container_client_async.py | 5 + .../storage/blob/aio/_list_blobs_helper.py | 66 ++++++---- .../T1_legacy_tests/list_blobs.py | 9 +- .../tests/perfstress_tests/list_blobs.py | 7 +- 8 files changed, 157 insertions(+), 65 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py index b63556ba61bc..5d7e4254ed31 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py @@ -763,14 +763,15 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs): results_per_page = kwargs.pop('results_per_page', None) timeout = kwargs.pop('timeout', None) + select = kwargs.pop('select', None) command = functools.partial( self._client.container.list_blob_flat_segment, include=include, timeout=timeout, **kwargs) return ItemPaged( - command, prefix=name_starts_with, results_per_page=results_per_page, - page_iterator_class=BlobPropertiesPaged) + command, prefix=name_starts_with, results_per_page=results_per_page, select=select, + deserializer=self._client._deserialize, page_iterator_class=BlobPropertiesPaged) @distributed_trace def walk_blobs( @@ -816,6 +817,8 @@ def walk_blobs( command, prefix=name_starts_with, results_per_page=results_per_page, + select=None, + deserializer=self._client._deserialize, delimiter=delimiter) @distributed_trace diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py index bec837429209..91edd8c2d0d6 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py @@ -1452,12 +1452,12 @@ async def list_blob_flat_segment( response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id')) response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version')) response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date')) - deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response) + #deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response) if cls: - return cls(pipeline_response, deserialized, response_headers) + return cls(pipeline_response, None, response_headers) - return deserialized + return None list_blob_flat_segment.metadata = {'url': '/{containerName}'} # type: ignore async def list_blob_hierarchy_segment( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py index f01bbc4393fe..39db923628be 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py @@ -1471,10 +1471,10 @@ def list_blob_flat_segment( response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id')) response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version')) response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date')) - deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response) + #deserialized = self._deserialize('ListBlobsFlatSegmentResponse', pipeline_response) if cls: - return cls(pipeline_response, deserialized, response_headers) + return cls(pipeline_response, None, response_headers) return deserialized list_blob_flat_segment.metadata = {'url': '/{containerName}'} # type: ignore diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py index 309d37bd9583..66b4b87d39ed 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py @@ -7,6 +7,8 @@ from azure.core.paging import PageIterator, ItemPaged from azure.core.exceptions import HttpResponseError +from azure.core.pipeline.policies import ContentDecodePolicy + from ._deserialize import get_blob_properties_from_generated_code, parse_tags from ._generated.models import BlobItemInternal, BlobPrefix as GenBlobPrefix, FilterBlobItem from ._models import BlobProperties, FilteredBlob @@ -14,6 +16,64 @@ from ._shared.response_handlers import return_context_and_deserialized, process_storage_error +def deserialize_list_result(pipeline_response, _, headers): + payload = pipeline_response.context[ContentDecodePolicy.CONTEXT_NAME] + location = pipeline_response.http_response.location_mode + return location, payload + +def load_xml_string(element, name): + node = element.find(name) + if node is None or not node.text: + return None + return node.text + +def load_xml_int(element, name): + node = element.find(name) + if node is None or not node.text: + return None + return int(node.text) + +def load_xml_bool(element, name): + node = load_xml_string(element, name) + if node and node.lower() == 'true': + return True + return False + + +def load_single_node(element, name): + return element.find(name) + + +def load_many_nodes(element, name, wrapper=None): + if wrapper: + element = load_single_node(element, wrapper) + return list(element.findall(name)) + + +def blob_properties_from_xml(element, select, deserializer): + if not select: + generated = deserializer.deserialize_data(element, 'BlobItemInternal') + return get_blob_properties_from_generated_code(generated) + blob = BlobProperties() + if 'name' in select: + blob.name = load_xml_string(element, 'Name') + if 'deleted' in select: + blob.deleted = load_xml_bool(element, 'Deleted') + if 'snapshot' in select: + blob.snapshot = load_xml_string(element, 'Snapshot') + if 'version' in select: + blob.version_id = load_xml_string(element, 'VersionId') + blob.is_current_version = load_xml_bool(element, 'IsCurrentVersion') + # TODO: Should also support selecting 'tags' and 'metadata', but these are only returned + # if opted-into with the 'include' parameter. + # if 'metadata' in select: + # blob.metadata = None + # blob.encrypted_metadata = None + # if 'tags' in select: + # blob.tags = None + return blob + + class BlobPropertiesPaged(PageIterator): """An Iterable of Blob properties. @@ -49,6 +109,8 @@ def __init__( container=None, prefix=None, results_per_page=None, + select=None, + deserializer=None, continuation_token=None, delimiter=None, location_mode=None): @@ -58,10 +120,12 @@ def __init__( continuation_token=continuation_token or "" ) self._command = command + self._deserializer = deserializer self.service_endpoint = None self.prefix = prefix self.marker = None self.results_per_page = results_per_page + self.select = select self.container = container self.delimiter = delimiter self.current_page = None @@ -73,30 +137,29 @@ def _get_next_cb(self, continuation_token): prefix=self.prefix, marker=continuation_token or None, maxresults=self.results_per_page, - cls=return_context_and_deserialized, + cls=deserialize_list_result, use_location=self.location_mode) except HttpResponseError as error: process_storage_error(error) def _extract_data_cb(self, get_next_return): self.location_mode, self._response = get_next_return - self.service_endpoint = self._response.service_endpoint - self.prefix = self._response.prefix - self.marker = self._response.marker - self.results_per_page = self._response.max_results - self.container = self._response.container_name - self.current_page = [self._build_item(item) for item in self._response.segment.blob_items] + self.service_endpoint = self._response.get('ServiceEndpoint') + self.prefix = load_xml_string(self._response, 'Prefix') + self.marker = load_xml_string(self._response, 'Marker') + self.results_per_page = load_xml_int(self._response, 'MaxResults') + self.container = self._response.get('ContainerName') - return self._response.next_marker or None, self.current_page + blobs = load_many_nodes(self._response, 'Blob', wrapper='Blobs') + self.current_page = [self._build_item(blob) for blob in blobs] + + next_marker = load_xml_string(self._response, 'NextMarker') + return next_marker or None, self.current_page def _build_item(self, item): - if isinstance(item, BlobProperties): - return item - if isinstance(item, BlobItemInternal): - blob = get_blob_properties_from_generated_code(item) # pylint: disable=protected-access - blob.container = self.container - return blob - return item + blob = blob_properties_from_xml(item, self.select, self._deserializer) + blob.container = self.container + return blob class BlobPrefixPaged(BlobPropertiesPaged): @@ -106,22 +169,23 @@ def __init__(self, *args, **kwargs): def _extract_data_cb(self, get_next_return): continuation_token, _ = super(BlobPrefixPaged, self)._extract_data_cb(get_next_return) - self.current_page = self._response.segment.blob_prefixes + self._response.segment.blob_items - self.current_page = [self._build_item(item) for item in self.current_page] - self.delimiter = self._response.delimiter + + blob_prefixes = load_many_nodes(self._response, 'BlobPrefix', wrapper='Blobs') + blob_prefixes = [self._build_item(blob) for blob in blob_prefixes] + + self.current_page = blob_prefixes + self.current_page + self.delimiter = load_xml_string(self._response, 'Delimiter') return continuation_token, self.current_page def _build_item(self, item): - item = super(BlobPrefixPaged, self)._build_item(item) - if isinstance(item, GenBlobPrefix): - return BlobPrefix( - self._command, - container=self.container, - prefix=item.name, - results_per_page=self.results_per_page, - location_mode=self.location_mode) - return item + return BlobPrefix( + self._command, + container=self.container, + prefix=load_xml_string(item, 'Name'), + results_per_page=self.results_per_page, + location_mode=self.location_mode + ) class BlobPrefix(ItemPaged, DictMixin): diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py index cd0164392ab6..498a646ff1f5 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py @@ -624,6 +624,7 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs): results_per_page = kwargs.pop('results_per_page', None) timeout = kwargs.pop('timeout', None) + select = kwargs.pop('select', None) command = functools.partial( self._client.container.list_blob_flat_segment, include=include, @@ -633,6 +634,8 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs): command, prefix=name_starts_with, results_per_page=results_per_page, + select=select, + deserializer=self._client._deserialize, page_iterator_class=BlobPropertiesPaged ) @@ -680,6 +683,8 @@ def walk_blobs( command, prefix=name_starts_with, results_per_page=results_per_page, + select=None, + deserializer=self._client._deserialize, delimiter=delimiter) @distributed_trace_async diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py index 058572fd270d..b0bb95c63d6c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py @@ -12,6 +12,13 @@ from .._generated.models import BlobItemInternal, BlobPrefix as GenBlobPrefix from .._shared.models import DictMixin from .._shared.response_handlers import return_context_and_deserialized, process_storage_error +from .._list_blobs_helper import ( + deserialize_list_result, + load_many_nodes, + load_xml_string, + load_xml_int, + blob_properties_from_xml +) class BlobPropertiesPaged(AsyncPageIterator): @@ -48,6 +55,8 @@ def __init__( container=None, prefix=None, results_per_page=None, + select=None, + deserializer=None, continuation_token=None, delimiter=None, location_mode=None): @@ -57,10 +66,12 @@ def __init__( continuation_token=continuation_token or "" ) self._command = command + self._deserializer = deserializer self.service_endpoint = None self.prefix = prefix self.marker = None self.results_per_page = results_per_page + self.select = select self.container = container self.delimiter = delimiter self.current_page = None @@ -72,30 +83,29 @@ async def _get_next_cb(self, continuation_token): prefix=self.prefix, marker=continuation_token or None, maxresults=self.results_per_page, - cls=return_context_and_deserialized, + cls=deserialize_list_result, use_location=self.location_mode) except HttpResponseError as error: process_storage_error(error) async def _extract_data_cb(self, get_next_return): self.location_mode, self._response = get_next_return - self.service_endpoint = self._response.service_endpoint - self.prefix = self._response.prefix - self.marker = self._response.marker - self.results_per_page = self._response.max_results - self.container = self._response.container_name - self.current_page = [self._build_item(item) for item in self._response.segment.blob_items] + self.service_endpoint = self._response.get('ServiceEndpoint') + self.prefix = load_xml_string(self._response, 'Prefix') + self.marker = load_xml_string(self._response, 'Marker') + self.results_per_page = load_xml_int(self._response, 'MaxResults') + self.container = self._response.get('ContainerName') - return self._response.next_marker or None, self.current_page + blobs = load_many_nodes(self._response, 'Blob', wrapper='Blobs') + self.current_page = [self._build_item(blob) for blob in blobs] + + next_marker = load_xml_string(self._response, 'NextMarker') + return next_marker or None, self.current_page def _build_item(self, item): - if isinstance(item, BlobProperties): - return item - if isinstance(item, BlobItemInternal): - blob = get_blob_properties_from_generated_code(item) # pylint: disable=protected-access - blob.container = self.container - return blob - return item + blob = blob_properties_from_xml(item, self.select, self._deserializer) + blob.container = self.container + return blob class BlobPrefix(AsyncItemPaged, DictMixin): @@ -144,20 +154,22 @@ def __init__(self, *args, **kwargs): self.name = self.prefix async def _extract_data_cb(self, get_next_return): - continuation_token, _ = await super(BlobPrefixPaged, self)._extract_data_cb(get_next_return) + continuation_token, current_page = await super(BlobPrefixPaged, self)._extract_data_cb(get_next_return) + + blob_prefixes = load_many_nodes(self._response, 'BlobPrefix', wrapper='Blobs') + blob_prefixes = [self._build_item(blob) for blob in blob_prefixes] + + self.current_page = blob_prefixes + current_page + self.delimiter = load_xml_string(self._response, 'Delimiter') self.current_page = self._response.segment.blob_prefixes + self._response.segment.blob_items - self.current_page = [self._build_item(item) for item in self.current_page] - self.delimiter = self._response.delimiter return continuation_token, self.current_page def _build_item(self, item): - item = super(BlobPrefixPaged, self)._build_item(item) - if isinstance(item, GenBlobPrefix): - return BlobPrefix( - self._command, - container=self.container, - prefix=item.name, - results_per_page=self.results_per_page, - location_mode=self.location_mode) - return item + return BlobPrefix( + self._command, + container=self.container, + prefix=load_xml_string(item, 'Name'), + results_per_page=self.results_per_page, + location_mode=self.location_mode + ) diff --git a/sdk/storage/azure-storage-blob/tests/perfstress_tests/T1_legacy_tests/list_blobs.py b/sdk/storage/azure-storage-blob/tests/perfstress_tests/T1_legacy_tests/list_blobs.py index b3a55bcf23b9..aedc564e0109 100644 --- a/sdk/storage/azure-storage-blob/tests/perfstress_tests/T1_legacy_tests/list_blobs.py +++ b/sdk/storage/azure-storage-blob/tests/perfstress_tests/T1_legacy_tests/list_blobs.py @@ -17,8 +17,12 @@ async def global_setup(self): blob=b"") def run_sync(self): - for _ in self.service_client.list_blobs(container_name=self.container_name): - pass + if self.args.name_only: + for _ in self.service_client.list_blob_names(container_name=self.container_name): + pass + else: + for _ in self.service_client.list_blobs(container_name=self.container_name): + pass async def run_async(self): raise NotImplementedError("Async not supported for legacy T1 tests.") @@ -27,3 +31,4 @@ async def run_async(self): def add_arguments(parser): super(LegacyListBlobsTest, LegacyListBlobsTest).add_arguments(parser) parser.add_argument('-c', '--count', nargs='?', type=int, help='Number of blobs to list. Defaults to 100', default=100) + parser.add_argument('--name-only', action='store_true', help='Return only blob name. Defaults to False', default=False) diff --git a/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py b/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py index f5f35a86fff1..b0a84c707573 100644 --- a/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py +++ b/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py @@ -27,14 +27,17 @@ async def global_setup(self): break def run_sync(self): - for _ in self.container_client.list_blobs(): + select = ['name'] if self.args.name_only else None + for _ in self.container_client.list_blobs(select=select): pass async def run_async(self): - async for _ in self.async_container_client.list_blobs(): + select = ['name'] if self.args.name_only else None + async for _ in self.async_container_client.list_blobs(select=select): pass @staticmethod def add_arguments(parser): super(ListBlobsTest, ListBlobsTest).add_arguments(parser) parser.add_argument('-c', '--count', nargs='?', type=int, help='Number of blobs to list. Defaults to 100', default=100) + parser.add_argument('--name-only', action='store_true', help='Return only blob name. Defaults to False', default=False) From 5f9d0ac625e2edf0d6a5a6fa26db6f68ee775dfd Mon Sep 17 00:00:00 2001 From: antisch Date: Fri, 16 Jul 2021 14:17:18 -0700 Subject: [PATCH 02/15] XML POC --- .../blob/_shared/xml_deserialization.py | 942 ++++++++++++++++++ 1 file changed, 942 insertions(+) create mode 100644 sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py new file mode 100644 index 000000000000..981348a538cb --- /dev/null +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -0,0 +1,942 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- + +from base64 import b64decode, b64encode +import calendar +import datetime +import decimal +import email +from enum import Enum +import json +import logging +import re +import sys +import os +try: + from urllib import quote # type: ignore +except ImportError: + from urllib.parse import quote # type: ignore + +if os.environ.get("AZURE_STORAGE_LXML"): + try: + from lxml import etree as ET + except: + import xml.etree.ElementTree as ET +else: + import xml.etree.ElementTree as ET + +import isodate + +from typing import Dict, Any + + +try: + basestring # type: ignore + unicode_str = unicode # type: ignore +except NameError: + basestring = str # type: ignore + unicode_str = str # type: ignore + +_LOGGER = logging.getLogger(__name__) + +try: + _long_type = long # type: ignore +except NameError: + _long_type = int + + +from msrest.exceptions import DeserializationError, raise_with_traceback +from msrest.serialization import ( + TZ_UTC, + _FixedOffset, + _FLATTEN +) + +def full_restapi_key_transformer(key, attr_desc, value): + """A key transformer that returns the full RestAPI key path. + + :param str _: The attribute name + :param dict attr_desc: The attribute metadata + :param object value: The value + :returns: A list of keys using RestAPI syntax. + """ + keys = _FLATTEN.split(attr_desc['key']) + return ([_decode_attribute_map_key(k) for k in keys], value) + +def last_restapi_key_transformer(key, attr_desc, value): + """A key transformer that returns the last RestAPI key. + + :param str key: The attribute name + :param dict attr_desc: The attribute metadata + :param object value: The value + :returns: The last RestAPI key. + """ + key, value = full_restapi_key_transformer(key, attr_desc, value) + return (key[-1], value) + + +def _decode_attribute_map_key(key): + """This decode a key in an _attribute_map to the actual key we want to look at + inside the received data. + + :param str key: A key string from the generated code + """ + return key.replace('\\.', '.') + + +def rest_key_extractor(attr, attr_desc, data): + key = attr_desc['key'] + working_data = data + + while '.' in key: + dict_keys = _FLATTEN.split(key) + if len(dict_keys) == 1: + key = _decode_attribute_map_key(dict_keys[0]) + break + working_key = _decode_attribute_map_key(dict_keys[0]) + working_data = working_data.get(working_key, data) + if working_data is None: + # If at any point while following flatten JSON path see None, it means + # that all properties under are None as well + # https://github.com/Azure/msrest-for-python/issues/197 + return None + key = '.'.join(dict_keys[1:]) + + return working_data.get(key) + +def rest_key_case_insensitive_extractor(attr, attr_desc, data): + key = attr_desc['key'] + working_data = data + + while '.' in key: + dict_keys = _FLATTEN.split(key) + if len(dict_keys) == 1: + key = _decode_attribute_map_key(dict_keys[0]) + break + working_key = _decode_attribute_map_key(dict_keys[0]) + working_data = attribute_key_case_insensitive_extractor(working_key, None, working_data) + if working_data is None: + # If at any point while following flatten JSON path see None, it means + # that all properties under are None as well + # https://github.com/Azure/msrest-for-python/issues/197 + return None + key = '.'.join(dict_keys[1:]) + + if working_data: + return attribute_key_case_insensitive_extractor(key, None, working_data) + +def last_rest_key_extractor(attr, attr_desc, data): + """Extract the attribute in "data" based on the last part of the JSON path key. + """ + key = attr_desc['key'] + dict_keys = _FLATTEN.split(key) + return attribute_key_extractor(dict_keys[-1], None, data) + +def last_rest_key_case_insensitive_extractor(attr, attr_desc, data): + """Extract the attribute in "data" based on the last part of the JSON path key. + + This is the case insensitive version of "last_rest_key_extractor" + """ + key = attr_desc['key'] + dict_keys = _FLATTEN.split(key) + return attribute_key_case_insensitive_extractor(dict_keys[-1], None, data) + +def attribute_key_extractor(attr, _, data): + return data.get(attr) + +def attribute_key_case_insensitive_extractor(attr, _, data): + found_key = None + lower_attr = attr.lower() + for key in data: + if lower_attr == key.lower(): + found_key = key + break + + return data.get(found_key) + +def _extract_name_from_internal_type(internal_type): + """Given an internal type XML description, extract correct XML name with namespace. + + :param dict internal_type: An model type + :rtype: tuple + :returns: A tuple XML name + namespace dict + """ + internal_type_xml_map = getattr(internal_type, "_xml_map", {}) + xml_name = internal_type_xml_map.get('name', internal_type.__name__) + xml_ns = internal_type_xml_map.get("ns", None) + if xml_ns: + xml_name = "{{{}}}{}".format(xml_ns, xml_name) + return xml_name + + +def xml_key_extractor(attr, attr_desc, data): + if isinstance(data, dict): + return None + + # Test if this model is XML ready first + if not isinstance(data, ET.Element): + return None + + xml_desc = attr_desc.get('xml', {}) + xml_name = xml_desc.get('name', attr_desc['key']) + + # Look for a children + is_iter_type = attr_desc['type'].startswith("[") + is_wrapped = xml_desc.get("wrapped", False) + internal_type = attr_desc.get("internalType", None) + internal_type_xml_map = getattr(internal_type, "_xml_map", {}) + + # Integrate namespace if necessary + xml_ns = xml_desc.get('ns', internal_type_xml_map.get("ns", None)) + if xml_ns: + xml_name = "{{{}}}{}".format(xml_ns, xml_name) + + # If it's an attribute, that's simple + if xml_desc.get("attr", False): + return data.get(xml_name) + + # If it's x-ms-text, that's simple too + if xml_desc.get("text", False): + return data.text + + # Scenario where I take the local name: + # - Wrapped node + # - Internal type is an enum (considered basic types) + # - Internal type has no XML/Name node + if is_wrapped or (internal_type and (issubclass(internal_type, Enum) or 'name' not in internal_type_xml_map)): + children = data.findall(xml_name) + # If internal type has a local name and it's not a list, I use that name + elif not is_iter_type and internal_type and 'name' in internal_type_xml_map: + xml_name = _extract_name_from_internal_type(internal_type) + children = data.findall(xml_name) + # That's an array + else: + if internal_type: # Complex type, ignore itemsName and use the complex type name + items_name = _extract_name_from_internal_type(internal_type) + else: + items_name = xml_desc.get("itemsName", xml_name) + children = data.findall(items_name) + + if len(children) == 0: + if is_iter_type: + if is_wrapped: + return None # is_wrapped no node, we want None + else: + return [] # not wrapped, assume empty list + return None # Assume it's not there, maybe an optional node. + + # If is_iter_type and not wrapped, return all found children + if is_iter_type: + if not is_wrapped: + return children + else: # Iter and wrapped, should have found one node only (the wrap one) + if len(children) != 1: + raise DeserializationError( + "Tried to deserialize an array not wrapped, and found several nodes '{}'. Maybe you should declare this array as wrapped?".format( + xml_name + )) + return list(children[0]) # Might be empty list and that's ok. + + # Here it's not a itertype, we should have found one element only or empty + if len(children) > 1: + raise DeserializationError("Find several XML '{}' where it was not expected".format(xml_name)) + return children[0] + +class Deserializer(object): + """Response object model deserializer. + + :param dict classes: Class type dictionary for deserializing complex types. + :ivar list key_extractors: Ordered list of extractors to be used by this deserializer. + """ + + basic_types = {str: 'str', int: 'int', bool: 'bool', float: 'float'} + + valid_date = re.compile( + r'\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}' + r'\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?') + + def __init__(self, classes=None): + self.deserialize_type = { + 'iso-8601': Deserializer.deserialize_iso, + 'rfc-1123': Deserializer.deserialize_rfc, + 'unix-time': Deserializer.deserialize_unix, + 'duration': Deserializer.deserialize_duration, + 'date': Deserializer.deserialize_date, + 'time': Deserializer.deserialize_time, + 'decimal': Deserializer.deserialize_decimal, + 'long': Deserializer.deserialize_long, + 'bytearray': Deserializer.deserialize_bytearray, + 'base64': Deserializer.deserialize_base64, + 'object': self.deserialize_object, + '[]': self.deserialize_iter, + '{}': self.deserialize_dict + } + self.deserialize_expected_types = { + 'duration': (isodate.Duration, datetime.timedelta), + 'iso-8601': (datetime.datetime) + } + self.dependencies = dict(classes) if classes else {} + self.key_extractors = [ + rest_key_extractor, + xml_key_extractor + ] + # Additional properties only works if the "rest_key_extractor" is used to + # extract the keys. Making it to work whatever the key extractor is too much + # complicated, with no real scenario for now. + # So adding a flag to disable additional properties detection. This flag should be + # used if your expect the deserialization to NOT come from a JSON REST syntax. + # Otherwise, result are unexpected + self.additional_properties_detection = True + + def __call__(self, target_obj, response_data, content_type=None): + """Call the deserializer to process a REST response. + + :param str target_obj: Target data type to deserialize to. + :param requests.Response response_data: REST response object. + :param str content_type: Swagger "produces" if available. + :raises: DeserializationError if deserialization fails. + :return: Deserialized object. + """ + data = self._unpack_content(response_data, content_type) + return self._deserialize(target_obj, data) + + def _deserialize(self, target_obj, data): + """Call the deserializer on a model. + + Data needs to be already deserialized as JSON or XML ElementTree + + :param str target_obj: Target data type to deserialize to. + :param object data: Object to deserialize. + :raises: DeserializationError if deserialization fails. + :return: Deserialized object. + """ + # This is already a model, go recursive just in case + if hasattr(data, "_attribute_map"): + constants = [name for name, config in getattr(data, '_validation', {}).items() + if config.get('constant')] + try: + for attr, mapconfig in data._attribute_map.items(): + if attr in constants: + continue + value = getattr(data, attr) + if value is None: + continue + local_type = mapconfig['type'] + internal_data_type = local_type.strip('[]{}') + if internal_data_type not in self.dependencies or isinstance(internal_data_type, Enum): + continue + setattr( + data, + attr, + self._deserialize(local_type, value) + ) + return data + except AttributeError: + return + + response, class_name = self._classify_target(target_obj, data) + + if isinstance(response, basestring): + return self.deserialize_data(data, response) + elif isinstance(response, type) and issubclass(response, Enum): + return self.deserialize_enum(data, response) + + if data is None: + return data + try: + attributes = response._attribute_map + d_attrs = {} + for attr, attr_desc in attributes.items(): + # Check empty string. If it's not empty, someone has a real "additionalProperties"... + if attr == "additional_properties" and attr_desc["key"] == '': + continue + raw_value = None + # Enhance attr_desc with some dynamic data + attr_desc = attr_desc.copy() # Do a copy, do not change the real one + internal_data_type = attr_desc["type"].strip('[]{}') + if internal_data_type in self.dependencies: + attr_desc["internalType"] = self.dependencies[internal_data_type] + + for key_extractor in self.key_extractors: + found_value = key_extractor(attr, attr_desc, data) + if found_value is not None: + if raw_value is not None and raw_value != found_value: + msg = ("Ignoring extracted value '%s' from %s for key '%s'" + " (duplicate extraction, follow extractors order)" ) + _LOGGER.warning( + msg, + found_value, + key_extractor, + attr + ) + continue + raw_value = found_value + + value = self.deserialize_data(raw_value, attr_desc['type']) + d_attrs[attr] = value + except (AttributeError, TypeError, KeyError) as err: + msg = "Unable to deserialize to object: " + class_name + raise_with_traceback(DeserializationError, msg, err) + else: + additional_properties = self._build_additional_properties(attributes, data) + return self._instantiate_model(response, d_attrs, additional_properties) + + def _build_additional_properties(self, attribute_map, data): + if not self.additional_properties_detection: + return None + if "additional_properties" in attribute_map and attribute_map.get("additional_properties", {}).get("key") != '': + # Check empty string. If it's not empty, someone has a real "additionalProperties" + return None + if isinstance(data, ET.Element): + data = {el.tag: el.text for el in data} + + known_keys = {_decode_attribute_map_key(_FLATTEN.split(desc['key'])[0]) + for desc in attribute_map.values() if desc['key'] != ''} + present_keys = set(data.keys()) + missing_keys = present_keys - known_keys + return {key: data[key] for key in missing_keys} + + def _classify_target(self, target, data): + """Check to see whether the deserialization target object can + be classified into a subclass. + Once classification has been determined, initialize object. + + :param str target: The target object type to deserialize to. + :param str/dict data: The response data to deseralize. + """ + if target is None: + return None, None + + if isinstance(target, basestring): + try: + target = self.dependencies[target] + except KeyError: + return target, target + + try: + target = target._classify(data, self.dependencies) + except AttributeError: + pass # Target is not a Model, no classify + return target, target.__class__.__name__ + + def failsafe_deserialize(self, target_obj, data, content_type=None): + """Ignores any errors encountered in deserialization, + and falls back to not deserializing the object. Recommended + for use in error deserialization, as we want to return the + HttpResponseError to users, and not have them deal with + a deserialization error. + + :param str target_obj: The target object type to deserialize to. + :param str/dict data: The response data to deseralize. + :param str content_type: Swagger "produces" if available. + """ + try: + return self(target_obj, data, content_type=content_type) + except: + _LOGGER.warning( + "Ran into a deserialization error. Ignoring since this is failsafe deserialization", + exc_info=True + ) + return None + + @staticmethod + def _unpack_content(raw_data, content_type=None): + """Extract the correct structure for deserialization. + + If raw_data is a PipelineResponse, try to extract the result of RawDeserializer. + if we can't, raise. Your Pipeline should have a RawDeserializer. + + If not a pipeline response and raw_data is bytes or string, use content-type + to decode it. If no content-type, try JSON. + + If raw_data is something else, bypass all logic and return it directly. + + :param raw_data: Data to be processed. + :param content_type: How to parse if raw_data is a string/bytes. + :raises JSONDecodeError: If JSON is requested and parsing is impossible. + :raises UnicodeDecodeError: If bytes is not UTF8 + """ + # This avoids a circular dependency. We might want to consider RawDesializer is more generic + # than the pipeline concept, and put it in a toolbox, used both here and in pipeline. TBD. + from .pipeline.universal import RawDeserializer + + # Assume this is enough to detect a Pipeline Response without importing it + context = getattr(raw_data, "context", {}) + if context: + if RawDeserializer.CONTEXT_NAME in context: + return context[RawDeserializer.CONTEXT_NAME] + raise ValueError("This pipeline didn't have the RawDeserializer policy; can't deserialize") + + #Assume this is enough to recognize universal_http.ClientResponse without importing it + if hasattr(raw_data, "body"): + return RawDeserializer.deserialize_from_http_generics( + raw_data.text(), + raw_data.headers + ) + + # Assume this enough to recognize requests.Response without importing it. + if hasattr(raw_data, '_content_consumed'): + return RawDeserializer.deserialize_from_http_generics( + raw_data.text, + raw_data.headers + ) + + if isinstance(raw_data, (basestring, bytes)) or hasattr(raw_data, 'read'): + return RawDeserializer.deserialize_from_text(raw_data, content_type) + return raw_data + + def _instantiate_model(self, response, attrs, additional_properties=None): + """Instantiate a response model passing in deserialized args. + + :param response: The response model class. + :param d_attrs: The deserialized response attributes. + """ + if callable(response): + subtype = getattr(response, '_subtype_map', {}) + try: + readonly = [k for k, v in response._validation.items() + if v.get('readonly')] + const = [k for k, v in response._validation.items() + if v.get('constant')] + kwargs = {k: v for k, v in attrs.items() + if k not in subtype and k not in readonly + const} + response_obj = response(**kwargs) + for attr in readonly: + setattr(response_obj, attr, attrs.get(attr)) + if additional_properties: + response_obj.additional_properties = additional_properties + return response_obj + except TypeError as err: + msg = "Unable to deserialize {} into model {}. ".format( + kwargs, response) + raise DeserializationError(msg + str(err)) + else: + try: + for attr, value in attrs.items(): + setattr(response, attr, value) + return response + except Exception as exp: + msg = "Unable to populate response model. " + msg += "Type: {}, Error: {}".format(type(response), exp) + raise DeserializationError(msg) + + def deserialize_data(self, data, data_type): + """Process data for deserialization according to data type. + + :param str data: The response string to be deserialized. + :param str data_type: The type to deserialize to. + :raises: DeserializationError if deserialization fails. + :return: Deserialized object. + """ + if data is None: + return data + + try: + if not data_type: + return data + if data_type in self.basic_types.values(): + return self.deserialize_basic(data, data_type) + if data_type in self.deserialize_type: + if isinstance(data, self.deserialize_expected_types.get(data_type, tuple())): + return data + + is_a_text_parsing_type = lambda x: x not in ["object", "[]", r"{}"] + if isinstance(data, ET.Element) and is_a_text_parsing_type(data_type) and not data.text: + return None + data_val = self.deserialize_type[data_type](data) + return data_val + + iter_type = data_type[0] + data_type[-1] + if iter_type in self.deserialize_type: + return self.deserialize_type[iter_type](data, data_type[1:-1]) + + obj_type = self.dependencies[data_type] + if issubclass(obj_type, Enum): + if isinstance(data, ET.Element): + data = data.text + return self.deserialize_enum(data, obj_type) + + except (ValueError, TypeError, AttributeError) as err: + msg = "Unable to deserialize response data." + msg += " Data: {}, {}".format(data, data_type) + raise_with_traceback(DeserializationError, msg, err) + else: + return self._deserialize(obj_type, data) + + def deserialize_iter(self, attr, iter_type): + """Deserialize an iterable. + + :param list attr: Iterable to be deserialized. + :param str iter_type: The type of object in the iterable. + :rtype: list + """ + if attr is None: + return None + if isinstance(attr, ET.Element): # If I receive an element here, get the children + attr = list(attr) + if not isinstance(attr, (list, set)): + raise DeserializationError("Cannot deserialize as [{}] an object of type {}".format( + iter_type, + type(attr) + )) + return [self.deserialize_data(a, iter_type) for a in attr] + + def deserialize_dict(self, attr, dict_type): + """Deserialize a dictionary. + + :param dict/list attr: Dictionary to be deserialized. Also accepts + a list of key, value pairs. + :param str dict_type: The object type of the items in the dictionary. + :rtype: dict + """ + if isinstance(attr, list): + return {x['key']: self.deserialize_data(x['value'], dict_type) for x in attr} + + if isinstance(attr, ET.Element): + # Transform value into {"Key": "value"} + attr = {el.tag: el.text for el in attr} + return {k: self.deserialize_data(v, dict_type) for k, v in attr.items()} + + def deserialize_object(self, attr, **kwargs): + """Deserialize a generic object. + This will be handled as a dictionary. + + :param dict attr: Dictionary to be deserialized. + :rtype: dict + :raises: TypeError if non-builtin datatype encountered. + """ + if attr is None: + return None + if isinstance(attr, ET.Element): + # Do no recurse on XML, just return the tree as-is + return attr + if isinstance(attr, basestring): + return self.deserialize_basic(attr, 'str') + obj_type = type(attr) + if obj_type in self.basic_types: + return self.deserialize_basic(attr, self.basic_types[obj_type]) + if obj_type is _long_type: + return self.deserialize_long(attr) + + if obj_type == dict: + deserialized = {} + for key, value in attr.items(): + try: + deserialized[key] = self.deserialize_object( + value, **kwargs) + except ValueError: + deserialized[key] = None + return deserialized + + if obj_type == list: + deserialized = [] + for obj in attr: + try: + deserialized.append(self.deserialize_object( + obj, **kwargs)) + except ValueError: + pass + return deserialized + + else: + error = "Cannot deserialize generic object with type: " + raise TypeError(error + str(obj_type)) + + def deserialize_basic(self, attr, data_type): + """Deserialize baisc builtin data type from string. + Will attempt to convert to str, int, float and bool. + This function will also accept '1', '0', 'true' and 'false' as + valid bool values. + + :param str attr: response string to be deserialized. + :param str data_type: deserialization data type. + :rtype: str, int, float or bool + :raises: TypeError if string format is not valid. + """ + # If we're here, data is supposed to be a basic type. + # If it's still an XML node, take the text + if isinstance(attr, ET.Element): + attr = attr.text + if not attr: + if data_type == "str": + # None or '', node is empty string. + return '' + else: + # None or '', node with a strong type is None. + # Don't try to model "empty bool" or "empty int" + return None + + if data_type == 'bool': + if attr in [True, False, 1, 0]: + return bool(attr) + elif isinstance(attr, basestring): + if attr.lower() in ['true', '1']: + return True + elif attr.lower() in ['false', '0']: + return False + raise TypeError("Invalid boolean value: {}".format(attr)) + + if data_type == 'str': + return self.deserialize_unicode(attr) + return eval(data_type)(attr) + + @staticmethod + def deserialize_unicode(data): + """Preserve unicode objects in Python 2, otherwise return data + as a string. + + :param str data: response string to be deserialized. + :rtype: str or unicode + """ + # We might be here because we have an enum modeled as string, + # and we try to deserialize a partial dict with enum inside + if isinstance(data, Enum): + return data + + # Consider this is real string + try: + if isinstance(data, unicode): + return data + except NameError: + return str(data) + else: + return str(data) + + @staticmethod + def deserialize_enum(data, enum_obj): + """Deserialize string into enum object. + + If the string is not a valid enum value it will be returned as-is + and a warning will be logged. + + :param str data: Response string to be deserialized. If this value is + None or invalid it will be returned as-is. + :param Enum enum_obj: Enum object to deserialize to. + :rtype: Enum + """ + if isinstance(data, enum_obj) or data is None: + return data + if isinstance(data, Enum): + data = data.value + if isinstance(data, int): + # Workaround. We might consider remove it in the future. + # https://github.com/Azure/azure-rest-api-specs/issues/141 + try: + return list(enum_obj.__members__.values())[data] + except IndexError: + error = "{!r} is not a valid index for enum {!r}" + raise DeserializationError(error.format(data, enum_obj)) + try: + return enum_obj(str(data)) + except ValueError: + for enum_value in enum_obj: + if enum_value.value.lower() == str(data).lower(): + return enum_value + # We don't fail anymore for unknown value, we deserialize as a string + _LOGGER.warning("Deserializer is not able to find %s as valid enum in %s", data, enum_obj) + return Deserializer.deserialize_unicode(data) + + @staticmethod + def deserialize_bytearray(attr): + """Deserialize string into bytearray. + + :param str attr: response string to be deserialized. + :rtype: bytearray + :raises: TypeError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + return bytearray(b64decode(attr)) + + @staticmethod + def deserialize_base64(attr): + """Deserialize base64 encoded string into string. + + :param str attr: response string to be deserialized. + :rtype: bytearray + :raises: TypeError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + padding = '=' * (3 - (len(attr) + 3) % 4) + attr = attr + padding + encoded = attr.replace('-', '+').replace('_', '/') + return b64decode(encoded) + + @staticmethod + def deserialize_decimal(attr): + """Deserialize string into Decimal object. + + :param str attr: response string to be deserialized. + :rtype: Decimal + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + return decimal.Decimal(attr) + except decimal.DecimalException as err: + msg = "Invalid decimal {}".format(attr) + raise_with_traceback(DeserializationError, msg, err) + + @staticmethod + def deserialize_long(attr): + """Deserialize string into long (Py2) or int (Py3). + + :param str attr: response string to be deserialized. + :rtype: long or int + :raises: ValueError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + return _long_type(attr) + + @staticmethod + def deserialize_duration(attr): + """Deserialize ISO-8601 formatted string into TimeDelta object. + + :param str attr: response string to be deserialized. + :rtype: TimeDelta + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + duration = isodate.parse_duration(attr) + except(ValueError, OverflowError, AttributeError) as err: + msg = "Cannot deserialize duration object." + raise_with_traceback(DeserializationError, msg, err) + else: + return duration + + @staticmethod + def deserialize_date(attr): + """Deserialize ISO-8601 formatted string into Date object. + + :param str attr: response string to be deserialized. + :rtype: Date + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + if re.search(r"[^\W\d_]", attr, re.I + re.U): + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) + # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. + return isodate.parse_date(attr, defaultmonth=None, defaultday=None) + + @staticmethod + def deserialize_time(attr): + """Deserialize ISO-8601 formatted string into time object. + + :param str attr: response string to be deserialized. + :rtype: datetime.time + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + if re.search(r"[^\W\d_]", attr, re.I + re.U): + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) + return isodate.parse_time(attr) + + @staticmethod + def deserialize_rfc(attr): + """Deserialize RFC-1123 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: Datetime + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + parsed_date = email.utils.parsedate_tz(attr) + date_obj = datetime.datetime( + *parsed_date[:6], + tzinfo=_FixedOffset(datetime.timedelta(minutes=(parsed_date[9] or 0)/60)) + ) + if not date_obj.tzinfo: + date_obj = date_obj.astimezone(tz=TZ_UTC) + except ValueError as err: + msg = "Cannot deserialize to rfc datetime object." + raise_with_traceback(DeserializationError, msg, err) + else: + return date_obj + + @staticmethod + def deserialize_iso(attr): + """Deserialize ISO-8601 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: Datetime + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + attr = attr.upper() + match = Deserializer.valid_date.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + check_decimal = attr.split('.') + if len(check_decimal) > 1: + decimal_str = "" + for digit in check_decimal[1]: + if digit.isdigit(): + decimal_str += digit + else: + break + if len(decimal_str) > 6: + attr = attr.replace(decimal_str, decimal_str[0:6]) + + date_obj = isodate.parse_datetime(attr) + test_utc = date_obj.utctimetuple() + if test_utc.tm_year > 9999 or test_utc.tm_year < 1: + raise OverflowError("Hit max or min date") + except(ValueError, OverflowError, AttributeError) as err: + msg = "Cannot deserialize datetime object." + raise_with_traceback(DeserializationError, msg, err) + else: + return date_obj + + @staticmethod + def deserialize_unix(attr): + """Serialize Datetime object into IntTime format. + This is represented as seconds. + + :param int attr: Object to be serialized. + :rtype: Datetime + :raises: DeserializationError if format invalid + """ + if isinstance(attr, ET.Element): + attr = int(attr.text) + try: + date_obj = datetime.datetime.fromtimestamp(attr, TZ_UTC) + except ValueError as err: + msg = "Cannot deserialize to unix datetime object." + raise_with_traceback(DeserializationError, msg, err) + else: + return date_obj \ No newline at end of file From 2bb536c990467a760ab8c0953db33d7eeccdee34 Mon Sep 17 00:00:00 2001 From: antisch Date: Mon, 19 Jul 2021 08:51:59 -0700 Subject: [PATCH 03/15] Plug in deserializer --- .../azure/storage/blob/_blob_client.py | 2 ++ .../azure/storage/blob/_blob_service_client.py | 2 ++ .../azure/storage/blob/_container_client.py | 2 ++ .../azure/storage/blob/_shared/base_client.py | 14 ++++++++++++++ .../storage/blob/_shared/xml_deserialization.py | 14 +++++++------- .../azure/storage/blob/aio/_blob_client_async.py | 2 ++ .../storage/blob/aio/_blob_service_client_async.py | 2 ++ .../storage/blob/aio/_container_client_async.py | 2 ++ 8 files changed, 33 insertions(+), 7 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py index e3a659f9c867..f1289aeec4cf 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py @@ -31,6 +31,7 @@ validate_and_format_range_headers) from ._shared.response_handlers import return_response_headers, process_storage_error, return_headers_and_deserialized from ._generated import AzureBlobStorage +from ._generated import models as generated_models from ._generated.models import ( # pylint: disable=unused-import DeleteSnapshotsOptionType, BlobHTTPHeaders, @@ -175,6 +176,7 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential, snapshot=self.snapshot) super(BlobClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs) self._client = AzureBlobStorage(self.url, pipeline=self._pipeline) + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py index d277a094921a..1096d06d2a61 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py @@ -28,6 +28,7 @@ from ._shared.response_handlers import return_response_headers, process_storage_error, \ parse_to_internal_user_delegation_key from ._generated import AzureBlobStorage +from ._generated import models as generated_models from ._generated.models import StorageServiceProperties, KeyInfo from ._container_client import ContainerClient from ._blob_client import BlobClient @@ -134,6 +135,7 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential) super(BlobServiceClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs) self._client = AzureBlobStorage(self.url, pipeline=self._pipeline) + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py index 5d7e4254ed31..89b579950f73 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py @@ -34,6 +34,7 @@ return_response_headers, return_headers_and_deserialized) from ._generated import AzureBlobStorage +from ._generated import models as generated_models from ._generated.models import SignedIdentifier from ._deserialize import deserialize_container_properties from ._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions @@ -156,6 +157,7 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential) super(ContainerClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs) self._client = AzureBlobStorage(self.url, pipeline=self._pipeline) + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index a2efa2170228..f02a928b023f 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -35,6 +35,7 @@ AzureSasCredentialPolicy ) +from .xml_deserialization import Deserializer from .constants import STORAGE_OAUTH_SCOPE, SERVICE_HOST_BASE, CONNECTION_TIMEOUT, READ_TIMEOUT from .models import LocationMode from .authentication import SharedKeyCredentialPolicy @@ -198,6 +199,19 @@ def api_version(self): :type: str """ return self._client._config.version # pylint: disable=protected-access + + def _custom_xml_deserializer(self, generated_models): + """Reset the deserializer on the generated client to be Storage implementation""" + client_models = {k: v for k, v in generated_models.__dict__.items() if isinstance(v, type)} + custom_deserialize = Deserializer(client_models) + self._client._deserialize = custom_deserialize + self._client.service._deserialize = custom_deserialize + self._client.container._deserialize = custom_deserialize + self._client.directory._deserialize = custom_deserialize + self._client.blob._deserialize = custom_deserialize + self._client.page_blob._deserialize = custom_deserialize + self._client.append_blob._deserialize = custom_deserialize + self._client.block_blob._deserialize = custom_deserialize def _format_query_string(self, sas_token, credential, snapshot=None, share_snapshot=None): query_str = "?" diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py index 981348a538cb..5b808a638092 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -481,31 +481,31 @@ def _unpack_content(raw_data, content_type=None): """ # This avoids a circular dependency. We might want to consider RawDesializer is more generic # than the pipeline concept, and put it in a toolbox, used both here and in pipeline. TBD. - from .pipeline.universal import RawDeserializer + from azure.core.pipeline.policies import ContentDecodePolicy # Assume this is enough to detect a Pipeline Response without importing it context = getattr(raw_data, "context", {}) if context: - if RawDeserializer.CONTEXT_NAME in context: - return context[RawDeserializer.CONTEXT_NAME] - raise ValueError("This pipeline didn't have the RawDeserializer policy; can't deserialize") + if ContentDecodePolicy.CONTEXT_NAME in context: + return context[ContentDecodePolicy.CONTEXT_NAME] + raise ValueError("This pipeline didn't have the ContentDecodePolicy policy; can't deserialize") #Assume this is enough to recognize universal_http.ClientResponse without importing it if hasattr(raw_data, "body"): - return RawDeserializer.deserialize_from_http_generics( + return ContentDecodePolicy.deserialize_from_http_generics( raw_data.text(), raw_data.headers ) # Assume this enough to recognize requests.Response without importing it. if hasattr(raw_data, '_content_consumed'): - return RawDeserializer.deserialize_from_http_generics( + return ContentDecodePolicy.deserialize_from_http_generics( raw_data.text, raw_data.headers ) if isinstance(raw_data, (basestring, bytes)) or hasattr(raw_data, 'read'): - return RawDeserializer.deserialize_from_text(raw_data, content_type) + return ContentDecodePolicy.deserialize_from_text(raw_data, content_type) return raw_data def _instantiate_model(self, response, attrs, additional_properties=None): diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py index d13de28f8711..39e2351e2263 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py @@ -21,6 +21,7 @@ from .._deserialize import get_page_ranges_result, parse_tags, deserialize_pipeline_response_into_cls from .._serialize import get_modify_conditions, get_api_version, get_access_conditions from .._generated.aio import AzureBlobStorage +from .._generated import models as generated_models from .._generated.models import CpkInfo from .._deserialize import deserialize_blob_properties from .._blob_client import BlobClient as BlobClientBase @@ -120,6 +121,7 @@ def __init__( credential=credential, **kwargs) self._client = AzureBlobStorage(url=self.url, pipeline=self._pipeline) + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py index d50661d8e2d7..f673c5b35e75 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py @@ -24,6 +24,7 @@ from .._shared.parser import _to_utc_datetime from .._shared.response_handlers import parse_to_internal_user_delegation_key from .._generated.aio import AzureBlobStorage +from .._generated import models as generated_models from .._generated.models import StorageServiceProperties, KeyInfo from .._blob_service_client import BlobServiceClient as BlobServiceClientBase from ._container_client_async import ContainerClient @@ -118,6 +119,7 @@ def __init__( credential=credential, **kwargs) self._client = AzureBlobStorage(url=self.url, pipeline=self._pipeline) + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py index 498a646ff1f5..7c58993b6aea 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py @@ -26,6 +26,7 @@ return_response_headers, return_headers_and_deserialized) from .._generated.aio import AzureBlobStorage +from .._generated import models as generated_models from .._generated.models import SignedIdentifier from .._deserialize import deserialize_container_properties from .._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions @@ -117,6 +118,7 @@ def __init__( credential=credential, **kwargs) self._client = AzureBlobStorage(url=self.url, pipeline=self._pipeline) + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access From 7b78ee76c8d6869c2f9f832b0a6286b34bc4f701 Mon Sep 17 00:00:00 2001 From: antisch Date: Mon, 19 Jul 2021 15:52:53 -0700 Subject: [PATCH 04/15] Remove content decide policy --- .../aio/operations/_container_operations.py | 6 +- .../operations/_container_operations.py | 8 +- .../azure/storage/blob/_list_blobs_helper.py | 22 ++- .../azure/storage/blob/_shared/base_client.py | 2 - .../storage/blob/_shared/base_client_async.py | 2 - .../storage/blob/_shared/response_handlers.py | 7 +- .../blob/_shared/xml_deserialization.py | 133 ++++++++++++++---- .../storage/blob/aio/_list_blobs_helper.py | 13 +- 8 files changed, 131 insertions(+), 62 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py index 91edd8c2d0d6..f4d4e68cdaa1 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py @@ -1564,12 +1564,12 @@ async def list_blob_hierarchy_segment( response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id')) response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version')) response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date')) - deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response) + # deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response) if cls: - return cls(pipeline_response, deserialized, response_headers) + return cls(pipeline_response, None, response_headers) - return deserialized + return None list_blob_hierarchy_segment.metadata = {'url': '/{containerName}'} # type: ignore async def get_account_info( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py index 39db923628be..018c93984bf2 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py @@ -1476,7 +1476,7 @@ def list_blob_flat_segment( if cls: return cls(pipeline_response, None, response_headers) - return deserialized + return None # deserialized list_blob_flat_segment.metadata = {'url': '/{containerName}'} # type: ignore def list_blob_hierarchy_segment( @@ -1584,12 +1584,12 @@ def list_blob_hierarchy_segment( response_headers['x-ms-request-id']=self._deserialize('str', response.headers.get('x-ms-request-id')) response_headers['x-ms-version']=self._deserialize('str', response.headers.get('x-ms-version')) response_headers['Date']=self._deserialize('rfc-1123', response.headers.get('Date')) - deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response) + # deserialized = self._deserialize('ListBlobsHierarchySegmentResponse', pipeline_response) if cls: - return cls(pipeline_response, deserialized, response_headers) + return cls(pipeline_response, None, response_headers) - return deserialized + return None # deserialized list_blob_hierarchy_segment.metadata = {'url': '/{containerName}'} # type: ignore def get_account_info( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py index 66b4b87d39ed..e67ca5434c1f 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py @@ -7,17 +7,19 @@ from azure.core.paging import PageIterator, ItemPaged from azure.core.exceptions import HttpResponseError -from azure.core.pipeline.policies import ContentDecodePolicy +#from azure.core.pipeline.policies import ContentDecodePolicy from ._deserialize import get_blob_properties_from_generated_code, parse_tags from ._generated.models import BlobItemInternal, BlobPrefix as GenBlobPrefix, FilterBlobItem from ._models import BlobProperties, FilteredBlob from ._shared.models import DictMixin +from ._shared.xml_deserialization import deserialize_from_http_generics from ._shared.response_handlers import return_context_and_deserialized, process_storage_error def deserialize_list_result(pipeline_response, _, headers): - payload = pipeline_response.context[ContentDecodePolicy.CONTEXT_NAME] + #payload = pipeline_response.context[ContentDecodePolicy.CONTEXT_NAME] + payload = deserialize_from_http_generics(pipeline_response.http_response) location = pipeline_response.http_response.location_mode return location, payload @@ -64,13 +66,6 @@ def blob_properties_from_xml(element, select, deserializer): if 'version' in select: blob.version_id = load_xml_string(element, 'VersionId') blob.is_current_version = load_xml_bool(element, 'IsCurrentVersion') - # TODO: Should also support selecting 'tags' and 'metadata', but these are only returned - # if opted-into with the 'include' parameter. - # if 'metadata' in select: - # blob.metadata = None - # blob.encrypted_metadata = None - # if 'tags' in select: - # blob.tags = None return blob @@ -171,20 +166,23 @@ def _extract_data_cb(self, get_next_return): continuation_token, _ = super(BlobPrefixPaged, self)._extract_data_cb(get_next_return) blob_prefixes = load_many_nodes(self._response, 'BlobPrefix', wrapper='Blobs') - blob_prefixes = [self._build_item(blob) for blob in blob_prefixes] + blob_prefixes = [self._build_prefix(blob) for blob in blob_prefixes] self.current_page = blob_prefixes + self.current_page self.delimiter = load_xml_string(self._response, 'Delimiter') return continuation_token, self.current_page - def _build_item(self, item): + def _build_prefix(self, item): return BlobPrefix( self._command, container=self.container, prefix=load_xml_string(item, 'Name'), results_per_page=self.results_per_page, - location_mode=self.location_mode + location_mode=self.location_mode, + select=self.select, + deserializer=self._deserializer, + delimiter=self.delimiter ) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index f02a928b023f..7aaaeadd4d60 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -26,7 +26,6 @@ from azure.core.pipeline.transport import RequestsTransport, HttpTransport from azure.core.pipeline.policies import ( RedirectPolicy, - ContentDecodePolicy, BearerTokenCredentialPolicy, ProxyPolicy, DistributedTracingPolicy, @@ -254,7 +253,6 @@ def _create_pipeline(self, credential, **kwargs): config.proxy_policy, config.user_agent_policy, StorageContentValidation(), - ContentDecodePolicy(response_encoding="utf-8"), RedirectPolicy(**kwargs), StorageHosts(hosts=self._hosts, **kwargs), config.retry_policy, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 3e619c90fd71..9df1b28e1069 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -15,7 +15,6 @@ from azure.core.async_paging import AsyncList from azure.core.exceptions import HttpResponseError from azure.core.pipeline.policies import ( - ContentDecodePolicy, AsyncBearerTokenCredentialPolicy, AsyncRedirectPolicy, DistributedTracingPolicy, @@ -97,7 +96,6 @@ def _create_pipeline(self, credential, **kwargs): StorageContentValidation(), StorageRequestHook(**kwargs), self._credential_policy, - ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), StorageHosts(hosts=self._hosts, **kwargs), # type: ignore config.retry_policy, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/response_handlers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/response_handlers.py index e5a351417e60..cb1efd798821 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/response_handlers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/response_handlers.py @@ -10,7 +10,7 @@ import logging from xml.etree.ElementTree import Element -from azure.core.pipeline.policies import ContentDecodePolicy +#from azure.core.pipeline.policies import ContentDecodePolicy from azure.core.exceptions import ( HttpResponseError, ResourceNotFoundError, @@ -19,6 +19,7 @@ ClientAuthenticationError, DecodeError) +from .xml_deserialization import deserialize_from_http_generics from .parser import _to_utc_datetime from .models import StorageErrorCode, UserDelegationKey, get_enum_value @@ -96,7 +97,7 @@ def process_storage_error(storage_error): # pylint:disable=too-many-statements additional_data = {} error_dict = {} try: - error_body = ContentDecodePolicy.deserialize_from_http_generics(storage_error.response) + error_body = deserialize_from_http_generics(storage_error.response) # If it is an XML response if isinstance(error_body, Element): error_dict = { @@ -108,7 +109,7 @@ def process_storage_error(storage_error): # pylint:disable=too-many-statements error_dict = error_body.get('error', {}) elif not error_code: _LOGGER.warning( - 'Unexpected return type % from ContentDecodePolicy.deserialize_from_http_generics.', type(error_body)) + 'Unexpected return type % from deserialize_from_http_generics.', type(error_body)) error_dict = {'message': str(error_body)} # If we extracted from a Json or XML response diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py index 5b808a638092..7bbba18466f0 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -50,7 +50,7 @@ import isodate -from typing import Dict, Any +from typing import Dict, Any, cast, IO try: @@ -68,6 +68,7 @@ _long_type = int +from azure.core.exceptions import DecodeError from msrest.exceptions import DeserializationError, raise_with_traceback from msrest.serialization import ( TZ_UTC, @@ -75,6 +76,7 @@ _FLATTEN ) + def full_restapi_key_transformer(key, attr_desc, value): """A key transformer that returns the full RestAPI key path. @@ -479,34 +481,12 @@ def _unpack_content(raw_data, content_type=None): :raises JSONDecodeError: If JSON is requested and parsing is impossible. :raises UnicodeDecodeError: If bytes is not UTF8 """ - # This avoids a circular dependency. We might want to consider RawDesializer is more generic - # than the pipeline concept, and put it in a toolbox, used both here and in pipeline. TBD. - from azure.core.pipeline.policies import ContentDecodePolicy - - # Assume this is enough to detect a Pipeline Response without importing it - context = getattr(raw_data, "context", {}) - if context: - if ContentDecodePolicy.CONTEXT_NAME in context: - return context[ContentDecodePolicy.CONTEXT_NAME] - raise ValueError("This pipeline didn't have the ContentDecodePolicy policy; can't deserialize") - - #Assume this is enough to recognize universal_http.ClientResponse without importing it - if hasattr(raw_data, "body"): - return ContentDecodePolicy.deserialize_from_http_generics( - raw_data.text(), - raw_data.headers - ) - - # Assume this enough to recognize requests.Response without importing it. - if hasattr(raw_data, '_content_consumed'): - return ContentDecodePolicy.deserialize_from_http_generics( - raw_data.text, - raw_data.headers - ) - if isinstance(raw_data, (basestring, bytes)) or hasattr(raw_data, 'read'): - return ContentDecodePolicy.deserialize_from_text(raw_data, content_type) - return raw_data + return deserialize_from_text(raw_data, content_type) + try: + return deserialize_from_http_generics(raw_data.http_response) + except AttributeError: + return raw_data def _instantiate_model(self, response, attrs, additional_properties=None): """Instantiate a response model passing in deserialized args. @@ -939,4 +919,99 @@ def deserialize_unix(attr): msg = "Cannot deserialize to unix datetime object." raise_with_traceback(DeserializationError, msg, err) else: - return date_obj \ No newline at end of file + return date_obj + + + +def deserialize_from_text( + data, # type: Optional[Union[AnyStr, IO]] + mime_type=None, # Optional[str] + response=None # Optional[Union[HttpResponse, AsyncHttpResponse]] +): + """Decode response data according to content-type. + Accept a stream of data as well, but will be load at once in memory for now. + If no content-type, will return the string version (not bytes, not stream) + :param response: The HTTP response. + :type response: ~azure.core.pipeline.transport.HttpResponse + :param str mime_type: The mime type. As mime type, charset is not expected. + :param response: If passed, exception will be annotated with that response + :raises ~azure.core.exceptions.DecodeError: If deserialization fails + :returns: A dict or XML tree, depending of the mime_type + """ + if not data: + return None + + if hasattr(data, 'read'): + # Assume a stream + data = cast(IO, data).read() + + if isinstance(data, bytes): + data_as_str = data.decode(encoding='utf-8-sig') + else: + # Explain to mypy the correct type. + data_as_str = cast(str, data) + + if mime_type is None: + return data_as_str + + if "xml" in (mime_type or []): + try: + try: + if isinstance(data, unicode): # type: ignore + # If I'm Python 2.7 and unicode XML will scream if I try a "fromstring" on unicode string + data_as_str = cast(str, data_as_str.encode(encoding="utf-8")) + except NameError: + pass + return ET.fromstring(data_as_str) # nosec + except ET.ParseError: + # It might be because the server has an issue, and returned JSON with + # content-type XML.... + # So let's try a JSON load, and if it's still broken + # let's flow the initial exception + def _json_attemp(data): + try: + return True, json.loads(data) + except ValueError: + return False, None # Don't care about this one + success, json_result = _json_attemp(data) + if success: + return json_result + # If i'm here, it's not JSON, it's not XML, let's scream + # and raise the last context in this block (the XML exception) + # The function hack is because Py2.7 messes up with exception + # context otherwise. + _LOGGER.critical("Wasn't XML not JSON, failing") + raise_with_traceback(DecodeError, message="XML is invalid", response=response) + elif mime_type.startswith("text/"): + return data_as_str + else: + try: + return json.loads(data_as_str) + except ValueError as err: + raise DecodeError(message="JSON is invalid: {}".format(err), response=response, error=err) + raise DecodeError("Cannot deserialize content-type: {}".format(mime_type)) + + +def deserialize_from_http_generics( + response, # Union[HttpResponse, AsyncHttpResponse] + encoding=None, # Optional[str] +): + """Deserialize from HTTP response. + Headers will tested for "content-type" + :param response: The HTTP response + :param encoding: The encoding to use if known for this service (will disable auto-detection) + :raises ~azure.core.exceptions.DecodeError: If deserialization fails + :returns: A dict or XML tree, depending of the mime-type + """ + # Try to use content-type from headers if available + if response.content_type: + mime_type = response.content_type.split(";")[0].strip().lower() + # Ouch, this server did not declare what it sent... + # Let's guess it's JSON... + # Also, since Autorest was considering that an empty body was a valid JSON, + # need that test as well.... + else: + mime_type = "application/json" + + # Rely on transport implementation to give me "text()" decoded correctly + return deserialize_from_text(response.text(encoding), mime_type, response=response) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py index b0bb95c63d6c..a4822bd5445f 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py @@ -155,21 +155,20 @@ def __init__(self, *args, **kwargs): async def _extract_data_cb(self, get_next_return): continuation_token, current_page = await super(BlobPrefixPaged, self)._extract_data_cb(get_next_return) - blob_prefixes = load_many_nodes(self._response, 'BlobPrefix', wrapper='Blobs') - blob_prefixes = [self._build_item(blob) for blob in blob_prefixes] - + blob_prefixes = [self._build_prefix(blob) for blob in blob_prefixes] self.current_page = blob_prefixes + current_page self.delimiter = load_xml_string(self._response, 'Delimiter') - self.current_page = self._response.segment.blob_prefixes + self._response.segment.blob_items - return continuation_token, self.current_page - def _build_item(self, item): + def _build_prefix(self, item): return BlobPrefix( self._command, container=self.container, prefix=load_xml_string(item, 'Name'), results_per_page=self.results_per_page, - location_mode=self.location_mode + location_mode=self.location_mode, + select=self.select, + deserializer=self._deserializer, + delimiter=self.delimiter ) From 12a96ab4d0d6ea563cb962341bbee7b48879fda4 Mon Sep 17 00:00:00 2001 From: antisch Date: Tue, 20 Jul 2021 08:26:32 -0700 Subject: [PATCH 05/15] Some code cleanup --- .../blob/_shared/xml_deserialization.py | 100 +----------------- 1 file changed, 1 insertion(+), 99 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py index 7bbba18466f0..07f664e04bdc 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -77,29 +77,6 @@ ) -def full_restapi_key_transformer(key, attr_desc, value): - """A key transformer that returns the full RestAPI key path. - - :param str _: The attribute name - :param dict attr_desc: The attribute metadata - :param object value: The value - :returns: A list of keys using RestAPI syntax. - """ - keys = _FLATTEN.split(attr_desc['key']) - return ([_decode_attribute_map_key(k) for k in keys], value) - -def last_restapi_key_transformer(key, attr_desc, value): - """A key transformer that returns the last RestAPI key. - - :param str key: The attribute name - :param dict attr_desc: The attribute metadata - :param object value: The value - :returns: The last RestAPI key. - """ - key, value = full_restapi_key_transformer(key, attr_desc, value) - return (key[-1], value) - - def _decode_attribute_map_key(key): """This decode a key in an _attribute_map to the actual key we want to look at inside the received data. @@ -109,76 +86,6 @@ def _decode_attribute_map_key(key): return key.replace('\\.', '.') -def rest_key_extractor(attr, attr_desc, data): - key = attr_desc['key'] - working_data = data - - while '.' in key: - dict_keys = _FLATTEN.split(key) - if len(dict_keys) == 1: - key = _decode_attribute_map_key(dict_keys[0]) - break - working_key = _decode_attribute_map_key(dict_keys[0]) - working_data = working_data.get(working_key, data) - if working_data is None: - # If at any point while following flatten JSON path see None, it means - # that all properties under are None as well - # https://github.com/Azure/msrest-for-python/issues/197 - return None - key = '.'.join(dict_keys[1:]) - - return working_data.get(key) - -def rest_key_case_insensitive_extractor(attr, attr_desc, data): - key = attr_desc['key'] - working_data = data - - while '.' in key: - dict_keys = _FLATTEN.split(key) - if len(dict_keys) == 1: - key = _decode_attribute_map_key(dict_keys[0]) - break - working_key = _decode_attribute_map_key(dict_keys[0]) - working_data = attribute_key_case_insensitive_extractor(working_key, None, working_data) - if working_data is None: - # If at any point while following flatten JSON path see None, it means - # that all properties under are None as well - # https://github.com/Azure/msrest-for-python/issues/197 - return None - key = '.'.join(dict_keys[1:]) - - if working_data: - return attribute_key_case_insensitive_extractor(key, None, working_data) - -def last_rest_key_extractor(attr, attr_desc, data): - """Extract the attribute in "data" based on the last part of the JSON path key. - """ - key = attr_desc['key'] - dict_keys = _FLATTEN.split(key) - return attribute_key_extractor(dict_keys[-1], None, data) - -def last_rest_key_case_insensitive_extractor(attr, attr_desc, data): - """Extract the attribute in "data" based on the last part of the JSON path key. - - This is the case insensitive version of "last_rest_key_extractor" - """ - key = attr_desc['key'] - dict_keys = _FLATTEN.split(key) - return attribute_key_case_insensitive_extractor(dict_keys[-1], None, data) - -def attribute_key_extractor(attr, _, data): - return data.get(attr) - -def attribute_key_case_insensitive_extractor(attr, _, data): - found_key = None - lower_attr = attr.lower() - for key in data: - if lower_attr == key.lower(): - found_key = key - break - - return data.get(found_key) - def _extract_name_from_internal_type(internal_type): """Given an internal type XML description, extract correct XML name with namespace. @@ -302,7 +209,7 @@ def __init__(self, classes=None): } self.dependencies = dict(classes) if classes else {} self.key_extractors = [ - rest_key_extractor, + # rest_key_extractor, xml_key_extractor ] # Additional properties only works if the "rest_key_extractor" is used to @@ -437,11 +344,6 @@ def _classify_target(self, target, data): target = self.dependencies[target] except KeyError: return target, target - - try: - target = target._classify(data, self.dependencies) - except AttributeError: - pass # Target is not a Model, no classify return target, target.__class__.__name__ def failsafe_deserialize(self, target_obj, data, content_type=None): From 85dd1802c8be316adf6e256e05080b4a2e097342 Mon Sep 17 00:00:00 2001 From: antisch Date: Tue, 20 Jul 2021 13:30:23 -0700 Subject: [PATCH 06/15] Refactor part 1 --- .../blob/_shared/xml_deserialization.py | 602 +++++++++--------- 1 file changed, 306 insertions(+), 296 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py index 07f664e04bdc..cd350f0b6bb2 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -61,6 +61,9 @@ unicode_str = str # type: ignore _LOGGER = logging.getLogger(__name__) +_valid_date = re.compile( + r'\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}' + r'\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?') try: _long_type = long # type: ignore @@ -76,6 +79,282 @@ _FLATTEN ) +def deserialize_bytearray(attr, *args): + """Deserialize string into bytearray. + + :param str attr: response string to be deserialized. + :rtype: bytearray + :raises: TypeError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + return bytearray(b64decode(attr)) + + +def deserialize_base64(attr, *args): + """Deserialize base64 encoded string into string. + + :param str attr: response string to be deserialized. + :rtype: bytearray + :raises: TypeError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + padding = '=' * (3 - (len(attr) + 3) % 4) + attr = attr + padding + encoded = attr.replace('-', '+').replace('_', '/') + return b64decode(encoded) + + +def deserialize_decimal(attr, *args): + """Deserialize string into Decimal object. + + :param str attr: response string to be deserialized. + :rtype: Decimal + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + return decimal.Decimal(attr) + except decimal.DecimalException as err: + msg = "Invalid decimal {}".format(attr) + raise_with_traceback(DeserializationError, msg, err) + + +def deserialize_long(attr, *args): + """Deserialize string into long (Py2) or int (Py3). + + :param str attr: response string to be deserialized. + :rtype: long or int + :raises: ValueError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + return _long_type(attr) + + +def deserialize_duration(attr, *args): + """Deserialize ISO-8601 formatted string into TimeDelta object. + + :param str attr: response string to be deserialized. + :rtype: TimeDelta + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + duration = isodate.parse_duration(attr) + except(ValueError, OverflowError, AttributeError) as err: + msg = "Cannot deserialize duration object." + raise_with_traceback(DeserializationError, msg, err) + else: + return duration + + +def deserialize_date(attr, *args): + """Deserialize ISO-8601 formatted string into Date object. + + :param str attr: response string to be deserialized. + :rtype: Date + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + if re.search(r"[^\W\d_]", attr, re.I + re.U): + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) + # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. + return isodate.parse_date(attr, defaultmonth=None, defaultday=None) + + +def deserialize_time(attr, *args): + """Deserialize ISO-8601 formatted string into time object. + + :param str attr: response string to be deserialized. + :rtype: datetime.time + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + if re.search(r"[^\W\d_]", attr, re.I + re.U): + raise DeserializationError("Date must have only digits and -. Received: %s" % attr) + return isodate.parse_time(attr) + + +def deserialize_rfc(attr, *args): + """Deserialize RFC-1123 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: Datetime + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + parsed_date = email.utils.parsedate_tz(attr) + date_obj = datetime.datetime( + *parsed_date[:6], + tzinfo=_FixedOffset(datetime.timedelta(minutes=(parsed_date[9] or 0)/60)) + ) + if not date_obj.tzinfo: + date_obj = date_obj.astimezone(tz=TZ_UTC) + except ValueError as err: + msg = "Cannot deserialize to rfc datetime object." + raise_with_traceback(DeserializationError, msg, err) + else: + return date_obj + + +def deserialize_iso(attr, *args): + """Deserialize ISO-8601 formatted string into Datetime object. + + :param str attr: response string to be deserialized. + :rtype: Datetime + :raises: DeserializationError if string format invalid. + """ + if isinstance(attr, ET.Element): + attr = attr.text + try: + attr = attr.upper() + match = _valid_date.match(attr) + if not match: + raise ValueError("Invalid datetime string: " + attr) + + check_decimal = attr.split('.') + if len(check_decimal) > 1: + decimal_str = "" + for digit in check_decimal[1]: + if digit.isdigit(): + decimal_str += digit + else: + break + if len(decimal_str) > 6: + attr = attr.replace(decimal_str, decimal_str[0:6]) + + date_obj = isodate.parse_datetime(attr) + test_utc = date_obj.utctimetuple() + if test_utc.tm_year > 9999 or test_utc.tm_year < 1: + raise OverflowError("Hit max or min date") + except(ValueError, OverflowError, AttributeError) as err: + msg = "Cannot deserialize datetime object." + raise_with_traceback(DeserializationError, msg, err) + else: + return date_obj + + +def deserialize_unix(attr, *args): + """Serialize Datetime object into IntTime format. + This is represented as seconds. + + :param int attr: Object to be serialized. + :rtype: Datetime + :raises: DeserializationError if format invalid + """ + if isinstance(attr, ET.Element): + attr = int(attr.text) + try: + date_obj = datetime.datetime.fromtimestamp(attr, TZ_UTC) + except ValueError as err: + msg = "Cannot deserialize to unix datetime object." + raise_with_traceback(DeserializationError, msg, err) + else: + return date_obj + + +def deserialize_unicode(data, *args): + """Preserve unicode objects in Python 2, otherwise return data + as a string. + + :param str data: response string to be deserialized. + :rtype: str or unicode + """ + # We might be here because we have an enum modeled as string, + # and we try to deserialize a partial dict with enum inside + if isinstance(data, Enum): + return data + + # Consider this is real string + try: + if isinstance(data, unicode): + return data + except NameError: + return str(data) + else: + return str(data) + + +def deserialize_enum(data, enum_obj): + """Deserialize string into enum object. + + If the string is not a valid enum value it will be returned as-is + and a warning will be logged. + + :param str data: Response string to be deserialized. If this value is + None or invalid it will be returned as-is. + :param Enum enum_obj: Enum object to deserialize to. + :rtype: Enum + """ + if isinstance(data, enum_obj) or data is None: + return data + if isinstance(data, Enum): + data = data.value + if isinstance(data, int): + # Workaround. We might consider remove it in the future. + # https://github.com/Azure/azure-rest-api-specs/issues/141 + try: + return list(enum_obj.__members__.values())[data] + except IndexError: + error = "{!r} is not a valid index for enum {!r}" + raise DeserializationError(error.format(data, enum_obj)) + try: + return enum_obj(str(data)) + except ValueError: + for enum_value in enum_obj: + if enum_value.value.lower() == str(data).lower(): + return enum_value + # We don't fail anymore for unknown value, we deserialize as a string + _LOGGER.warning("Deserializer is not able to find %s as valid enum in %s", data, enum_obj) + return deserialize_unicode(data) + + +def deserialize_basic(attr, data_type): + """Deserialize baisc builtin data type from string. + Will attempt to convert to str, int, float and bool. + This function will also accept '1', '0', 'true' and 'false' as + valid bool values. + + :param str attr: response string to be deserialized. + :param str data_type: deserialization data type. + :rtype: str, int, float or bool + :raises: TypeError if string format is not valid. + """ + # If we're here, data is supposed to be a basic type. + # If it's still an XML node, take the text + if isinstance(attr, ET.Element): + attr = attr.text + if not attr: + if data_type == "str": + # None or '', node is empty string. + return '' + else: + # None or '', node with a strong type is None. + # Don't try to model "empty bool" or "empty int" + return None + + if data_type == 'bool': + if attr in [True, False, 1, 0]: + return bool(attr) + elif isinstance(attr, basestring): + if attr.lower() in ['true', '1']: + return True + elif attr.lower() in ['false', '0']: + return False + raise TypeError("Invalid boolean value: {}".format(attr)) + + if data_type == 'str': + return deserialize_unicode(attr) + return eval(data_type)(attr) + def _decode_attribute_map_key(key): """This decode a key in an _attribute_map to the actual key we want to look at @@ -183,22 +462,22 @@ class Deserializer(object): basic_types = {str: 'str', int: 'int', bool: 'bool', float: 'float'} - valid_date = re.compile( - r'\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}' - r'\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?') - def __init__(self, classes=None): self.deserialize_type = { - 'iso-8601': Deserializer.deserialize_iso, - 'rfc-1123': Deserializer.deserialize_rfc, - 'unix-time': Deserializer.deserialize_unix, - 'duration': Deserializer.deserialize_duration, - 'date': Deserializer.deserialize_date, - 'time': Deserializer.deserialize_time, - 'decimal': Deserializer.deserialize_decimal, - 'long': Deserializer.deserialize_long, - 'bytearray': Deserializer.deserialize_bytearray, - 'base64': Deserializer.deserialize_base64, + 'str': deserialize_basic, + 'int': deserialize_basic, + 'bool': deserialize_basic, + 'float': deserialize_basic, + 'iso-8601': deserialize_iso, + 'rfc-1123': deserialize_rfc, + 'unix-time': deserialize_unix, + 'duration': deserialize_duration, + 'date': deserialize_date, + 'time': deserialize_time, + 'decimal': deserialize_decimal, + 'long': deserialize_long, + 'bytearray': deserialize_bytearray, + 'base64': deserialize_base64, 'object': self.deserialize_object, '[]': self.deserialize_iter, '{}': self.deserialize_dict @@ -209,7 +488,6 @@ def __init__(self, classes=None): } self.dependencies = dict(classes) if classes else {} self.key_extractors = [ - # rest_key_extractor, xml_key_extractor ] # Additional properties only works if the "rest_key_extractor" is used to @@ -229,6 +507,12 @@ def __call__(self, target_obj, response_data, content_type=None): :raises: DeserializationError if deserialization fails. :return: Deserialized object. """ + if response_data is None: + return None + try: + return self.deserialize_type[target_obj](response_data, target_obj) + except KeyError: + pass data = self._unpack_content(response_data, content_type) return self._deserialize(target_obj, data) @@ -271,6 +555,7 @@ def _deserialize(self, target_obj, data): if isinstance(response, basestring): return self.deserialize_data(data, response) elif isinstance(response, type) and issubclass(response, Enum): + raise Exception("BOOM additional enum") return self.deserialize_enum(data, response) if data is None: @@ -293,6 +578,7 @@ def _deserialize(self, target_obj, data): found_value = key_extractor(attr, attr_desc, data) if found_value is not None: if raw_value is not None and raw_value != found_value: + raise Exception("BOOM raw_value") msg = ("Ignoring extracted value '%s' from %s for key '%s'" " (duplicate extraction, follow extractors order)" ) _LOGGER.warning( @@ -440,9 +726,10 @@ def deserialize_data(self, data, data_type): if not data_type: return data if data_type in self.basic_types.values(): - return self.deserialize_basic(data, data_type) + return deserialize_basic(data, data_type) if data_type in self.deserialize_type: if isinstance(data, self.deserialize_expected_types.get(data_type, tuple())): + raise Exception("BOOM expected types") return data is_a_text_parsing_type = lambda x: x not in ["object", "[]", r"{}"] @@ -516,12 +803,12 @@ def deserialize_object(self, attr, **kwargs): # Do no recurse on XML, just return the tree as-is return attr if isinstance(attr, basestring): - return self.deserialize_basic(attr, 'str') + return deserialize_basic(attr, 'str') obj_type = type(attr) if obj_type in self.basic_types: - return self.deserialize_basic(attr, self.basic_types[obj_type]) + return deserialize_basic(attr, self.basic_types[obj_type]) if obj_type is _long_type: - return self.deserialize_long(attr) + return deserialize_long(attr) if obj_type == dict: deserialized = {} @@ -547,283 +834,6 @@ def deserialize_object(self, attr, **kwargs): error = "Cannot deserialize generic object with type: " raise TypeError(error + str(obj_type)) - def deserialize_basic(self, attr, data_type): - """Deserialize baisc builtin data type from string. - Will attempt to convert to str, int, float and bool. - This function will also accept '1', '0', 'true' and 'false' as - valid bool values. - - :param str attr: response string to be deserialized. - :param str data_type: deserialization data type. - :rtype: str, int, float or bool - :raises: TypeError if string format is not valid. - """ - # If we're here, data is supposed to be a basic type. - # If it's still an XML node, take the text - if isinstance(attr, ET.Element): - attr = attr.text - if not attr: - if data_type == "str": - # None or '', node is empty string. - return '' - else: - # None or '', node with a strong type is None. - # Don't try to model "empty bool" or "empty int" - return None - - if data_type == 'bool': - if attr in [True, False, 1, 0]: - return bool(attr) - elif isinstance(attr, basestring): - if attr.lower() in ['true', '1']: - return True - elif attr.lower() in ['false', '0']: - return False - raise TypeError("Invalid boolean value: {}".format(attr)) - - if data_type == 'str': - return self.deserialize_unicode(attr) - return eval(data_type)(attr) - - @staticmethod - def deserialize_unicode(data): - """Preserve unicode objects in Python 2, otherwise return data - as a string. - - :param str data: response string to be deserialized. - :rtype: str or unicode - """ - # We might be here because we have an enum modeled as string, - # and we try to deserialize a partial dict with enum inside - if isinstance(data, Enum): - return data - - # Consider this is real string - try: - if isinstance(data, unicode): - return data - except NameError: - return str(data) - else: - return str(data) - - @staticmethod - def deserialize_enum(data, enum_obj): - """Deserialize string into enum object. - - If the string is not a valid enum value it will be returned as-is - and a warning will be logged. - - :param str data: Response string to be deserialized. If this value is - None or invalid it will be returned as-is. - :param Enum enum_obj: Enum object to deserialize to. - :rtype: Enum - """ - if isinstance(data, enum_obj) or data is None: - return data - if isinstance(data, Enum): - data = data.value - if isinstance(data, int): - # Workaround. We might consider remove it in the future. - # https://github.com/Azure/azure-rest-api-specs/issues/141 - try: - return list(enum_obj.__members__.values())[data] - except IndexError: - error = "{!r} is not a valid index for enum {!r}" - raise DeserializationError(error.format(data, enum_obj)) - try: - return enum_obj(str(data)) - except ValueError: - for enum_value in enum_obj: - if enum_value.value.lower() == str(data).lower(): - return enum_value - # We don't fail anymore for unknown value, we deserialize as a string - _LOGGER.warning("Deserializer is not able to find %s as valid enum in %s", data, enum_obj) - return Deserializer.deserialize_unicode(data) - - @staticmethod - def deserialize_bytearray(attr): - """Deserialize string into bytearray. - - :param str attr: response string to be deserialized. - :rtype: bytearray - :raises: TypeError if string format invalid. - """ - if isinstance(attr, ET.Element): - attr = attr.text - return bytearray(b64decode(attr)) - - @staticmethod - def deserialize_base64(attr): - """Deserialize base64 encoded string into string. - - :param str attr: response string to be deserialized. - :rtype: bytearray - :raises: TypeError if string format invalid. - """ - if isinstance(attr, ET.Element): - attr = attr.text - padding = '=' * (3 - (len(attr) + 3) % 4) - attr = attr + padding - encoded = attr.replace('-', '+').replace('_', '/') - return b64decode(encoded) - - @staticmethod - def deserialize_decimal(attr): - """Deserialize string into Decimal object. - - :param str attr: response string to be deserialized. - :rtype: Decimal - :raises: DeserializationError if string format invalid. - """ - if isinstance(attr, ET.Element): - attr = attr.text - try: - return decimal.Decimal(attr) - except decimal.DecimalException as err: - msg = "Invalid decimal {}".format(attr) - raise_with_traceback(DeserializationError, msg, err) - - @staticmethod - def deserialize_long(attr): - """Deserialize string into long (Py2) or int (Py3). - - :param str attr: response string to be deserialized. - :rtype: long or int - :raises: ValueError if string format invalid. - """ - if isinstance(attr, ET.Element): - attr = attr.text - return _long_type(attr) - - @staticmethod - def deserialize_duration(attr): - """Deserialize ISO-8601 formatted string into TimeDelta object. - - :param str attr: response string to be deserialized. - :rtype: TimeDelta - :raises: DeserializationError if string format invalid. - """ - if isinstance(attr, ET.Element): - attr = attr.text - try: - duration = isodate.parse_duration(attr) - except(ValueError, OverflowError, AttributeError) as err: - msg = "Cannot deserialize duration object." - raise_with_traceback(DeserializationError, msg, err) - else: - return duration - - @staticmethod - def deserialize_date(attr): - """Deserialize ISO-8601 formatted string into Date object. - - :param str attr: response string to be deserialized. - :rtype: Date - :raises: DeserializationError if string format invalid. - """ - if isinstance(attr, ET.Element): - attr = attr.text - if re.search(r"[^\W\d_]", attr, re.I + re.U): - raise DeserializationError("Date must have only digits and -. Received: %s" % attr) - # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. - return isodate.parse_date(attr, defaultmonth=None, defaultday=None) - - @staticmethod - def deserialize_time(attr): - """Deserialize ISO-8601 formatted string into time object. - - :param str attr: response string to be deserialized. - :rtype: datetime.time - :raises: DeserializationError if string format invalid. - """ - if isinstance(attr, ET.Element): - attr = attr.text - if re.search(r"[^\W\d_]", attr, re.I + re.U): - raise DeserializationError("Date must have only digits and -. Received: %s" % attr) - return isodate.parse_time(attr) - - @staticmethod - def deserialize_rfc(attr): - """Deserialize RFC-1123 formatted string into Datetime object. - - :param str attr: response string to be deserialized. - :rtype: Datetime - :raises: DeserializationError if string format invalid. - """ - if isinstance(attr, ET.Element): - attr = attr.text - try: - parsed_date = email.utils.parsedate_tz(attr) - date_obj = datetime.datetime( - *parsed_date[:6], - tzinfo=_FixedOffset(datetime.timedelta(minutes=(parsed_date[9] or 0)/60)) - ) - if not date_obj.tzinfo: - date_obj = date_obj.astimezone(tz=TZ_UTC) - except ValueError as err: - msg = "Cannot deserialize to rfc datetime object." - raise_with_traceback(DeserializationError, msg, err) - else: - return date_obj - - @staticmethod - def deserialize_iso(attr): - """Deserialize ISO-8601 formatted string into Datetime object. - - :param str attr: response string to be deserialized. - :rtype: Datetime - :raises: DeserializationError if string format invalid. - """ - if isinstance(attr, ET.Element): - attr = attr.text - try: - attr = attr.upper() - match = Deserializer.valid_date.match(attr) - if not match: - raise ValueError("Invalid datetime string: " + attr) - - check_decimal = attr.split('.') - if len(check_decimal) > 1: - decimal_str = "" - for digit in check_decimal[1]: - if digit.isdigit(): - decimal_str += digit - else: - break - if len(decimal_str) > 6: - attr = attr.replace(decimal_str, decimal_str[0:6]) - - date_obj = isodate.parse_datetime(attr) - test_utc = date_obj.utctimetuple() - if test_utc.tm_year > 9999 or test_utc.tm_year < 1: - raise OverflowError("Hit max or min date") - except(ValueError, OverflowError, AttributeError) as err: - msg = "Cannot deserialize datetime object." - raise_with_traceback(DeserializationError, msg, err) - else: - return date_obj - - @staticmethod - def deserialize_unix(attr): - """Serialize Datetime object into IntTime format. - This is represented as seconds. - - :param int attr: Object to be serialized. - :rtype: Datetime - :raises: DeserializationError if format invalid - """ - if isinstance(attr, ET.Element): - attr = int(attr.text) - try: - date_obj = datetime.datetime.fromtimestamp(attr, TZ_UTC) - except ValueError as err: - msg = "Cannot deserialize to unix datetime object." - raise_with_traceback(DeserializationError, msg, err) - else: - return date_obj - - def deserialize_from_text( data, # type: Optional[Union[AnyStr, IO]] From f984f6143bbe5adfcacbdedd93d864d9e398a4d2 Mon Sep 17 00:00:00 2001 From: antisch Date: Tue, 20 Jul 2021 14:27:32 -0700 Subject: [PATCH 07/15] Refactor part 2 --- .../azure/storage/blob/_list_blobs_helper.py | 6 +- .../storage/blob/_shared/response_handlers.py | 7 +- .../blob/_shared/xml_deserialization.py | 212 ++++++------------ 3 files changed, 77 insertions(+), 148 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py index e67ca5434c1f..790d417ea35b 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py @@ -7,19 +7,17 @@ from azure.core.paging import PageIterator, ItemPaged from azure.core.exceptions import HttpResponseError -#from azure.core.pipeline.policies import ContentDecodePolicy from ._deserialize import get_blob_properties_from_generated_code, parse_tags from ._generated.models import BlobItemInternal, BlobPrefix as GenBlobPrefix, FilterBlobItem from ._models import BlobProperties, FilteredBlob from ._shared.models import DictMixin -from ._shared.xml_deserialization import deserialize_from_http_generics +from ._shared.xml_deserialization import unpack_xml_content from ._shared.response_handlers import return_context_and_deserialized, process_storage_error def deserialize_list_result(pipeline_response, _, headers): - #payload = pipeline_response.context[ContentDecodePolicy.CONTEXT_NAME] - payload = deserialize_from_http_generics(pipeline_response.http_response) + payload = unpack_xml_content(pipeline_response.http_response) location = pipeline_response.http_response.location_mode return location, payload diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/response_handlers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/response_handlers.py index cb1efd798821..e5a351417e60 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/response_handlers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/response_handlers.py @@ -10,7 +10,7 @@ import logging from xml.etree.ElementTree import Element -#from azure.core.pipeline.policies import ContentDecodePolicy +from azure.core.pipeline.policies import ContentDecodePolicy from azure.core.exceptions import ( HttpResponseError, ResourceNotFoundError, @@ -19,7 +19,6 @@ ClientAuthenticationError, DecodeError) -from .xml_deserialization import deserialize_from_http_generics from .parser import _to_utc_datetime from .models import StorageErrorCode, UserDelegationKey, get_enum_value @@ -97,7 +96,7 @@ def process_storage_error(storage_error): # pylint:disable=too-many-statements additional_data = {} error_dict = {} try: - error_body = deserialize_from_http_generics(storage_error.response) + error_body = ContentDecodePolicy.deserialize_from_http_generics(storage_error.response) # If it is an XML response if isinstance(error_body, Element): error_dict = { @@ -109,7 +108,7 @@ def process_storage_error(storage_error): # pylint:disable=too-many-statements error_dict = error_body.get('error', {}) elif not error_code: _LOGGER.warning( - 'Unexpected return type % from deserialize_from_http_generics.', type(error_body)) + 'Unexpected return type % from ContentDecodePolicy.deserialize_from_http_generics.', type(error_body)) error_dict = {'message': str(error_body)} # If we extracted from a Json or XML response diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py index cd350f0b6bb2..21d2f4a61964 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -79,6 +79,36 @@ _FLATTEN ) + +def unpack_xml_content(response_data, content_type=None): + """Extract the correct structure for deserialization. + + If raw_data is a PipelineResponse, try to extract the result of RawDeserializer. + if we can't, raise. Your Pipeline should have a RawDeserializer. + + If not a pipeline response and raw_data is bytes or string, use content-type + to decode it. If no content-type, try JSON. + + If raw_data is something else, bypass all logic and return it directly. + + :param raw_data: Data to be processed. + :param content_type: How to parse if raw_data is a string/bytes. + :raises UnicodeDecodeError: If bytes is not UTF8 + """ + data_as_str = response_data.text() + try: + try: + if isinstance(raw_data, unicode): # type: ignore + # If I'm Python 2.7 and unicode XML will scream if I try a "fromstring" on unicode string + data_as_str = cast(str, data_as_str.encode(encoding="utf-8")) + except NameError: + pass + return ET.fromstring(data_as_str) # nosec + except ET.ParseError: + _LOGGER.critical("Response body invalid XML") + raise_with_traceback(DecodeError, message="XML is invalid", response=response_data) + + def deserialize_bytearray(attr, *args): """Deserialize string into bytearray. @@ -508,13 +538,53 @@ def __call__(self, target_obj, response_data, content_type=None): :return: Deserialized object. """ if response_data is None: + # No data. Moving on. return None try: + # Data is a basic data type. return self.deserialize_type[target_obj](response_data, target_obj) except KeyError: pass - data = self._unpack_content(response_data, content_type) - return self._deserialize(target_obj, data) + try: + # Data is an XML model. + target_obj = self.dependencies[target_obj] + decoded_data = unpack_xml_content(response_data.http_response, content_type) + return self._deserialize(target_obj, decoded_data) + except KeyError: + pass + try: + # Data is in a dict/list. + structure = target_obj[0] + target_obj[-1] + inner_type = target_obj[1:-1] + try: + self.dependencies[inner_type] + response_data = unpack_xml_content(response_data.http_response, content_type) + except KeyError: + pass + return self.deserialize_type[structure](response_data, inner_type) + except (KeyError, IndexError): + pass + raise Exception("No idea what to do with {}, {}".format(target_obj, response_data)) + + def failsafe_deserialize(self, target_obj, data, content_type=None): + """Ignores any errors encountered in deserialization, + and falls back to not deserializing the object. Recommended + for use in error deserialization, as we want to return the + HttpResponseError to users, and not have them deal with + a deserialization error. + + :param str target_obj: The target object type to deserialize to. + :param str/dict data: The response data to deseralize. + :param str content_type: Swagger "produces" if available. + """ + try: + return self(target_obj, data, content_type=content_type) + except: + _LOGGER.warning( + "Ran into a deserialization error. Ignoring since this is failsafe deserialization", + exc_info=True + ) + return None def _deserialize(self, target_obj, data): """Call the deserializer on a model. @@ -632,50 +702,6 @@ def _classify_target(self, target, data): return target, target return target, target.__class__.__name__ - def failsafe_deserialize(self, target_obj, data, content_type=None): - """Ignores any errors encountered in deserialization, - and falls back to not deserializing the object. Recommended - for use in error deserialization, as we want to return the - HttpResponseError to users, and not have them deal with - a deserialization error. - - :param str target_obj: The target object type to deserialize to. - :param str/dict data: The response data to deseralize. - :param str content_type: Swagger "produces" if available. - """ - try: - return self(target_obj, data, content_type=content_type) - except: - _LOGGER.warning( - "Ran into a deserialization error. Ignoring since this is failsafe deserialization", - exc_info=True - ) - return None - - @staticmethod - def _unpack_content(raw_data, content_type=None): - """Extract the correct structure for deserialization. - - If raw_data is a PipelineResponse, try to extract the result of RawDeserializer. - if we can't, raise. Your Pipeline should have a RawDeserializer. - - If not a pipeline response and raw_data is bytes or string, use content-type - to decode it. If no content-type, try JSON. - - If raw_data is something else, bypass all logic and return it directly. - - :param raw_data: Data to be processed. - :param content_type: How to parse if raw_data is a string/bytes. - :raises JSONDecodeError: If JSON is requested and parsing is impossible. - :raises UnicodeDecodeError: If bytes is not UTF8 - """ - if isinstance(raw_data, (basestring, bytes)) or hasattr(raw_data, 'read'): - return deserialize_from_text(raw_data, content_type) - try: - return deserialize_from_http_generics(raw_data.http_response) - except AttributeError: - return raw_data - def _instantiate_model(self, response, attrs, additional_properties=None): """Instantiate a response model passing in deserialized args. @@ -833,97 +859,3 @@ def deserialize_object(self, attr, **kwargs): else: error = "Cannot deserialize generic object with type: " raise TypeError(error + str(obj_type)) - - -def deserialize_from_text( - data, # type: Optional[Union[AnyStr, IO]] - mime_type=None, # Optional[str] - response=None # Optional[Union[HttpResponse, AsyncHttpResponse]] -): - """Decode response data according to content-type. - Accept a stream of data as well, but will be load at once in memory for now. - If no content-type, will return the string version (not bytes, not stream) - :param response: The HTTP response. - :type response: ~azure.core.pipeline.transport.HttpResponse - :param str mime_type: The mime type. As mime type, charset is not expected. - :param response: If passed, exception will be annotated with that response - :raises ~azure.core.exceptions.DecodeError: If deserialization fails - :returns: A dict or XML tree, depending of the mime_type - """ - if not data: - return None - - if hasattr(data, 'read'): - # Assume a stream - data = cast(IO, data).read() - - if isinstance(data, bytes): - data_as_str = data.decode(encoding='utf-8-sig') - else: - # Explain to mypy the correct type. - data_as_str = cast(str, data) - - if mime_type is None: - return data_as_str - - if "xml" in (mime_type or []): - try: - try: - if isinstance(data, unicode): # type: ignore - # If I'm Python 2.7 and unicode XML will scream if I try a "fromstring" on unicode string - data_as_str = cast(str, data_as_str.encode(encoding="utf-8")) - except NameError: - pass - return ET.fromstring(data_as_str) # nosec - except ET.ParseError: - # It might be because the server has an issue, and returned JSON with - # content-type XML.... - # So let's try a JSON load, and if it's still broken - # let's flow the initial exception - def _json_attemp(data): - try: - return True, json.loads(data) - except ValueError: - return False, None # Don't care about this one - success, json_result = _json_attemp(data) - if success: - return json_result - # If i'm here, it's not JSON, it's not XML, let's scream - # and raise the last context in this block (the XML exception) - # The function hack is because Py2.7 messes up with exception - # context otherwise. - _LOGGER.critical("Wasn't XML not JSON, failing") - raise_with_traceback(DecodeError, message="XML is invalid", response=response) - elif mime_type.startswith("text/"): - return data_as_str - else: - try: - return json.loads(data_as_str) - except ValueError as err: - raise DecodeError(message="JSON is invalid: {}".format(err), response=response, error=err) - raise DecodeError("Cannot deserialize content-type: {}".format(mime_type)) - - -def deserialize_from_http_generics( - response, # Union[HttpResponse, AsyncHttpResponse] - encoding=None, # Optional[str] -): - """Deserialize from HTTP response. - Headers will tested for "content-type" - :param response: The HTTP response - :param encoding: The encoding to use if known for this service (will disable auto-detection) - :raises ~azure.core.exceptions.DecodeError: If deserialization fails - :returns: A dict or XML tree, depending of the mime-type - """ - # Try to use content-type from headers if available - if response.content_type: - mime_type = response.content_type.split(";")[0].strip().lower() - # Ouch, this server did not declare what it sent... - # Let's guess it's JSON... - # Also, since Autorest was considering that an empty body was a valid JSON, - # need that test as well.... - else: - mime_type = "application/json" - - # Rely on transport implementation to give me "text()" decoded correctly - return deserialize_from_text(response.text(encoding), mime_type, response=response) From 5ccd5ea9d362fa81ea7ff016c0801f91beb33d49 Mon Sep 17 00:00:00 2001 From: antisch Date: Tue, 20 Jul 2021 20:46:24 -0700 Subject: [PATCH 08/15] Refactor part 3 --- .../blob/_shared/xml_deserialization.py | 210 +++--------------- 1 file changed, 35 insertions(+), 175 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py index 21d2f4a61964..511aa41ca00c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -537,34 +537,15 @@ def __call__(self, target_obj, response_data, content_type=None): :raises: DeserializationError if deserialization fails. :return: Deserialized object. """ + try: + # First, unpack the response if we have one. + response_data = unpack_xml_content(response_data.http_response, content_type) + except AttributeError: + pass if response_data is None: # No data. Moving on. return None - try: - # Data is a basic data type. - return self.deserialize_type[target_obj](response_data, target_obj) - except KeyError: - pass - try: - # Data is an XML model. - target_obj = self.dependencies[target_obj] - decoded_data = unpack_xml_content(response_data.http_response, content_type) - return self._deserialize(target_obj, decoded_data) - except KeyError: - pass - try: - # Data is in a dict/list. - structure = target_obj[0] + target_obj[-1] - inner_type = target_obj[1:-1] - try: - self.dependencies[inner_type] - response_data = unpack_xml_content(response_data.http_response, content_type) - except KeyError: - pass - return self.deserialize_type[structure](response_data, inner_type) - except (KeyError, IndexError): - pass - raise Exception("No idea what to do with {}, {}".format(target_obj, response_data)) + return self._deserialize(target_obj, response_data) def failsafe_deserialize(self, target_obj, data, content_type=None): """Ignores any errors encountered in deserialization, @@ -596,42 +577,17 @@ def _deserialize(self, target_obj, data): :raises: DeserializationError if deserialization fails. :return: Deserialized object. """ - # This is already a model, go recursive just in case - if hasattr(data, "_attribute_map"): - constants = [name for name, config in getattr(data, '_validation', {}).items() - if config.get('constant')] - try: - for attr, mapconfig in data._attribute_map.items(): - if attr in constants: - continue - value = getattr(data, attr) - if value is None: - continue - local_type = mapconfig['type'] - internal_data_type = local_type.strip('[]{}') - if internal_data_type not in self.dependencies or isinstance(internal_data_type, Enum): - continue - setattr( - data, - attr, - self._deserialize(local_type, value) - ) - return data - except AttributeError: - return - - response, class_name = self._classify_target(target_obj, data) - - if isinstance(response, basestring): - return self.deserialize_data(data, response) - elif isinstance(response, type) and issubclass(response, Enum): - raise Exception("BOOM additional enum") - return self.deserialize_enum(data, response) + try: + model_type = self.dependencies[target_obj] + if issubclass(model_type, Enum): + return deserialize_enum(data.text, model_type) + except KeyError: + return self.deserialize_data(data, target_obj) if data is None: return data try: - attributes = response._attribute_map + attributes = model_type._attribute_map d_attrs = {} for attr, attr_desc in attributes.items(): # Check empty string. If it's not empty, someone has a real "additionalProperties"... @@ -647,27 +603,16 @@ def _deserialize(self, target_obj, data): for key_extractor in self.key_extractors: found_value = key_extractor(attr, attr_desc, data) if found_value is not None: - if raw_value is not None and raw_value != found_value: - raise Exception("BOOM raw_value") - msg = ("Ignoring extracted value '%s' from %s for key '%s'" - " (duplicate extraction, follow extractors order)" ) - _LOGGER.warning( - msg, - found_value, - key_extractor, - attr - ) - continue raw_value = found_value value = self.deserialize_data(raw_value, attr_desc['type']) d_attrs[attr] = value except (AttributeError, TypeError, KeyError) as err: - msg = "Unable to deserialize to object: " + class_name + msg = "Unable to deserialize to object: " + str(target_obj) raise_with_traceback(DeserializationError, msg, err) else: additional_properties = self._build_additional_properties(attributes, data) - return self._instantiate_model(response, d_attrs, additional_properties) + return self._instantiate_model(model_type, d_attrs, additional_properties) def _build_additional_properties(self, attribute_map, data): if not self.additional_properties_detection: @@ -684,58 +629,26 @@ def _build_additional_properties(self, attribute_map, data): missing_keys = present_keys - known_keys return {key: data[key] for key in missing_keys} - def _classify_target(self, target, data): - """Check to see whether the deserialization target object can - be classified into a subclass. - Once classification has been determined, initialize object. - - :param str target: The target object type to deserialize to. - :param str/dict data: The response data to deseralize. - """ - if target is None: - return None, None - - if isinstance(target, basestring): - try: - target = self.dependencies[target] - except KeyError: - return target, target - return target, target.__class__.__name__ - def _instantiate_model(self, response, attrs, additional_properties=None): """Instantiate a response model passing in deserialized args. :param response: The response model class. :param d_attrs: The deserialized response attributes. """ - if callable(response): - subtype = getattr(response, '_subtype_map', {}) - try: - readonly = [k for k, v in response._validation.items() - if v.get('readonly')] - const = [k for k, v in response._validation.items() - if v.get('constant')] - kwargs = {k: v for k, v in attrs.items() - if k not in subtype and k not in readonly + const} - response_obj = response(**kwargs) - for attr in readonly: - setattr(response_obj, attr, attrs.get(attr)) - if additional_properties: - response_obj.additional_properties = additional_properties - return response_obj - except TypeError as err: - msg = "Unable to deserialize {} into model {}. ".format( - kwargs, response) - raise DeserializationError(msg + str(err)) - else: - try: - for attr, value in attrs.items(): - setattr(response, attr, value) - return response - except Exception as exp: - msg = "Unable to populate response model. " - msg += "Type: {}, Error: {}".format(type(response), exp) - raise DeserializationError(msg) + try: + readonly = [k for k, v in response._validation.items() if v.get('readonly')] + const = [k for k, v in response._validation.items() if v.get('constant')] + kwargs = {k: v for k, v in attrs.items() if k not in readonly + const} + response_obj = response(**kwargs) + for attr in readonly: + setattr(response_obj, attr, attrs.get(attr)) + if additional_properties: + response_obj.additional_properties = additional_properties + return response_obj + except Exception as err: + msg = "Unable to deserialize {} into model {}. ".format( + kwargs, response) + raise DeserializationError(msg + str(err)) def deserialize_data(self, data, data_type): """Process data for deserialization according to data type. @@ -754,10 +667,6 @@ def deserialize_data(self, data, data_type): if data_type in self.basic_types.values(): return deserialize_basic(data, data_type) if data_type in self.deserialize_type: - if isinstance(data, self.deserialize_expected_types.get(data_type, tuple())): - raise Exception("BOOM expected types") - return data - is_a_text_parsing_type = lambda x: x not in ["object", "[]", r"{}"] if isinstance(data, ET.Element) and is_a_text_parsing_type(data_type) and not data.text: return None @@ -768,18 +677,12 @@ def deserialize_data(self, data, data_type): if iter_type in self.deserialize_type: return self.deserialize_type[iter_type](data, data_type[1:-1]) - obj_type = self.dependencies[data_type] - if issubclass(obj_type, Enum): - if isinstance(data, ET.Element): - data = data.text - return self.deserialize_enum(data, obj_type) - except (ValueError, TypeError, AttributeError) as err: msg = "Unable to deserialize response data." msg += " Data: {}, {}".format(data, data_type) raise_with_traceback(DeserializationError, msg, err) else: - return self._deserialize(obj_type, data) + return self._deserialize(data_type, data) def deserialize_iter(self, attr, iter_type): """Deserialize an iterable. @@ -790,14 +693,7 @@ def deserialize_iter(self, attr, iter_type): """ if attr is None: return None - if isinstance(attr, ET.Element): # If I receive an element here, get the children - attr = list(attr) - if not isinstance(attr, (list, set)): - raise DeserializationError("Cannot deserialize as [{}] an object of type {}".format( - iter_type, - type(attr) - )) - return [self.deserialize_data(a, iter_type) for a in attr] + return [self.deserialize_data(a, iter_type) for a in list(attr)] def deserialize_dict(self, attr, dict_type): """Deserialize a dictionary. @@ -807,12 +703,8 @@ def deserialize_dict(self, attr, dict_type): :param str dict_type: The object type of the items in the dictionary. :rtype: dict """ - if isinstance(attr, list): - return {x['key']: self.deserialize_data(x['value'], dict_type) for x in attr} - - if isinstance(attr, ET.Element): - # Transform value into {"Key": "value"} - attr = {el.tag: el.text for el in attr} + # Transform value into {"Key": "value"} + attr = {el.tag: el.text for el in attr} return {k: self.deserialize_data(v, dict_type) for k, v in attr.items()} def deserialize_object(self, attr, **kwargs): @@ -825,37 +717,5 @@ def deserialize_object(self, attr, **kwargs): """ if attr is None: return None - if isinstance(attr, ET.Element): - # Do no recurse on XML, just return the tree as-is - return attr - if isinstance(attr, basestring): - return deserialize_basic(attr, 'str') - obj_type = type(attr) - if obj_type in self.basic_types: - return deserialize_basic(attr, self.basic_types[obj_type]) - if obj_type is _long_type: - return deserialize_long(attr) - - if obj_type == dict: - deserialized = {} - for key, value in attr.items(): - try: - deserialized[key] = self.deserialize_object( - value, **kwargs) - except ValueError: - deserialized[key] = None - return deserialized - - if obj_type == list: - deserialized = [] - for obj in attr: - try: - deserialized.append(self.deserialize_object( - obj, **kwargs)) - except ValueError: - pass - return deserialized - - else: - error = "Cannot deserialize generic object with type: " - raise TypeError(error + str(obj_type)) + # Do no recurse on XML, just return the tree as-is + return attr From 5d868e20838be5e48fd3ee0181fd1a1224ce3037 Mon Sep 17 00:00:00 2001 From: antisch Date: Wed, 21 Jul 2021 10:15:06 -0700 Subject: [PATCH 09/15] Refactor part 4 --- .../blob/_shared/xml_deserialization.py | 75 +++++-------------- 1 file changed, 20 insertions(+), 55 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py index 511aa41ca00c..b388daf37770 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -116,8 +116,6 @@ def deserialize_bytearray(attr, *args): :rtype: bytearray :raises: TypeError if string format invalid. """ - if isinstance(attr, ET.Element): - attr = attr.text return bytearray(b64decode(attr)) @@ -128,8 +126,6 @@ def deserialize_base64(attr, *args): :rtype: bytearray :raises: TypeError if string format invalid. """ - if isinstance(attr, ET.Element): - attr = attr.text padding = '=' * (3 - (len(attr) + 3) % 4) attr = attr + padding encoded = attr.replace('-', '+').replace('_', '/') @@ -143,8 +139,6 @@ def deserialize_decimal(attr, *args): :rtype: Decimal :raises: DeserializationError if string format invalid. """ - if isinstance(attr, ET.Element): - attr = attr.text try: return decimal.Decimal(attr) except decimal.DecimalException as err: @@ -159,8 +153,6 @@ def deserialize_long(attr, *args): :rtype: long or int :raises: ValueError if string format invalid. """ - if isinstance(attr, ET.Element): - attr = attr.text return _long_type(attr) @@ -171,8 +163,6 @@ def deserialize_duration(attr, *args): :rtype: TimeDelta :raises: DeserializationError if string format invalid. """ - if isinstance(attr, ET.Element): - attr = attr.text try: duration = isodate.parse_duration(attr) except(ValueError, OverflowError, AttributeError) as err: @@ -189,8 +179,6 @@ def deserialize_date(attr, *args): :rtype: Date :raises: DeserializationError if string format invalid. """ - if isinstance(attr, ET.Element): - attr = attr.text if re.search(r"[^\W\d_]", attr, re.I + re.U): raise DeserializationError("Date must have only digits and -. Received: %s" % attr) # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. @@ -204,8 +192,6 @@ def deserialize_time(attr, *args): :rtype: datetime.time :raises: DeserializationError if string format invalid. """ - if isinstance(attr, ET.Element): - attr = attr.text if re.search(r"[^\W\d_]", attr, re.I + re.U): raise DeserializationError("Date must have only digits and -. Received: %s" % attr) return isodate.parse_time(attr) @@ -218,8 +204,6 @@ def deserialize_rfc(attr, *args): :rtype: Datetime :raises: DeserializationError if string format invalid. """ - if isinstance(attr, ET.Element): - attr = attr.text try: parsed_date = email.utils.parsedate_tz(attr) date_obj = datetime.datetime( @@ -242,8 +226,6 @@ def deserialize_iso(attr, *args): :rtype: Datetime :raises: DeserializationError if string format invalid. """ - if isinstance(attr, ET.Element): - attr = attr.text try: attr = attr.upper() match = _valid_date.match(attr) @@ -280,10 +262,8 @@ def deserialize_unix(attr, *args): :rtype: Datetime :raises: DeserializationError if format invalid """ - if isinstance(attr, ET.Element): - attr = int(attr.text) try: - date_obj = datetime.datetime.fromtimestamp(attr, TZ_UTC) + date_obj = datetime.datetime.fromtimestamp(int(attr), TZ_UTC) except ValueError as err: msg = "Cannot deserialize to unix datetime object." raise_with_traceback(DeserializationError, msg, err) @@ -298,6 +278,8 @@ def deserialize_unicode(data, *args): :param str data: response string to be deserialized. :rtype: str or unicode """ + if data is None: + return "" # We might be here because we have an enum modeled as string, # and we try to deserialize a partial dict with enum inside if isinstance(data, Enum): @@ -358,19 +340,6 @@ def deserialize_basic(attr, data_type): :rtype: str, int, float or bool :raises: TypeError if string format is not valid. """ - # If we're here, data is supposed to be a basic type. - # If it's still an XML node, take the text - if isinstance(attr, ET.Element): - attr = attr.text - if not attr: - if data_type == "str": - # None or '', node is empty string. - return '' - else: - # None or '', node with a strong type is None. - # Don't try to model "empty bool" or "empty int" - return None - if data_type == 'bool': if attr in [True, False, 1, 0]: return bool(attr) @@ -414,10 +383,6 @@ def xml_key_extractor(attr, attr_desc, data): if isinstance(data, dict): return None - # Test if this model is XML ready first - if not isinstance(data, ET.Element): - return None - xml_desc = attr_desc.get('xml', {}) xml_name = xml_desc.get('name', attr_desc['key']) @@ -658,25 +623,29 @@ def deserialize_data(self, data, data_type): :raises: DeserializationError if deserialization fails. :return: Deserialized object. """ - if data is None: - return data + if not data_type or data is None: + return None + try: + xml_data = data.text + except AttributeError: + xml_data = data try: - if not data_type: - return data - if data_type in self.basic_types.values(): - return deserialize_basic(data, data_type) - if data_type in self.deserialize_type: - is_a_text_parsing_type = lambda x: x not in ["object", "[]", r"{}"] - if isinstance(data, ET.Element) and is_a_text_parsing_type(data_type) and not data.text: - return None - data_val = self.deserialize_type[data_type](data) - return data_val + basic_deserialize = self.deserialize_type[data_type] + if not xml_data and data_type != 'str': + return None + return basic_deserialize(xml_data, data_type) + except KeyError: + pass + except (ValueError, TypeError, AttributeError) as err: + msg = "Unable to deserialize response data." + msg += " Data: {}, {}".format(data, data_type) + raise_with_traceback(DeserializationError, msg, err) + try: iter_type = data_type[0] + data_type[-1] if iter_type in self.deserialize_type: return self.deserialize_type[iter_type](data, data_type[1:-1]) - except (ValueError, TypeError, AttributeError) as err: msg = "Unable to deserialize response data." msg += " Data: {}, {}".format(data, data_type) @@ -691,8 +660,6 @@ def deserialize_iter(self, attr, iter_type): :param str iter_type: The type of object in the iterable. :rtype: list """ - if attr is None: - return None return [self.deserialize_data(a, iter_type) for a in list(attr)] def deserialize_dict(self, attr, dict_type): @@ -715,7 +682,5 @@ def deserialize_object(self, attr, **kwargs): :rtype: dict :raises: TypeError if non-builtin datatype encountered. """ - if attr is None: - return None # Do no recurse on XML, just return the tree as-is return attr From 9372dc12263fbf83803f4bb5054ad333b003283e Mon Sep 17 00:00:00 2001 From: antisch Date: Wed, 21 Jul 2021 16:25:20 -0700 Subject: [PATCH 10/15] Refactor part 5 --- .../blob/_shared/xml_deserialization.py | 195 +++++------------- 1 file changed, 53 insertions(+), 142 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py index b388daf37770..9291bf4a9231 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -355,108 +355,11 @@ def deserialize_basic(attr, data_type): return eval(data_type)(attr) -def _decode_attribute_map_key(key): - """This decode a key in an _attribute_map to the actual key we want to look at - inside the received data. - - :param str key: A key string from the generated code - """ - return key.replace('\\.', '.') - - -def _extract_name_from_internal_type(internal_type): - """Given an internal type XML description, extract correct XML name with namespace. - - :param dict internal_type: An model type - :rtype: tuple - :returns: A tuple XML name + namespace dict - """ - internal_type_xml_map = getattr(internal_type, "_xml_map", {}) - xml_name = internal_type_xml_map.get('name', internal_type.__name__) - xml_ns = internal_type_xml_map.get("ns", None) - if xml_ns: - xml_name = "{{{}}}{}".format(xml_ns, xml_name) - return xml_name - - -def xml_key_extractor(attr, attr_desc, data): - if isinstance(data, dict): - return None - - xml_desc = attr_desc.get('xml', {}) - xml_name = xml_desc.get('name', attr_desc['key']) - - # Look for a children - is_iter_type = attr_desc['type'].startswith("[") - is_wrapped = xml_desc.get("wrapped", False) - internal_type = attr_desc.get("internalType", None) - internal_type_xml_map = getattr(internal_type, "_xml_map", {}) - - # Integrate namespace if necessary - xml_ns = xml_desc.get('ns', internal_type_xml_map.get("ns", None)) - if xml_ns: - xml_name = "{{{}}}{}".format(xml_ns, xml_name) - - # If it's an attribute, that's simple - if xml_desc.get("attr", False): - return data.get(xml_name) - - # If it's x-ms-text, that's simple too - if xml_desc.get("text", False): - return data.text - - # Scenario where I take the local name: - # - Wrapped node - # - Internal type is an enum (considered basic types) - # - Internal type has no XML/Name node - if is_wrapped or (internal_type and (issubclass(internal_type, Enum) or 'name' not in internal_type_xml_map)): - children = data.findall(xml_name) - # If internal type has a local name and it's not a list, I use that name - elif not is_iter_type and internal_type and 'name' in internal_type_xml_map: - xml_name = _extract_name_from_internal_type(internal_type) - children = data.findall(xml_name) - # That's an array - else: - if internal_type: # Complex type, ignore itemsName and use the complex type name - items_name = _extract_name_from_internal_type(internal_type) - else: - items_name = xml_desc.get("itemsName", xml_name) - children = data.findall(items_name) - - if len(children) == 0: - if is_iter_type: - if is_wrapped: - return None # is_wrapped no node, we want None - else: - return [] # not wrapped, assume empty list - return None # Assume it's not there, maybe an optional node. - - # If is_iter_type and not wrapped, return all found children - if is_iter_type: - if not is_wrapped: - return children - else: # Iter and wrapped, should have found one node only (the wrap one) - if len(children) != 1: - raise DeserializationError( - "Tried to deserialize an array not wrapped, and found several nodes '{}'. Maybe you should declare this array as wrapped?".format( - xml_name - )) - return list(children[0]) # Might be empty list and that's ok. - - # Here it's not a itertype, we should have found one element only or empty - if len(children) > 1: - raise DeserializationError("Find several XML '{}' where it was not expected".format(xml_name)) - return children[0] - class Deserializer(object): """Response object model deserializer. :param dict classes: Class type dictionary for deserializing complex types. - :ivar list key_extractors: Ordered list of extractors to be used by this deserializer. """ - - basic_types = {str: 'str', int: 'int', bool: 'bool', float: 'float'} - def __init__(self, classes=None): self.deserialize_type = { 'str': deserialize_basic, @@ -477,21 +380,8 @@ def __init__(self, classes=None): '[]': self.deserialize_iter, '{}': self.deserialize_dict } - self.deserialize_expected_types = { - 'duration': (isodate.Duration, datetime.timedelta), - 'iso-8601': (datetime.datetime) - } + self.dependencies = dict(classes) if classes else {} - self.key_extractors = [ - xml_key_extractor - ] - # Additional properties only works if the "rest_key_extractor" is used to - # extract the keys. Making it to work whatever the key extractor is too much - # complicated, with no real scenario for now. - # So adding a flag to disable additional properties detection. This flag should be - # used if your expect the deserialization to NOT come from a JSON REST syntax. - # Otherwise, result are unexpected - self.additional_properties_detection = True def __call__(self, target_obj, response_data, content_type=None): """Call the deserializer to process a REST response. @@ -510,7 +400,8 @@ def __call__(self, target_obj, response_data, content_type=None): if response_data is None: # No data. Moving on. return None - return self._deserialize(target_obj, response_data) + #return self._deserialize(target_obj, response_data) + return self.deserialize_data(response_data, target_obj) def failsafe_deserialize(self, target_obj, data, content_type=None): """Ignores any errors encountered in deserialization, @@ -554,45 +445,31 @@ def _deserialize(self, target_obj, data): try: attributes = model_type._attribute_map d_attrs = {} + include_extra_props = False for attr, attr_desc in attributes.items(): # Check empty string. If it's not empty, someone has a real "additionalProperties"... if attr == "additional_properties" and attr_desc["key"] == '': + include_extra_props = True continue - raw_value = None - # Enhance attr_desc with some dynamic data - attr_desc = attr_desc.copy() # Do a copy, do not change the real one - internal_data_type = attr_desc["type"].strip('[]{}') - if internal_data_type in self.dependencies: - attr_desc["internalType"] = self.dependencies[internal_data_type] - - for key_extractor in self.key_extractors: - found_value = key_extractor(attr, attr_desc, data) - if found_value is not None: - raw_value = found_value - - value = self.deserialize_data(raw_value, attr_desc['type']) + attr_type = attr_desc["type"] + try: + subtype = self.dependencies[attr_type.strip('[]{}')] + except KeyError: + subtype = None + if attr_type[0] == '[': + raw_value = self.multi_xml_key_extractor(attr_desc, data, subtype) + else: + raw_value = self.xml_key_extractor(attr_desc, data, subtype) + value = self.deserialize_data(raw_value, attr_type) d_attrs[attr] = value except (AttributeError, TypeError, KeyError) as err: msg = "Unable to deserialize to object: " + str(target_obj) raise_with_traceback(DeserializationError, msg, err) else: - additional_properties = self._build_additional_properties(attributes, data) - return self._instantiate_model(model_type, d_attrs, additional_properties) - - def _build_additional_properties(self, attribute_map, data): - if not self.additional_properties_detection: - return None - if "additional_properties" in attribute_map and attribute_map.get("additional_properties", {}).get("key") != '': - # Check empty string. If it's not empty, someone has a real "additionalProperties" - return None - if isinstance(data, ET.Element): - data = {el.tag: el.text for el in data} - - known_keys = {_decode_attribute_map_key(_FLATTEN.split(desc['key'])[0]) - for desc in attribute_map.values() if desc['key'] != ''} - present_keys = set(data.keys()) - missing_keys = present_keys - known_keys - return {key: data[key] for key in missing_keys} + if include_extra_props: + extra = {el.tag: el.text for el in data if el.tag not in d_attrs} + return self._instantiate_model(model_type, d_attrs, extra) + return self._instantiate_model(model_type, d_attrs) def _instantiate_model(self, response, attrs, additional_properties=None): """Instantiate a response model passing in deserialized args. @@ -615,6 +492,40 @@ def _instantiate_model(self, response, attrs, additional_properties=None): kwargs, response) raise DeserializationError(msg + str(err)) + def multi_xml_key_extractor(self, attr_desc, data, subtype): + xml_desc = attr_desc.get('xml', {}) + xml_name = xml_desc.get('name', attr_desc['key']) + is_wrapped = xml_desc.get("wrapped", False) + subtype_xml_map = getattr(subtype, "_xml_map", {}) + if is_wrapped: + items_name = xml_name + elif subtype: + items_name = subtype_xml_map.get('name', xml_name) + else: + items_name = xml_desc.get("itemsName", xml_name) + children = data.findall(items_name) + if is_wrapped: + if len(children) == 0: + return None + return list(children[0]) + return children + + def xml_key_extractor(self, attr_desc, data, subtype): + xml_desc = attr_desc.get('xml', {}) + xml_name = xml_desc.get('name', attr_desc['key']) + + # If it's an attribute, that's simple + if xml_desc.get("attr", False): + return data.get(xml_name) + + # If it's x-ms-text, that's simple too + if xml_desc.get("text", False): + return data.text + + subtype_xml_map = getattr(subtype, "_xml_map", {}) + xml_name = subtype_xml_map.get('name', xml_name) + return data.find(xml_name) + def deserialize_data(self, data, data_type): """Process data for deserialization according to data type. From 406647f06569950b5f01ddd050fdb07b0bfc2a3b Mon Sep 17 00:00:00 2001 From: antisch Date: Thu, 22 Jul 2021 07:44:12 -0700 Subject: [PATCH 11/15] Make xml pipeline opt-in --- .../azure/storage/blob/_blob_client.py | 3 +- .../storage/blob/_blob_service_client.py | 7 +- .../azure/storage/blob/_container_client.py | 15 +++-- .../azure/storage/blob/_list_blobs_helper.py | 11 ++-- .../azure/storage/blob/_shared/base_client.py | 6 +- .../storage/blob/_shared/base_client_async.py | 3 + .../blob/_shared/xml_deserialization.py | 65 ++++++++----------- .../storage/blob/aio/_blob_client_async.py | 3 +- .../blob/aio/_blob_service_client_async.py | 7 +- .../blob/aio/_container_client_async.py | 9 +-- .../storage/blob/aio/_list_blobs_helper.py | 7 +- .../tests/perfstress_tests/_test_base.py | 2 + 12 files changed, 74 insertions(+), 64 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py index f1289aeec4cf..c86e738c7ef7 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py @@ -176,7 +176,8 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential, snapshot=self.snapshot) super(BlobClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs) self._client = AzureBlobStorage(self.url, pipeline=self._pipeline) - self._custom_xml_deserializer(generated_models) + if not self._msrest_xml: + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py index 1096d06d2a61..cd0f4f3137ac 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py @@ -135,7 +135,8 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential) super(BlobServiceClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs) self._client = AzureBlobStorage(self.url, pipeline=self._pipeline) - self._custom_xml_deserializer(generated_models) + if not self._msrest_xml: + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access @@ -678,7 +679,7 @@ def get_container_client(self, container): credential=self.credential, api_version=self.api_version, _configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts, require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - key_resolver_function=self.key_resolver_function) + key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml) def get_blob_client( self, container, # type: Union[ContainerProperties, str] @@ -731,4 +732,4 @@ def get_blob_client( credential=self.credential, api_version=self.api_version, _configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts, require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - key_resolver_function=self.key_resolver_function) + key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py index 89b579950f73..e6698bc87af9 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py @@ -157,7 +157,8 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential) super(ContainerClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs) self._client = AzureBlobStorage(self.url, pipeline=self._pipeline) - self._custom_xml_deserializer(generated_models) + if not self._msrest_xml: + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access @@ -772,8 +773,12 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs): timeout=timeout, **kwargs) return ItemPaged( - command, prefix=name_starts_with, results_per_page=results_per_page, select=select, - deserializer=self._client._deserialize, page_iterator_class=BlobPropertiesPaged) + command, + prefix=name_starts_with, + results_per_page=results_per_page, + select=select, + deserializer=self._client._deserialize, # pylint: disable=protected-access + page_iterator_class=BlobPropertiesPaged) @distributed_trace def walk_blobs( @@ -820,7 +825,7 @@ def walk_blobs( prefix=name_starts_with, results_per_page=results_per_page, select=None, - deserializer=self._client._deserialize, + deserializer=self._client._deserialize, # pylint: disable=protected-access delimiter=delimiter) @distributed_trace @@ -1553,4 +1558,4 @@ def get_blob_client( credential=self.credential, api_version=self.api_version, _configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts, require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - key_resolver_function=self.key_resolver_function) + key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py index 790d417ea35b..a7575b338213 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py @@ -9,30 +9,33 @@ from azure.core.exceptions import HttpResponseError from ._deserialize import get_blob_properties_from_generated_code, parse_tags -from ._generated.models import BlobItemInternal, BlobPrefix as GenBlobPrefix, FilterBlobItem +from ._generated.models import FilterBlobItem from ._models import BlobProperties, FilteredBlob from ._shared.models import DictMixin from ._shared.xml_deserialization import unpack_xml_content from ._shared.response_handlers import return_context_and_deserialized, process_storage_error -def deserialize_list_result(pipeline_response, _, headers): +def deserialize_list_result(pipeline_response, *args): payload = unpack_xml_content(pipeline_response.http_response) location = pipeline_response.http_response.location_mode return location, payload + def load_xml_string(element, name): node = element.find(name) if node is None or not node.text: return None return node.text + def load_xml_int(element, name): node = element.find(name) if node is None or not node.text: return None return int(node.text) + def load_xml_bool(element, name): node = load_xml_string(element, name) if node and node.lower() == 'true': @@ -58,7 +61,7 @@ def blob_properties_from_xml(element, select, deserializer): if 'name' in select: blob.name = load_xml_string(element, 'Name') if 'deleted' in select: - blob.deleted = load_xml_bool(element, 'Deleted') + blob.deleted = load_xml_bool(element, 'Deleted') if 'snapshot' in select: blob.snapshot = load_xml_string(element, 'Snapshot') if 'version' in select: @@ -67,7 +70,7 @@ def blob_properties_from_xml(element, select, deserializer): return blob -class BlobPropertiesPaged(PageIterator): +class BlobPropertiesPaged(PageIterator): # pylint: disable=too-many-instance-attributes """An Iterable of Blob properties. :ivar str service_endpoint: The service URL. diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index 7aaaeadd4d60..781b73e7f9e3 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -26,6 +26,7 @@ from azure.core.pipeline.transport import RequestsTransport, HttpTransport from azure.core.pipeline.policies import ( RedirectPolicy, + ContentDecodePolicy, BearerTokenCredentialPolicy, ProxyPolicy, DistributedTracingPolicy, @@ -74,6 +75,7 @@ def __init__( # type: (...) -> None self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) self._hosts = kwargs.get("_hosts") + self._msrest_xml = kwargs.get('msrest_xml', True) self.scheme = parsed_url.scheme if service not in ["blob", "queue", "file-share", "dfs"]: @@ -250,13 +252,13 @@ def _create_pipeline(self, credential, **kwargs): config.transport = RequestsTransport(**kwargs) policies = [ QueueMessagePolicy(), + config.headers_policy, config.proxy_policy, config.user_agent_policy, StorageContentValidation(), RedirectPolicy(**kwargs), StorageHosts(hosts=self._hosts, **kwargs), config.retry_policy, - config.headers_policy, StorageRequestHook(**kwargs), self._credential_policy, config.logging_policy, @@ -264,6 +266,8 @@ def _create_pipeline(self, credential, **kwargs): DistributedTracingPolicy(**kwargs), HttpLoggingPolicy(**kwargs) ] + if self._msrest_xml: + policies.insert(5, ContentDecodePolicy(response_encoding="utf-8")) if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") return config, Pipeline(config.transport, policies=policies) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 9df1b28e1069..5deca436299b 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -20,6 +20,7 @@ DistributedTracingPolicy, HttpLoggingPolicy, AzureSasCredentialPolicy, + ContentDecodePolicy ) from azure.core.pipeline.transport import AsyncHttpTransport @@ -104,6 +105,8 @@ def _create_pipeline(self, credential, **kwargs): DistributedTracingPolicy(**kwargs), HttpLoggingPolicy(**kwargs), ] + if self._msrest_xml: + policies.insert(5, ContentDecodePolicy(response_encoding="utf-8")) if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") return config, AsyncPipeline(config.transport, policies=policies) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py index 9291bf4a9231..9378ad4d2743 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -24,33 +24,32 @@ # # -------------------------------------------------------------------------- -from base64 import b64decode, b64encode -import calendar +from base64 import b64decode +from typing import cast import datetime import decimal import email from enum import Enum -import json import logging import re -import sys import os -try: - from urllib import quote # type: ignore -except ImportError: - from urllib.parse import quote # type: ignore +import ast if os.environ.get("AZURE_STORAGE_LXML"): try: from lxml import etree as ET - except: + except: # pylint: disable=bare-except import xml.etree.ElementTree as ET else: import xml.etree.ElementTree as ET import isodate - -from typing import Dict, Any, cast, IO +from azure.core.exceptions import DecodeError +from msrest.exceptions import DeserializationError, raise_with_traceback +from msrest.serialization import ( + TZ_UTC, + _FixedOffset +) try: @@ -71,15 +70,6 @@ _long_type = int -from azure.core.exceptions import DecodeError -from msrest.exceptions import DeserializationError, raise_with_traceback -from msrest.serialization import ( - TZ_UTC, - _FixedOffset, - _FLATTEN -) - - def unpack_xml_content(response_data, content_type=None): """Extract the correct structure for deserialization. @@ -109,7 +99,7 @@ def unpack_xml_content(response_data, content_type=None): raise_with_traceback(DecodeError, message="XML is invalid", response=response_data) -def deserialize_bytearray(attr, *args): +def deserialize_bytearray(attr, *_): """Deserialize string into bytearray. :param str attr: response string to be deserialized. @@ -119,7 +109,7 @@ def deserialize_bytearray(attr, *args): return bytearray(b64decode(attr)) -def deserialize_base64(attr, *args): +def deserialize_base64(attr, *_): """Deserialize base64 encoded string into string. :param str attr: response string to be deserialized. @@ -132,7 +122,7 @@ def deserialize_base64(attr, *args): return b64decode(encoded) -def deserialize_decimal(attr, *args): +def deserialize_decimal(attr, *_): """Deserialize string into Decimal object. :param str attr: response string to be deserialized. @@ -146,7 +136,7 @@ def deserialize_decimal(attr, *args): raise_with_traceback(DeserializationError, msg, err) -def deserialize_long(attr, *args): +def deserialize_long(attr, *_): """Deserialize string into long (Py2) or int (Py3). :param str attr: response string to be deserialized. @@ -156,7 +146,7 @@ def deserialize_long(attr, *args): return _long_type(attr) -def deserialize_duration(attr, *args): +def deserialize_duration(attr, *_): """Deserialize ISO-8601 formatted string into TimeDelta object. :param str attr: response string to be deserialized. @@ -172,7 +162,7 @@ def deserialize_duration(attr, *args): return duration -def deserialize_date(attr, *args): +def deserialize_date(attr, *_): """Deserialize ISO-8601 formatted string into Date object. :param str attr: response string to be deserialized. @@ -185,7 +175,7 @@ def deserialize_date(attr, *args): return isodate.parse_date(attr, defaultmonth=None, defaultday=None) -def deserialize_time(attr, *args): +def deserialize_time(attr, *_): """Deserialize ISO-8601 formatted string into time object. :param str attr: response string to be deserialized. @@ -197,7 +187,7 @@ def deserialize_time(attr, *args): return isodate.parse_time(attr) -def deserialize_rfc(attr, *args): +def deserialize_rfc(attr, *_): """Deserialize RFC-1123 formatted string into Datetime object. :param str attr: response string to be deserialized. @@ -219,7 +209,7 @@ def deserialize_rfc(attr, *args): return date_obj -def deserialize_iso(attr, *args): +def deserialize_iso(attr, *_): """Deserialize ISO-8601 formatted string into Datetime object. :param str attr: response string to be deserialized. @@ -254,7 +244,7 @@ def deserialize_iso(attr, *args): return date_obj -def deserialize_unix(attr, *args): +def deserialize_unix(attr, *_): """Serialize Datetime object into IntTime format. This is represented as seconds. @@ -271,7 +261,7 @@ def deserialize_unix(attr, *args): return date_obj -def deserialize_unicode(data, *args): +def deserialize_unicode(data, *_): """Preserve unicode objects in Python 2, otherwise return data as a string. @@ -340,19 +330,20 @@ def deserialize_basic(attr, data_type): :rtype: str, int, float or bool :raises: TypeError if string format is not valid. """ + if data_type == 'str': + return deserialize_unicode(attr) if data_type == 'bool': if attr in [True, False, 1, 0]: return bool(attr) - elif isinstance(attr, basestring): + if isinstance(attr, basestring): if attr.lower() in ['true', '1']: return True elif attr.lower() in ['false', '0']: return False raise TypeError("Invalid boolean value: {}".format(attr)) - - if data_type == 'str': - return deserialize_unicode(attr) - return eval(data_type)(attr) + if data_type == 'int': + return int(attr) + return float(attr) class Deserializer(object): @@ -458,7 +449,7 @@ def _deserialize(self, target_obj, data): subtype = None if attr_type[0] == '[': raw_value = self.multi_xml_key_extractor(attr_desc, data, subtype) - else: + else: raw_value = self.xml_key_extractor(attr_desc, data, subtype) value = self.deserialize_data(raw_value, attr_type) d_attrs[attr] = value diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py index 39e2351e2263..ec5266183770 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py @@ -121,7 +121,8 @@ def __init__( credential=credential, **kwargs) self._client = AzureBlobStorage(url=self.url, pipeline=self._pipeline) - self._custom_xml_deserializer(generated_models) + if not self._msrest_xml: + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py index f673c5b35e75..3a2fd1eb4b14 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py @@ -119,7 +119,8 @@ def __init__( credential=credential, **kwargs) self._client = AzureBlobStorage(url=self.url, pipeline=self._pipeline) - self._custom_xml_deserializer(generated_models) + if not self._msrest_xml: + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access @@ -621,7 +622,7 @@ def get_container_client(self, container): credential=self.credential, api_version=self.api_version, _configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts, require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - key_resolver_function=self.key_resolver_function) + key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml) def get_blob_client( self, container, # type: Union[ContainerProperties, str] @@ -676,4 +677,4 @@ def get_blob_client( credential=self.credential, api_version=self.api_version, _configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts, require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - key_resolver_function=self.key_resolver_function) + key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py index 7c58993b6aea..7d888153a049 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py @@ -118,7 +118,8 @@ def __init__( credential=credential, **kwargs) self._client = AzureBlobStorage(url=self.url, pipeline=self._pipeline) - self._custom_xml_deserializer(generated_models) + if not self._msrest_xml: + self._custom_xml_deserializer(generated_models) default_api_version = self._client._config.version # pylint: disable=protected-access self._client._config.version = get_api_version(kwargs, default_api_version) # pylint: disable=protected-access @@ -637,7 +638,7 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs): prefix=name_starts_with, results_per_page=results_per_page, select=select, - deserializer=self._client._deserialize, + deserializer=self._client._deserialize, # pylint: disable=protected-access page_iterator_class=BlobPropertiesPaged ) @@ -686,7 +687,7 @@ def walk_blobs( prefix=name_starts_with, results_per_page=results_per_page, select=None, - deserializer=self._client._deserialize, + deserializer=self._client._deserialize, # pylint: disable=protected-access delimiter=delimiter) @distributed_trace_async @@ -1213,4 +1214,4 @@ def get_blob_client( credential=self.credential, api_version=self.api_version, _configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, _hosts=self._hosts, require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - key_resolver_function=self.key_resolver_function) + key_resolver_function=self.key_resolver_function, msrest_xml=self._msrest_xml) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py index a4822bd5445f..9a11087d7020 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_list_blobs_helper.py @@ -7,11 +7,8 @@ from azure.core.async_paging import AsyncPageIterator, AsyncItemPaged from azure.core.exceptions import HttpResponseError -from .._deserialize import get_blob_properties_from_generated_code -from .._models import BlobProperties -from .._generated.models import BlobItemInternal, BlobPrefix as GenBlobPrefix from .._shared.models import DictMixin -from .._shared.response_handlers import return_context_and_deserialized, process_storage_error +from .._shared.response_handlers import process_storage_error from .._list_blobs_helper import ( deserialize_list_result, load_many_nodes, @@ -21,7 +18,7 @@ ) -class BlobPropertiesPaged(AsyncPageIterator): +class BlobPropertiesPaged(AsyncPageIterator): # pylint: disable=too-many-instance-attributes """An Iterable of Blob properties. :ivar str service_endpoint: The service URL. diff --git a/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py b/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py index ca46e67ffccb..f90adea2645a 100644 --- a/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py +++ b/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py @@ -25,6 +25,7 @@ def __init__(self, arguments): self._client_kwargs['max_single_put_size'] = self.args.max_put_size self._client_kwargs['max_block_size'] = self.args.max_block_size self._client_kwargs['min_large_block_upload_threshold'] = self.args.buffer_threshold + self._client_kwargs['msrest_xml'] = self.args.msrest_xml # self._client_kwargs['api_version'] = '2019-02-02' # Used only for comparison with T1 legacy tests if not _ServiceTest.service_client or self.args.no_client_share: @@ -46,6 +47,7 @@ def add_arguments(parser): parser.add_argument('--max-concurrency', nargs='?', type=int, help='Maximum number of concurrent threads used for data transfer. Defaults to 1', default=1) parser.add_argument('-s', '--size', nargs='?', type=int, help='Size of data to transfer. Default is 10240.', default=10240) parser.add_argument('--no-client-share', action='store_true', help='Create one ServiceClient per test instance. Default is to share a single ServiceClient.', default=False) + parser.add_argument('--msrest-xml', action='store_true', help='Use the msrest XML derialization pipeline. Defaults to True', default=True) class _ContainerTest(_ServiceTest): From e4e3450ade298e23580792c818838bd17d9f2bc1 Mon Sep 17 00:00:00 2001 From: antisch Date: Thu, 22 Jul 2021 09:01:07 -0700 Subject: [PATCH 12/15] Some code cleanup --- .../azure/storage/blob/_list_blobs_helper.py | 2 +- .../azure/storage/blob/_shared/base_client.py | 5 +- .../blob/_shared/xml_deserialization.py | 241 +++++++++--------- .../tests/perfstress_tests/_test_base.py | 4 +- 4 files changed, 133 insertions(+), 119 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py index a7575b338213..6b4e1e6bd6b4 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_list_blobs_helper.py @@ -16,7 +16,7 @@ from ._shared.response_handlers import return_context_and_deserialized, process_storage_error -def deserialize_list_result(pipeline_response, *args): +def deserialize_list_result(pipeline_response, *_): payload = unpack_xml_content(pipeline_response.http_response) location = pipeline_response.http_response.location_mode return location, payload diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index 781b73e7f9e3..955a8073f5db 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -75,7 +75,7 @@ def __init__( # type: (...) -> None self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) self._hosts = kwargs.get("_hosts") - self._msrest_xml = kwargs.get('msrest_xml', True) + self._msrest_xml = kwargs.get('msrest_xml', False) self.scheme = parsed_url.scheme if service not in ["blob", "queue", "file-share", "dfs"]: @@ -200,9 +200,10 @@ def api_version(self): :type: str """ return self._client._config.version # pylint: disable=protected-access - + def _custom_xml_deserializer(self, generated_models): """Reset the deserializer on the generated client to be Storage implementation""" + # pylint: disable=protected-access client_models = {k: v for k, v in generated_models.__dict__.items() if isinstance(v, type)} custom_deserialize = Deserializer(client_models) self._client._deserialize = custom_deserialize diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py index 9378ad4d2743..db62d8b7b0d9 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -33,15 +33,6 @@ import logging import re import os -import ast - -if os.environ.get("AZURE_STORAGE_LXML"): - try: - from lxml import etree as ET - except: # pylint: disable=bare-except - import xml.etree.ElementTree as ET -else: - import xml.etree.ElementTree as ET import isodate from azure.core.exceptions import DecodeError @@ -51,9 +42,16 @@ _FixedOffset ) +if os.environ.get("AZURE_STORAGE_LXML"): + try: + from lxml import etree as ET + except: # pylint: disable=bare-except + import xml.etree.ElementTree as ET +else: + import xml.etree.ElementTree as ET try: - basestring # type: ignore + basestring # pylint: disable=pointless-statement unicode_str = unicode # type: ignore except NameError: basestring = str # type: ignore @@ -70,7 +68,7 @@ _long_type = int -def unpack_xml_content(response_data, content_type=None): +def unpack_xml_content(response_data, **kwargs): """Extract the correct structure for deserialization. If raw_data is a PipelineResponse, try to extract the result of RawDeserializer. @@ -96,7 +94,7 @@ def unpack_xml_content(response_data, content_type=None): return ET.fromstring(data_as_str) # nosec except ET.ParseError: _LOGGER.critical("Response body invalid XML") - raise_with_traceback(DecodeError, message="XML is invalid", response=response_data) + raise_with_traceback(DecodeError, message="XML is invalid", response=response_data, **kwargs) def deserialize_bytearray(attr, *_): @@ -136,6 +134,43 @@ def deserialize_decimal(attr, *_): raise_with_traceback(DeserializationError, msg, err) +def deserialize_bool(attr, *args): + """Deserialize string into bool. + + :param str attr: response string to be deserialized. + :rtype: bool + :raises: TypeError if string format is not valid. + """ + if attr in [True, False, 1, 0]: + return bool(attr) + if isinstance(attr, basestring): + if attr.lower() in ['true', '1']: + return True + if attr.lower() in ['false', '0']: + return False + raise TypeError("Invalid boolean value: {}".format(attr)) + + +def deserialize_int(attr, *_): + """Deserialize string into int. + + :param str attr: response string to be deserialized. + :rtype: int + :raises: ValueError or TypeError if string format invalid. + """ + return int(attr) + + +def deserialize_float(attr, *_): + """Deserialize string into float. + + :param str attr: response string to be deserialized. + :rtype: float + :raises: ValueError if string format invalid. + """ + return float(attr) + + def deserialize_long(attr, *_): """Deserialize string into long (Py2) or int (Py3). @@ -244,6 +279,19 @@ def deserialize_iso(attr, *_): return date_obj +def deserialize_object(attr, *_): + """Deserialize a generic object. + This will be handled as a dictionary. + + :param dict attr: Dictionary to be deserialized. + :rtype: dict + :raises: TypeError if non-builtin datatype encountered. + """ + # Do no recurse on XML, just return the tree as-is + # TODO: This probably needs work + return attr + + def deserialize_unix(attr, *_): """Serialize Datetime object into IntTime format. This is represented as seconds. @@ -319,31 +367,61 @@ def deserialize_enum(data, enum_obj): return deserialize_unicode(data) -def deserialize_basic(attr, data_type): - """Deserialize baisc builtin data type from string. - Will attempt to convert to str, int, float and bool. - This function will also accept '1', '0', 'true' and 'false' as - valid bool values. +def instantiate_model(response, attrs, additional_properties=None): + """Instantiate a response model passing in deserialized args. - :param str attr: response string to be deserialized. - :param str data_type: deserialization data type. - :rtype: str, int, float or bool - :raises: TypeError if string format is not valid. + :param response: The response model class. + :param d_attrs: The deserialized response attributes. """ - if data_type == 'str': - return deserialize_unicode(attr) - if data_type == 'bool': - if attr in [True, False, 1, 0]: - return bool(attr) - if isinstance(attr, basestring): - if attr.lower() in ['true', '1']: - return True - elif attr.lower() in ['false', '0']: - return False - raise TypeError("Invalid boolean value: {}".format(attr)) - if data_type == 'int': - return int(attr) - return float(attr) + try: + readonly = [k for k, v in response._validation.items() if v.get('readonly')] # pylint:disable=protected-access + const = [k for k, v in response._validation.items() if v.get('constant')] # pylint:disable=protected-access + kwargs = {k: v for k, v in attrs.items() if k not in readonly + const} + response_obj = response(**kwargs) + for attr in readonly: + setattr(response_obj, attr, attrs.get(attr)) + if additional_properties: + response_obj.additional_properties = additional_properties + return response_obj + except Exception as err: + msg = "Unable to deserialize {} into model {}. ".format( + kwargs, response) + raise DeserializationError(msg + str(err)) + + +def multi_xml_key_extractor(attr_desc, data, subtype): + xml_desc = attr_desc.get('xml', {}) + xml_name = xml_desc.get('name', attr_desc['key']) + is_wrapped = xml_desc.get("wrapped", False) + subtype_xml_map = getattr(subtype, "_xml_map", {}) + if is_wrapped: + items_name = xml_name + elif subtype: + items_name = subtype_xml_map.get('name', xml_name) + else: + items_name = xml_desc.get("itemsName", xml_name) + children = data.findall(items_name) + if is_wrapped: + if len(children) == 0: + return None + return list(children[0]) + return children + +def xml_key_extractor(attr_desc, data, subtype): + xml_desc = attr_desc.get('xml', {}) + xml_name = xml_desc.get('name', attr_desc['key']) + + # If it's an attribute, that's simple + if xml_desc.get("attr", False): + return data.get(xml_name) + + # If it's x-ms-text, that's simple too + if xml_desc.get("text", False): + return data.text + + subtype_xml_map = getattr(subtype, "_xml_map", {}) + xml_name = subtype_xml_map.get('name', xml_name) + return data.find(xml_name) class Deserializer(object): @@ -353,10 +431,10 @@ class Deserializer(object): """ def __init__(self, classes=None): self.deserialize_type = { - 'str': deserialize_basic, - 'int': deserialize_basic, - 'bool': deserialize_basic, - 'float': deserialize_basic, + 'str': deserialize_unicode, + 'int': deserialize_int, + 'bool': deserialize_bool, + 'float': deserialize_float, 'iso-8601': deserialize_iso, 'rfc-1123': deserialize_rfc, 'unix-time': deserialize_unix, @@ -367,14 +445,14 @@ def __init__(self, classes=None): 'long': deserialize_long, 'bytearray': deserialize_bytearray, 'base64': deserialize_base64, - 'object': self.deserialize_object, + 'object': deserialize_object, '[]': self.deserialize_iter, '{}': self.deserialize_dict } self.dependencies = dict(classes) if classes else {} - def __call__(self, target_obj, response_data, content_type=None): + def __call__(self, target_obj, response_data, **kwargs): """Call the deserializer to process a REST response. :param str target_obj: Target data type to deserialize to. @@ -385,7 +463,7 @@ def __call__(self, target_obj, response_data, content_type=None): """ try: # First, unpack the response if we have one. - response_data = unpack_xml_content(response_data.http_response, content_type) + response_data = unpack_xml_content(response_data.http_response, **kwargs) except AttributeError: pass if response_data is None: @@ -407,7 +485,7 @@ def failsafe_deserialize(self, target_obj, data, content_type=None): """ try: return self(target_obj, data, content_type=content_type) - except: + except: # pylint: disable=bare-except _LOGGER.warning( "Ran into a deserialization error. Ignoring since this is failsafe deserialization", exc_info=True @@ -434,7 +512,7 @@ def _deserialize(self, target_obj, data): if data is None: return data try: - attributes = model_type._attribute_map + attributes = model_type._attribute_map # pylint:disable=protected-access d_attrs = {} include_extra_props = False for attr, attr_desc in attributes.items(): @@ -444,13 +522,14 @@ def _deserialize(self, target_obj, data): continue attr_type = attr_desc["type"] try: + # TODO: Validate this subtype logic subtype = self.dependencies[attr_type.strip('[]{}')] except KeyError: subtype = None if attr_type[0] == '[': - raw_value = self.multi_xml_key_extractor(attr_desc, data, subtype) + raw_value = multi_xml_key_extractor(attr_desc, data, subtype) else: - raw_value = self.xml_key_extractor(attr_desc, data, subtype) + raw_value = xml_key_extractor(attr_desc, data, subtype) value = self.deserialize_data(raw_value, attr_type) d_attrs[attr] = value except (AttributeError, TypeError, KeyError) as err: @@ -459,63 +538,8 @@ def _deserialize(self, target_obj, data): else: if include_extra_props: extra = {el.tag: el.text for el in data if el.tag not in d_attrs} - return self._instantiate_model(model_type, d_attrs, extra) - return self._instantiate_model(model_type, d_attrs) - - def _instantiate_model(self, response, attrs, additional_properties=None): - """Instantiate a response model passing in deserialized args. - - :param response: The response model class. - :param d_attrs: The deserialized response attributes. - """ - try: - readonly = [k for k, v in response._validation.items() if v.get('readonly')] - const = [k for k, v in response._validation.items() if v.get('constant')] - kwargs = {k: v for k, v in attrs.items() if k not in readonly + const} - response_obj = response(**kwargs) - for attr in readonly: - setattr(response_obj, attr, attrs.get(attr)) - if additional_properties: - response_obj.additional_properties = additional_properties - return response_obj - except Exception as err: - msg = "Unable to deserialize {} into model {}. ".format( - kwargs, response) - raise DeserializationError(msg + str(err)) - - def multi_xml_key_extractor(self, attr_desc, data, subtype): - xml_desc = attr_desc.get('xml', {}) - xml_name = xml_desc.get('name', attr_desc['key']) - is_wrapped = xml_desc.get("wrapped", False) - subtype_xml_map = getattr(subtype, "_xml_map", {}) - if is_wrapped: - items_name = xml_name - elif subtype: - items_name = subtype_xml_map.get('name', xml_name) - else: - items_name = xml_desc.get("itemsName", xml_name) - children = data.findall(items_name) - if is_wrapped: - if len(children) == 0: - return None - return list(children[0]) - return children - - def xml_key_extractor(self, attr_desc, data, subtype): - xml_desc = attr_desc.get('xml', {}) - xml_name = xml_desc.get('name', attr_desc['key']) - - # If it's an attribute, that's simple - if xml_desc.get("attr", False): - return data.get(xml_name) - - # If it's x-ms-text, that's simple too - if xml_desc.get("text", False): - return data.text - - subtype_xml_map = getattr(subtype, "_xml_map", {}) - xml_name = subtype_xml_map.get('name', xml_name) - return data.find(xml_name) + return instantiate_model(model_type, d_attrs, extra) + return instantiate_model(model_type, d_attrs) def deserialize_data(self, data, data_type): """Process data for deserialization according to data type. @@ -575,14 +599,3 @@ def deserialize_dict(self, attr, dict_type): # Transform value into {"Key": "value"} attr = {el.tag: el.text for el in attr} return {k: self.deserialize_data(v, dict_type) for k, v in attr.items()} - - def deserialize_object(self, attr, **kwargs): - """Deserialize a generic object. - This will be handled as a dictionary. - - :param dict attr: Dictionary to be deserialized. - :rtype: dict - :raises: TypeError if non-builtin datatype encountered. - """ - # Do no recurse on XML, just return the tree as-is - return attr diff --git a/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py b/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py index f90adea2645a..8d9cdaf49829 100644 --- a/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py +++ b/sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py @@ -25,7 +25,7 @@ def __init__(self, arguments): self._client_kwargs['max_single_put_size'] = self.args.max_put_size self._client_kwargs['max_block_size'] = self.args.max_block_size self._client_kwargs['min_large_block_upload_threshold'] = self.args.buffer_threshold - self._client_kwargs['msrest_xml'] = self.args.msrest_xml + self._client_kwargs['msrest_xml'] = not self.args.no_msrest # self._client_kwargs['api_version'] = '2019-02-02' # Used only for comparison with T1 legacy tests if not _ServiceTest.service_client or self.args.no_client_share: @@ -47,7 +47,7 @@ def add_arguments(parser): parser.add_argument('--max-concurrency', nargs='?', type=int, help='Maximum number of concurrent threads used for data transfer. Defaults to 1', default=1) parser.add_argument('-s', '--size', nargs='?', type=int, help='Size of data to transfer. Default is 10240.', default=10240) parser.add_argument('--no-client-share', action='store_true', help='Create one ServiceClient per test instance. Default is to share a single ServiceClient.', default=False) - parser.add_argument('--msrest-xml', action='store_true', help='Use the msrest XML derialization pipeline. Defaults to True', default=True) + parser.add_argument('--no-msrest', action='store_true', help='Do not use the msrest XML derialization pipeline. Defaults to False', default=False) class _ContainerTest(_ServiceTest): From 24dbf61ae6eea9684492afc2243397f0916c3125 Mon Sep 17 00:00:00 2001 From: antisch Date: Thu, 22 Jul 2021 09:45:28 -0700 Subject: [PATCH 13/15] Don't decode payload --- .../blob/_shared/xml_deserialization.py | 22 +++++-------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py index db62d8b7b0d9..b3441d038a7a 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -33,22 +33,17 @@ import logging import re import os +import xml.etree.ElementTree as ET import isodate -from azure.core.exceptions import DecodeError from msrest.exceptions import DeserializationError, raise_with_traceback from msrest.serialization import ( TZ_UTC, _FixedOffset ) -if os.environ.get("AZURE_STORAGE_LXML"): - try: - from lxml import etree as ET - except: # pylint: disable=bare-except - import xml.etree.ElementTree as ET -else: - import xml.etree.ElementTree as ET +from azure.core.exceptions import DecodeError + try: basestring # pylint: disable=pointless-statement @@ -83,15 +78,8 @@ def unpack_xml_content(response_data, **kwargs): :param content_type: How to parse if raw_data is a string/bytes. :raises UnicodeDecodeError: If bytes is not UTF8 """ - data_as_str = response_data.text() try: - try: - if isinstance(raw_data, unicode): # type: ignore - # If I'm Python 2.7 and unicode XML will scream if I try a "fromstring" on unicode string - data_as_str = cast(str, data_as_str.encode(encoding="utf-8")) - except NameError: - pass - return ET.fromstring(data_as_str) # nosec + return ET.fromstring(response_data.body()) # nosec except ET.ParseError: _LOGGER.critical("Response body invalid XML") raise_with_traceback(DecodeError, message="XML is invalid", response=response_data, **kwargs) @@ -134,7 +122,7 @@ def deserialize_decimal(attr, *_): raise_with_traceback(DeserializationError, msg, err) -def deserialize_bool(attr, *args): +def deserialize_bool(attr, *_): """Deserialize string into bool. :param str attr: response string to be deserialized. From 95557f746d8819e70bc70a6542f438c7b40a12a9 Mon Sep 17 00:00:00 2001 From: antisch Date: Thu, 22 Jul 2021 12:43:45 -0700 Subject: [PATCH 14/15] Fix stats test --- .../storage/blob/_shared/xml_deserialization.py | 2 -- .../tests/test_blob_service_stats.py | 16 ++++++++-------- .../tests/test_blob_service_stats_async.py | 16 ++++++++-------- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py index b3441d038a7a..20f53f028354 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/xml_deserialization.py @@ -25,14 +25,12 @@ # -------------------------------------------------------------------------- from base64 import b64decode -from typing import cast import datetime import decimal import email from enum import Enum import logging import re -import os import xml.etree.ElementTree as ET import isodate diff --git a/sdk/storage/azure-storage-blob/tests/test_blob_service_stats.py b/sdk/storage/azure-storage-blob/tests/test_blob_service_stats.py index 1de16c8a6538..2a36c6b8f9c2 100644 --- a/sdk/storage/azure-storage-blob/tests/test_blob_service_stats.py +++ b/sdk/storage/azure-storage-blob/tests/test_blob_service_stats.py @@ -12,13 +12,13 @@ from _shared.testcase import GlobalStorageAccountPreparer, GlobalResourceGroupPreparer -SERVICE_UNAVAILABLE_RESP_BODY = 'unavailable ' +SERVICE_UNAVAILABLE_RESP_BODY = b'unavailable ' -SERVICE_LIVE_RESP_BODY = 'liveWed, 19 Jan 2021 22:28:43 GMT ' +SERVICE_LIVE_RESP_BODY = b'liveWed, 19 Jan 2021 22:28:43 GMT ' # --Test Class ----------------------------------------------------------------- class ServiceStatsTest(StorageTestCase): @@ -39,11 +39,11 @@ def _assert_stats_unavailable(self, stats): @staticmethod def override_response_body_with_live_status(response): - response.http_response.text = lambda encoding=None: SERVICE_LIVE_RESP_BODY + response.http_response.body = lambda: SERVICE_LIVE_RESP_BODY @staticmethod def override_response_body_with_unavailable_status(response): - response.http_response.text = lambda encoding=None: SERVICE_UNAVAILABLE_RESP_BODY + response.http_response.body = lambda: SERVICE_UNAVAILABLE_RESP_BODY # --Test cases per service --------------------------------------- @GlobalResourceGroupPreparer() diff --git a/sdk/storage/azure-storage-blob/tests/test_blob_service_stats_async.py b/sdk/storage/azure-storage-blob/tests/test_blob_service_stats_async.py index 380fa67b024d..4b545a2eafab 100644 --- a/sdk/storage/azure-storage-blob/tests/test_blob_service_stats_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_blob_service_stats_async.py @@ -15,14 +15,14 @@ from devtools_testutils.storage.aio import AsyncStorageTestCase -SERVICE_UNAVAILABLE_RESP_BODY = 'unavailable ' +SERVICE_UNAVAILABLE_RESP_BODY = b'unavailable ' -SERVICE_LIVE_RESP_BODY = 'liveWed, 19 Jan 2021 22:28:43 GMT ' +SERVICE_LIVE_RESP_BODY = b'liveWed, 19 Jan 2021 22:28:43 GMT ' class AiohttpTestTransport(AioHttpTransport): @@ -55,11 +55,11 @@ def _assert_stats_unavailable(self, stats): @staticmethod def override_response_body_with_live_status(response): - response.http_response.text = lambda encoding=None: SERVICE_LIVE_RESP_BODY + response.http_response.body = lambda: SERVICE_LIVE_RESP_BODY @staticmethod def override_response_body_with_unavailable_status(response): - response.http_response.text = lambda encoding=None: SERVICE_UNAVAILABLE_RESP_BODY + response.http_response.body = lambda: SERVICE_UNAVAILABLE_RESP_BODY # --Test cases per service --------------------------------------- @GlobalResourceGroupPreparer() From efb9c56e296445b6862b2a9f4649b2560584bfe9 Mon Sep 17 00:00:00 2001 From: antisch Date: Mon, 26 Jul 2021 10:38:25 -0700 Subject: [PATCH 15/15] Surfaced as separate API --- .../azure/storage/blob/_container_client.py | 33 ++++++++++++++++-- .../blob/aio/_container_client_async.py | 34 +++++++++++++++++-- .../tests/perfstress_tests/list_blobs.py | 18 ++++++---- 3 files changed, 75 insertions(+), 10 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py index e6698bc87af9..d8ebad17a640 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py @@ -766,7 +766,6 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs): results_per_page = kwargs.pop('results_per_page', None) timeout = kwargs.pop('timeout', None) - select = kwargs.pop('select', None) command = functools.partial( self._client.container.list_blob_flat_segment, include=include, @@ -776,7 +775,37 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs): command, prefix=name_starts_with, results_per_page=results_per_page, - select=select, + select=None, + deserializer=self._client._deserialize, # pylint: disable=protected-access + page_iterator_class=BlobPropertiesPaged) + + @distributed_trace + def list_blob_names(self, **kwargs): + # type: (**Any) -> ItemPaged[str] + """Returns a generator to list the names of blobs under the specified container. + The generator will lazily follow the continuation tokens returned by + the service. + + :keyword str name_starts_with: + Filters the results to return only blobs whose names + begin with the specified prefix. + :keyword int timeout: + The timeout parameter is expressed in seconds. + :returns: An iterable (auto-paging) response of blob names as strings. + :rtype: ~azure.core.paging.ItemPaged[str] + """ + name_starts_with = kwargs.pop('name_starts_with', None) + results_per_page = kwargs.pop('results_per_page', None) + timeout = kwargs.pop('timeout', None) + command = functools.partial( + self._client.container.list_blob_flat_segment, + timeout=timeout, + **kwargs) + return ItemPaged( + command, + prefix=name_starts_with, + results_per_page=results_per_page, + select=["name"], deserializer=self._client._deserialize, # pylint: disable=protected-access page_iterator_class=BlobPropertiesPaged) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py index 7d888153a049..9a919ec6c36b 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py @@ -627,7 +627,6 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs): results_per_page = kwargs.pop('results_per_page', None) timeout = kwargs.pop('timeout', None) - select = kwargs.pop('select', None) command = functools.partial( self._client.container.list_blob_flat_segment, include=include, @@ -637,7 +636,38 @@ def list_blobs(self, name_starts_with=None, include=None, **kwargs): command, prefix=name_starts_with, results_per_page=results_per_page, - select=select, + select=None, + deserializer=self._client._deserialize, # pylint: disable=protected-access + page_iterator_class=BlobPropertiesPaged + ) + + @distributed_trace + def list_blob_names(self, **kwargs): + # type: (**Any) -> AsyncItemPaged[str] + """Returns a generator to list the names of blobs under the specified container. + The generator will lazily follow the continuation tokens returned by + the service. + + :keyword str name_starts_with: + Filters the results to return only blobs whose names + begin with the specified prefix. + :keyword int timeout: + The timeout parameter is expressed in seconds. + :returns: An iterable (auto-paging) response of blob names as strings. + :rtype: ~azure.core.async_paging.AsyncItemPaged[str] + """ + name_starts_with = kwargs.pop('name_starts_with', None) + results_per_page = kwargs.pop('results_per_page', None) + timeout = kwargs.pop('timeout', None) + command = functools.partial( + self._client.container.list_blob_flat_segment, + timeout=timeout, + **kwargs) + return AsyncItemPaged( + command, + prefix=name_starts_with, + results_per_page=results_per_page, + select=["name"], deserializer=self._client._deserialize, # pylint: disable=protected-access page_iterator_class=BlobPropertiesPaged ) diff --git a/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py b/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py index b0a84c707573..65894b044e3f 100644 --- a/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py +++ b/sdk/storage/azure-storage-blob/tests/perfstress_tests/list_blobs.py @@ -27,14 +27,20 @@ async def global_setup(self): break def run_sync(self): - select = ['name'] if self.args.name_only else None - for _ in self.container_client.list_blobs(select=select): - pass + if self.args.name_only: + for _ in self.container_client.list_blob_names(): + pass + else: + for _ in self.container_client.list_blobs(): + pass async def run_async(self): - select = ['name'] if self.args.name_only else None - async for _ in self.async_container_client.list_blobs(select=select): - pass + if self.args.name_only: + async for _ in self.async_container_client.list_blob_names(): + pass + else: + async for _ in self.async_container_client.list_blobs(): + pass @staticmethod def add_arguments(parser):