Skip to content

Commit e09e3ed

Browse files
Merge pull request Azure#11 from kristapratico/pipeline_ownership
pipeline ownership for queues and files
2 parents 952c125 + 6c73352 commit e09e3ed

16 files changed

+228
-26
lines changed

sdk/storage/azure-storage-file/azure/storage/file/_shared/base_client.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from azure.core import Configuration
2929
from azure.core.exceptions import HttpResponseError
3030
from azure.core.pipeline import Pipeline
31-
from azure.core.pipeline.transport import RequestsTransport
31+
from azure.core.pipeline.transport import RequestsTransport, HttpTransport
3232
from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy
3333
from azure.core.pipeline.policies import RedirectPolicy, ContentDecodePolicy, BearerTokenCredentialPolicy, ProxyPolicy
3434

@@ -216,6 +216,27 @@ def _batch_send(
216216
process_storage_error(error)
217217

218218

219+
class TransportWrapper(HttpTransport):
220+
221+
def __init__(self, transport):
222+
self._transport = transport
223+
224+
def send(self, request, **kwargs):
225+
return self._transport.send(request, **kwargs)
226+
227+
def open(self):
228+
pass
229+
230+
def close(self):
231+
pass
232+
233+
def __enter__(self, *args): # pylint: disable=arguments-differ
234+
pass
235+
236+
def __exit__(self, *args): # pylint: disable=arguments-differ
237+
pass
238+
239+
219240
def format_shared_key_credential(account, credential):
220241
if isinstance(credential, six.string_types):
221242
if len(account) < 2:

sdk/storage/azure-storage-file/azure/storage/file/_shared/base_client_async.py

+22
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AsyncBearerTokenCredentialPolicy,
1919
AsyncRedirectPolicy)
2020

21+
from azure.core.pipeline.transport import AsyncHttpTransport
2122
from .constants import STORAGE_OAUTH_SCOPE, DEFAULT_SOCKET_TIMEOUT
2223
from .authentication import SharedKeyCredentialPolicy
2324
from .base_client import create_configuration
@@ -122,3 +123,24 @@ async def _batch_send(
122123
return response.parts() # Return an AsyncIterator
123124
except StorageErrorException as error:
124125
process_storage_error(error)
126+
127+
128+
class AsyncTransportWrapper(AsyncHttpTransport):
129+
130+
def __init__(self, async_transport):
131+
self._transport = async_transport
132+
133+
async def send(self, request, **kwargs):
134+
return await self._transport.send(request, **kwargs)
135+
136+
async def open(self):
137+
pass
138+
139+
async def close(self):
140+
pass
141+
142+
async def __aenter__(self, *args): # pylint: disable=arguments-differ
143+
pass
144+
145+
async def __aexit__(self, *args): # pylint: disable=arguments-differ
146+
pass

sdk/storage/azure-storage-file/azure/storage/file/aio/directory_client_async.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from azure.core.polling import async_poller
1313
from azure.core.async_paging import AsyncItemPaged
14-
14+
from azure.core.pipeline import AsyncPipeline
1515
from azure.core.tracing.decorator import distributed_trace
1616
from azure.core.tracing.decorator_async import distributed_trace_async
1717
from .._parser import _get_file_permission, _datetime_to_str
@@ -20,7 +20,7 @@
2020
from .._generated.aio import AzureFileStorage
2121
from .._generated.version import VERSION
2222
from .._generated.models import StorageErrorException
23-
from .._shared.base_client_async import AsyncStorageAccountHostsMixin
23+
from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper
2424
from .._shared.policies_async import ExponentialRetry
2525
from .._shared.request_handlers import add_metadata_headers
2626
from .._shared.response_handlers import return_response_headers, process_storage_error
@@ -112,10 +112,15 @@ def get_file_client(self, file_name, **kwargs):
112112
"""
113113
if self.directory_path:
114114
file_name = self.directory_path.rstrip('/') + "/" + file_name
115+
116+
_pipeline = AsyncPipeline(
117+
transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
118+
policies=self._pipeline._impl_policies # pylint: disable = protected-access
119+
)
115120
return FileClient(
116121
self.url, file_path=file_name, share_name=self.share_name, snapshot=self.snapshot,
117122
credential=self.credential, _hosts=self._hosts, _configuration=self._config,
118-
_pipeline=self._pipeline, _location_mode=self._location_mode, loop=self._loop, **kwargs)
123+
_pipeline=_pipeline, _location_mode=self._location_mode, loop=self._loop, **kwargs)
119124

120125
def get_subdirectory_client(self, directory_name, **kwargs):
121126
# type: (str, Any) -> DirectoryClient
@@ -138,10 +143,15 @@ def get_subdirectory_client(self, directory_name, **kwargs):
138143
:caption: Gets the subdirectory client.
139144
"""
140145
directory_path = self.directory_path.rstrip('/') + "/" + directory_name
146+
147+
_pipeline = AsyncPipeline(
148+
transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
149+
policies=self._pipeline._impl_policies # pylint: disable = protected-access
150+
)
141151
return DirectoryClient(
142152
self.url, share_name=self.share_name, directory_path=directory_path, snapshot=self.snapshot,
143153
credential=self.credential, _hosts=self._hosts, _configuration=self._config,
144-
_pipeline=self._pipeline, _location_mode=self._location_mode, loop=self._loop, **kwargs)
154+
_pipeline=_pipeline, _location_mode=self._location_mode, loop=self._loop, **kwargs)
145155

146156
@distributed_trace_async
147157
async def create_directory(self, **kwargs): # type: ignore

sdk/storage/azure-storage-file/azure/storage/file/aio/file_service_client_async.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212

1313
from azure.core.async_paging import AsyncItemPaged
1414
from azure.core.tracing.decorator import distributed_trace
15+
from azure.core.pipeline import AsyncPipeline
1516
from azure.core.tracing.decorator_async import distributed_trace_async
1617

17-
from .._shared.base_client_async import AsyncStorageAccountHostsMixin
18+
from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper
1819
from .._shared.response_handlers import process_storage_error
1920
from .._shared.policies_async import ExponentialRetry
2021
from .._generated.aio import AzureFileStorage
@@ -314,6 +315,11 @@ def get_share_client(self, share, snapshot=None):
314315
share_name = share.name
315316
except AttributeError:
316317
share_name = share
318+
319+
_pipeline = AsyncPipeline(
320+
transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
321+
policies=self._pipeline._impl_policies # pylint: disable = protected-access
322+
)
317323
return ShareClient(
318324
self.url, share_name=share_name, snapshot=snapshot, credential=self.credential, _hosts=self._hosts,
319-
_configuration=self._config, _pipeline=self._pipeline, _location_mode=self._location_mode, loop=self._loop)
325+
_configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, loop=self._loop)

sdk/storage/azure-storage-file/azure/storage/file/aio/share_client_async.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
from azure.core.tracing.decorator import distributed_trace
1212
from azure.core.tracing.decorator_async import distributed_trace_async
13-
13+
from azure.core.pipeline import AsyncPipeline
1414
from .._shared.policies_async import ExponentialRetry
15-
from .._shared.base_client_async import AsyncStorageAccountHostsMixin
15+
from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper
1616
from .._shared.request_handlers import add_metadata_headers, serialize_iso
1717
from .._shared.response_handlers import (
1818
return_response_headers,
@@ -102,9 +102,14 @@ def get_directory_client(self, directory_path=None):
102102
:returns: A Directory Client.
103103
:rtype: ~azure.storage.file.aio.directory_client_async.DirectoryClient
104104
"""
105+
_pipeline = AsyncPipeline(
106+
transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
107+
policies=self._pipeline._impl_policies # pylint: disable = protected-access
108+
)
109+
105110
return DirectoryClient(
106111
self.url, share_name=self.share_name, directory_path=directory_path or "", snapshot=self.snapshot,
107-
credential=self.credential, _hosts=self._hosts, _configuration=self._config, _pipeline=self._pipeline,
112+
credential=self.credential, _hosts=self._hosts, _configuration=self._config, _pipeline=_pipeline,
108113
_location_mode=self._location_mode, loop=self._loop)
109114

110115
def get_file_client(self, file_path):
@@ -117,10 +122,15 @@ def get_file_client(self, file_path):
117122
:returns: A File Client.
118123
:rtype: ~azure.storage.file.aio.file_client_async.FileClient
119124
"""
125+
_pipeline = AsyncPipeline(
126+
transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
127+
policies=self._pipeline._impl_policies # pylint: disable = protected-access
128+
)
129+
120130
return FileClient(
121131
self.url, share_name=self.share_name, file_path=file_path, snapshot=self.snapshot,
122132
credential=self.credential, _hosts=self._hosts, _configuration=self._config,
123-
_pipeline=self._pipeline, _location_mode=self._location_mode, loop=self._loop)
133+
_pipeline=_pipeline, _location_mode=self._location_mode, loop=self._loop)
124134

125135
@distributed_trace_async
126136
async def create_share(self, **kwargs): # type: ignore

sdk/storage/azure-storage-file/azure/storage/file/directory_client.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
import six
1919
from azure.core.polling import LROPoller
2020
from azure.core.paging import ItemPaged
21+
from azure.core.pipeline import Pipeline
2122
from azure.core.tracing.decorator import distributed_trace
2223

2324
from ._generated import AzureFileStorage
2425
from ._generated.version import VERSION
2526
from ._generated.models import StorageErrorException
26-
from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query
27+
from ._shared.base_client import StorageAccountHostsMixin, TransportWrapper, parse_connection_str, parse_query
2728
from ._shared.request_handlers import add_metadata_headers
2829
from ._shared.response_handlers import return_response_headers, process_storage_error
2930
from ._shared.parser import _str
@@ -217,10 +218,15 @@ def get_file_client(self, file_name, **kwargs):
217218
"""
218219
if self.directory_path:
219220
file_name = self.directory_path.rstrip('/') + "/" + file_name
221+
222+
_pipeline = Pipeline(
223+
transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
224+
policies=self._pipeline._impl_policies # pylint: disable = protected-access
225+
)
220226
return FileClient(
221227
self.url, file_path=file_name, share_name=self.share_name, napshot=self.snapshot,
222228
credential=self.credential, _hosts=self._hosts, _configuration=self._config,
223-
_pipeline=self._pipeline, _location_mode=self._location_mode, **kwargs)
229+
_pipeline=_pipeline, _location_mode=self._location_mode, **kwargs)
224230

225231
def get_subdirectory_client(self, directory_name, **kwargs):
226232
# type: (str, Any) -> DirectoryClient
@@ -243,9 +249,14 @@ def get_subdirectory_client(self, directory_name, **kwargs):
243249
:caption: Gets the subdirectory client.
244250
"""
245251
directory_path = self.directory_path.rstrip('/') + "/" + directory_name
252+
253+
_pipeline = Pipeline(
254+
transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
255+
policies=self._pipeline._impl_policies # pylint: disable = protected-access
256+
)
246257
return DirectoryClient(
247258
self.url, share_name=self.share_name, directory_path=directory_path, snapshot=self.snapshot,
248-
credential=self.credential, _hosts=self._hosts, _configuration=self._config, _pipeline=self._pipeline,
259+
credential=self.credential, _hosts=self._hosts, _configuration=self._config, _pipeline=_pipeline,
249260
_location_mode=self._location_mode, **kwargs)
250261

251262
@distributed_trace

sdk/storage/azure-storage-file/azure/storage/file/file_service_client.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
from azure.core.paging import ItemPaged
1818
from azure.core.tracing.decorator import distributed_trace
19-
19+
from azure.core.pipeline import Pipeline
2020
from ._shared.shared_access_signature import SharedAccessSignature
2121
from ._shared.models import Services
22-
from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query
22+
from ._shared.base_client import StorageAccountHostsMixin, TransportWrapper, parse_connection_str, parse_query
2323
from ._shared.response_handlers import process_storage_error
2424
from ._generated import AzureFileStorage
2525
from ._generated.models import StorageErrorException, StorageServiceProperties
@@ -431,6 +431,11 @@ def get_share_client(self, share, snapshot=None):
431431
share_name = share.name
432432
except AttributeError:
433433
share_name = share
434+
435+
_pipeline = Pipeline(
436+
transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
437+
policies=self._pipeline._impl_policies # pylint: disable = protected-access
438+
)
434439
return ShareClient(
435440
self.url, share_name=share_name, snapshot=snapshot, credential=self.credential, _hosts=self._hosts,
436-
_configuration=self._config, _pipeline=self._pipeline, _location_mode=self._location_mode)
441+
_configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode)

sdk/storage/azure-storage-file/azure/storage/file/share_client.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
import six
1717
from azure.core.tracing.decorator import distributed_trace
18-
from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query
18+
from azure.core.pipeline import Pipeline
19+
from ._shared.base_client import StorageAccountHostsMixin, TransportWrapper, parse_connection_str, parse_query
1920
from ._shared.request_handlers import add_metadata_headers, serialize_iso
2021
from ._shared.response_handlers import (
2122
return_response_headers,
@@ -297,9 +298,14 @@ def get_directory_client(self, directory_path=None):
297298
:returns: A Directory Client.
298299
:rtype: ~azure.storage.file.DirectoryClient
299300
"""
301+
_pipeline = Pipeline(
302+
transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
303+
policies=self._pipeline._impl_policies # pylint: disable = protected-access
304+
)
305+
300306
return DirectoryClient(
301307
self.url, share_name=self.share_name, directory_path=directory_path or "", snapshot=self.snapshot,
302-
credential=self.credential, _hosts=self._hosts, _configuration=self._config, _pipeline=self._pipeline,
308+
credential=self.credential, _hosts=self._hosts, _configuration=self._config, _pipeline=_pipeline,
303309
_location_mode=self._location_mode)
304310

305311
def get_file_client(self, file_path):
@@ -312,10 +318,15 @@ def get_file_client(self, file_path):
312318
:returns: A File Client.
313319
:rtype: ~azure.storage.file.FileClient
314320
"""
321+
_pipeline = Pipeline(
322+
transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
323+
policies=self._pipeline._impl_policies # pylint: disable = protected-access
324+
)
325+
315326
return FileClient(
316327
self.url, share_name=self.share_name, file_path=file_path, snapshot=self.snapshot,
317328
credential=self.credential, _hosts=self._hosts, _configuration=self._config,
318-
_pipeline=self._pipeline, _location_mode=self._location_mode)
329+
_pipeline=_pipeline, _location_mode=self._location_mode)
319330

320331
@distributed_trace
321332
def create_share(self, **kwargs): # type: ignore

sdk/storage/azure-storage-file/tests/test_share.py

+13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import pytest
1212
import requests
13+
from azure.core.pipeline.transport import RequestsTransport
1314
from azure.core.exceptions import (
1415
HttpResponseError,
1516
ResourceNotFoundError,
@@ -762,6 +763,18 @@ def test_create_permission_for_share(self):
762763
# server returned permission
763764
self.assertEquals(permission_key, permission_key2)
764765

766+
@record
767+
def test_transport_closed_only_once(self):
768+
transport = RequestsTransport()
769+
url = self.get_file_url()
770+
credential = self.get_shared_key_credential()
771+
share = self._get_share_reference()
772+
with FileServiceClient(url, credential=credential, transport=transport) as fsc:
773+
assert transport.session is not None
774+
with fsc.get_share_client(share.share_name) as fc:
775+
assert transport.session is not None
776+
assert transport.session is not None # Right now it's None
777+
765778
# ------------------------------------------------------------------------------
766779
if __name__ == '__main__':
767780
unittest.main()

sdk/storage/azure-storage-file/tests/test_share_async.py

+12
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pytest
1212
import requests
1313
from azure.core.pipeline.transport import AioHttpTransport
14+
from azure.core.pipeline.transport import AsyncioRequestsTransport
1415
from multidict import CIMultiDict, CIMultiDictProxy
1516
from azure.core.exceptions import (
1617
HttpResponseError,
@@ -918,6 +919,17 @@ def test_create_permission_for_share_async(self):
918919
loop = asyncio.get_event_loop()
919920
loop.run_until_complete(self._test_create_permission_for_share())
920921

922+
async def test_transport_closed_only_once_async(self):
923+
transport = AsyncioRequestsTransport()
924+
url = self.get_file_url()
925+
credential = self.get_shared_key_credential()
926+
share = self._get_share_reference()
927+
async with FileServiceClient(url, credential=credential, transport=transport) as fsc:
928+
assert transport.session is not None
929+
async with fsc.get_share_client(share.share_name) as fc:
930+
assert transport.session is not None
931+
assert transport.session is not None # Right now it's None
932+
921933
# ------------------------------------------------------------------------------
922934
if __name__ == '__main__':
923935
unittest.main()

sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from azure.core import Configuration
2929
from azure.core.exceptions import HttpResponseError
3030
from azure.core.pipeline import Pipeline
31-
from azure.core.pipeline.transport import RequestsTransport
31+
from azure.core.pipeline.transport import RequestsTransport, HttpTransport
3232
from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy
3333
from azure.core.pipeline.policies import RedirectPolicy, ContentDecodePolicy, BearerTokenCredentialPolicy, ProxyPolicy
3434

@@ -216,6 +216,27 @@ def _batch_send(
216216
process_storage_error(error)
217217

218218

219+
class TransportWrapper(HttpTransport):
220+
221+
def __init__(self, transport):
222+
self._transport = transport
223+
224+
def send(self, request, **kwargs):
225+
return self._transport.send(request, **kwargs)
226+
227+
def open(self):
228+
pass
229+
230+
def close(self):
231+
pass
232+
233+
def __enter__(self, *args): # pylint: disable=arguments-differ
234+
pass
235+
236+
def __exit__(self, *args): # pylint: disable=arguments-differ
237+
pass
238+
239+
219240
def format_shared_key_credential(account, credential):
220241
if isinstance(credential, six.string_types):
221242
if len(account) < 2:

0 commit comments

Comments
 (0)