From b73d3258e17b835345fb914585d46e22abc06844 Mon Sep 17 00:00:00 2001 From: swathipil Date: Wed, 2 Oct 2024 20:44:00 -0700 Subject: [PATCH 01/11] remove verify from pyamqp JWTToken --- .../azure/eventhub/_transport/_pyamqp_transport.py | 1 - .../azure/eventhub/aio/_transport/_pyamqp_transport_async.py | 1 - 2 files changed, 2 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py index 457a711ec79a..cd90adcdde2d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py @@ -579,7 +579,6 @@ def create_token_auth( timeout=config.auth_timeout, custom_endpoint_hostname=config.custom_endpoint_hostname, port=config.connection_port, - verify=config.connection_verify, ) # if update_token: # token_auth.update_token() # TODO: why don't we need to update in pyamqp? diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py index c7cf800abb34..e6ad90bc9792 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py @@ -367,7 +367,6 @@ async def create_token_auth_async( timeout=config.auth_timeout, custom_endpoint_hostname=config.custom_endpoint_hostname, port=config.connection_port, - verify=config.connection_verify, ) # if update_token: # token_auth.update_token() # TODO: why don't we need to update in pyamqp? From 948d5d82bf5de788687f6287701dbe6a24479db5 Mon Sep 17 00:00:00 2001 From: swathipil Date: Wed, 2 Oct 2024 23:44:14 -0700 Subject: [PATCH 02/11] add ssl_context kwarg --- .../azure/eventhub/_configuration.py | 10 +++++++- .../azure/eventhub/_consumer_client.py | 9 ++++++++ .../azure/eventhub/_producer_client.py | 9 ++++++++ .../azure/eventhub/_pyamqp/_transport.py | 23 ++++++++----------- .../eventhub/_pyamqp/aio/_client_async.py | 14 +++++------ .../eventhub/_pyamqp/aio/_transport_async.py | 12 ++-------- .../azure/eventhub/_pyamqp/client.py | 21 ++++++++++------- .../eventhub/aio/_consumer_client_async.py | 9 ++++++++ .../eventhub/aio/_producer_client_async.py | 9 ++++++++ 9 files changed, 76 insertions(+), 40 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py index 23082a5a4ef7..ce2445ae6043 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py @@ -2,12 +2,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +import warnings from typing import Optional, Dict, Any, Union, TYPE_CHECKING from urllib.parse import urlparse from azure.core.pipeline.policies import RetryMode from ._constants import TransportType, DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT if TYPE_CHECKING: + from ssl import SSLContext from ._transport._base import AmqpTransport from .aio._transport._base_async import AmqpTransportAsync @@ -36,6 +38,7 @@ def __init__( custom_endpoint_address: Optional[str] = None, connection_verify: Optional[str] = None, use_tls: bool = True, + ssl_context: Optional["SSLContext"] = None, **kwargs: Any ): self.user_agent = user_agent @@ -57,7 +60,12 @@ def __init__( self.receive_timeout = receive_timeout self.send_timeout = send_timeout self.custom_endpoint_address = custom_endpoint_address - self.connection_verify = connection_verify + if ssl_context: + if connection_verify: + warnings.warn("ssl_context is specified, connection_verify will be ignored.") + self.connection_verify = ssl_context + else: + self.connection_verify = connection_verify self.custom_endpoint_hostname = None self.hostname = hostname self.use_tls = use_tls diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py index 266864daadd3..051f18631818 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: + from ssl import SSLContext from ._eventprocessor.partition_context import PartitionContext from ._common import EventData from ._client_base import CredentialTypes @@ -124,6 +125,9 @@ class EventHubConsumerClient( authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: str or None + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -244,6 +248,7 @@ def from_connection_string( load_balancing_strategy: Union[str, LoadBalancingStrategy] = LoadBalancingStrategy.GREEDY, custom_endpoint_address: Optional[str] = None, connection_verify: Optional[str] = None, + ssl_context: Optional["SSLContext"] = None, uamqp_transport: bool = False, **kwargs: Any ) -> "EventHubConsumerClient": @@ -319,6 +324,9 @@ def from_connection_string( authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: str or None + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -356,6 +364,7 @@ def from_connection_string( load_balancing_strategy=load_balancing_strategy, custom_endpoint_address=custom_endpoint_address, connection_verify=connection_verify, + ssl_context=ssl_context, uamqp_transport=uamqp_transport, **kwargs, ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index 320bea5c601f..b4ffc8bdc3c6 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -30,6 +30,7 @@ from .exceptions import ConnectError, EventHubError if TYPE_CHECKING: + from ssl import SSLContext from ._client_base import CredentialTypes SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]] @@ -130,6 +131,9 @@ class EventHubProducerClient( authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: Optional[str] + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -466,6 +470,7 @@ def from_connection_string( transport_type: Optional["TransportType"] = TransportType.Amqp, custom_endpoint_address: Optional[str] = None, connection_verify: Optional[str] = None, + ssl_context: Optional["SSLContext"] = None, uamqp_transport: bool = False, **kwargs: Any ) -> "EventHubProducerClient": @@ -553,6 +558,9 @@ def from_connection_string( authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: Optional[str] + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -589,6 +597,7 @@ def from_connection_string( transport_type=transport_type, custom_endpoint_address=custom_endpoint_address, connection_verify=connection_verify, + ssl_context=ssl_context, uamqp_transport=uamqp_transport, **kwargs ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 9431efcdc22c..d6e7c5f80b0a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -34,6 +34,7 @@ from __future__ import absolute_import, unicode_literals +from typing import Dict, Optional, Any import errno import re @@ -46,8 +47,6 @@ import logging from threading import Lock -import certifi - from ._platform import KNOWN_TCP_OPTS, SOL_TCP from ._encode import encode_frame from ._decode import decode_frame, decode_empty_frame @@ -171,7 +170,7 @@ def __init__( socket_timeout=SOCKET_TIMEOUT, socket_settings=None, raise_on_initial_eintr=True, - use_tls: bool =True, + use_tls: bool = True, **kwargs ): self._quick_recv = None @@ -516,20 +515,16 @@ def _setup_transport(self): self.sock = self._wrap_socket(self.sock, **self.sslopts) self._quick_recv = self.sock.recv - def _wrap_socket(self, sock, context=None, **sslopts): + def _wrap_socket( + self, + sock: socket.socket, + context: Optional[ssl.SSLContext] = None, + **sslopts: Dict[str, Any] + ): if context: - return self._wrap_context(sock, sslopts, **context) + return context.wrap_socket(sock, **sslopts) return self._wrap_socket_sni(sock, **sslopts) - def _wrap_context( - self, sock, sslopts, check_hostname=None, **ctx_options - ): - ctx = ssl.create_default_context(**ctx_options) - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.load_verify_locations(cafile=certifi.where()) - ctx.check_hostname = check_hostname - return ctx.wrap_socket(sock, **sslopts) - def _wrap_socket_sni( self, sock, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py index 2c33bad20da9..51bffe2015da 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py @@ -131,9 +131,9 @@ class AMQPClientAsync(AMQPClientSync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. + authenticate the identity of the connection endpoint OR an instance of ssl.SSLContext to be used. Default is None in which case `certifi.where()` will be used. - :paramtype connection_verify: str + :paramtype connection_verify: str or ssl.SSLContext or None :keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp), and 1 for transport type AmqpOverWebsocket. @@ -246,7 +246,7 @@ async def open_async(self, connection=None): self._connection = Connection( "amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname, sasl_credential=self._auth.sasl, - ssl_opts={'ca_certs': self._connection_verify or certifi.where()}, + ssl_opts=self._connection_verify, container_id=self._name, max_frame_size=self._max_frame_size, channel_max=self._channel_max, @@ -495,9 +495,9 @@ class SendClientAsync(SendClientSync, AMQPClientAsync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. + authenticate the identity of the connection endpoint OR an instance of ssl.SSLContext to be used. Default is None in which case `certifi.where()` will be used. - :paramtype connection_verify: str + :paramtype connection_verify: str or ssl.SSLContext or None """ async def _client_ready_async(self): @@ -717,9 +717,9 @@ class ReceiveClientAsync(ReceiveClientSync, AMQPClientAsync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. + authenticate the identity of the connection endpoint OR an instance of ssl.SSLContext to be used. Default is None in which case `certifi.where()` will be used. - :paramtype connection_verify: str + :paramtype connection_verify: str or ssl.SSLContext or None """ async def _client_ready_async(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index 8bdbf9fa4b5e..a0f6e54e5955 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -40,6 +40,7 @@ from ssl import SSLError from io import BytesIO import logging +from typing import Dict, Optional, Any @@ -182,7 +183,7 @@ def _build_ssl_opts(self, sslopts): return sslopts try: if "context" in sslopts: - return self._build_ssl_context(**sslopts.pop("context")) + return sslopts["context"] ssl_version = sslopts.get("ssl_version") if ssl_version is None: ssl_version = ssl.PROTOCOL_TLS_CLIENT @@ -224,15 +225,6 @@ def _build_ssl_opts(self, sslopts): "SSL configuration must be a dictionary, or the value True." ) from None - def _build_ssl_context( - self, check_hostname=None, **ctx_options - ): - ctx = ssl.create_default_context(**ctx_options) - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.load_verify_locations(cafile=certifi.where()) - ctx.check_hostname = check_hostname - return ctx - class AsyncTransport( AsyncTransportMixin diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index f0c7e4ae81ff..acbfaf5e674d 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -145,9 +145,9 @@ class AMQPClient( If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. + authenticate the identity of the connection endpoint OR an instance of ssl.SSLContext to be used. Default is None in which case `certifi.where()` will be used. - :paramtype connection_verify: str + :paramtype connection_verify: str or ssl.SSLContext or None :keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp), and 1 for transport type AmqpOverWebsocket. @@ -212,7 +212,12 @@ def __init__(self, hostname, **kwargs): # Custom Endpoint self._custom_endpoint_address = kwargs.get("custom_endpoint_address") - self._connection_verify = kwargs.get("connection_verify") + connection_verify = kwargs.get("connection_verify") + if connection_verify and not isinstance(connection_verify, str): # ssl.SSLContext + self._connection_verify = {"context": connection_verify} + else: # str or None + self._connection_verify = {"ca_certs": connection_verify or certifi.where()} + # Emulator self._use_tls: bool = kwargs.get("use_tls", True) @@ -317,7 +322,7 @@ def open(self, connection=None): self._connection = Connection( "amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname, sasl_credential=self._auth.sasl, - ssl_opts={"ca_certs": self._connection_verify or certifi.where()}, + ssl_opts=self._connection_verify, container_id=self._name, max_frame_size=self._max_frame_size, channel_max=self._channel_max, @@ -567,9 +572,9 @@ class SendClient(AMQPClient): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. + authenticate the identity of the connection endpoint OR an instance of ssl.SSLContext to be used. Default is None in which case `certifi.where()` will be used. - :paramtype connection_verify: str + :paramtype connection_verify: str or ssl.SSLContext or None """ def __init__(self, hostname, target, **kwargs): @@ -806,9 +811,9 @@ class ReceiveClient(AMQPClient): # pylint:disable=too-many-instance-attributes If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. + authenticate the identity of the connection endpoint OR an instance of ssl.SSLContext to be used. Default is None in which case `certifi.where()` will be used. - :paramtype connection_verify: str + :paramtype connection_verify: str or ssl.SSLContext or None """ def __init__(self, hostname, source, **kwargs): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py index 7834539bbf51..43a54be82f2f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: + from ssl import SSLContext from ._client_base_async import CredentialTypes from ._eventprocessor.partition_context import PartitionContext from ._eventprocessor.checkpoint_store import CheckpointStore @@ -136,6 +137,9 @@ class EventHubConsumerClient( authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: Optional[str] + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -258,6 +262,7 @@ def from_connection_string( load_balancing_strategy: Union[str, LoadBalancingStrategy] = LoadBalancingStrategy.GREEDY, custom_endpoint_address: Optional[str] = None, connection_verify: Optional[str] = None, + ssl_context: Optional["SSLContext"] = None, uamqp_transport: bool = False, **kwargs: Any ) -> "EventHubConsumerClient": @@ -333,6 +338,9 @@ def from_connection_string( authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: Optional[str] + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -369,6 +377,7 @@ def from_connection_string( load_balancing_strategy=load_balancing_strategy, custom_endpoint_address=custom_endpoint_address, connection_verify=connection_verify, + ssl_context=ssl_context, uamqp_transport=uamqp_transport, **kwargs, ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index bdac602669dd..a6bec44be22f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -20,6 +20,7 @@ from .._common import EventDataBatch, EventData if TYPE_CHECKING: + from ssl import SSLContext from ._client_base_async import CredentialTypes SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]] @@ -117,6 +118,9 @@ class EventHubProducerClient( authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: Optional[str] + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -451,6 +455,7 @@ def from_connection_string( transport_type: TransportType = TransportType.Amqp, custom_endpoint_address: Optional[str] = None, connection_verify: Optional[str] = None, + ssl_context: Optional["SSLContext"] = None, uamqp_transport: bool = False, **kwargs: Any ) -> "EventHubProducerClient": @@ -529,6 +534,9 @@ def from_connection_string( authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: Optional[str] + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -564,6 +572,7 @@ def from_connection_string( transport_type=transport_type, custom_endpoint_address=custom_endpoint_address, connection_verify=connection_verify, + ssl_context=ssl_context, uamqp_transport=uamqp_transport, **kwargs ) From fc201952abec31c69f9cba7fa7c10cd40ea0b200 Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 3 Oct 2024 01:03:27 -0700 Subject: [PATCH 03/11] add tests --- sdk/eventhub/azure-eventhub/CHANGELOG.md | 10 ++ sdk/eventhub/azure-eventhub/tests/conftest.py | 9 +- .../livetest/asynctests/test_auth_async.py | 105 +++++++++++++++++ .../tests/livetest/synctests/test_auth.py | 110 +++++++++++++++++- .../synctests/test_consumer_client.py | 1 - 5 files changed, 232 insertions(+), 3 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index c756e771869b..732824a08e23 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -1,8 +1,18 @@ # Release History +<<<<<<< Updated upstream +======= +## 5.12.3 (2024-10-09) + +### Bugs Fixed + +- Fixed a bug where creating the SSL context in the async clients was making a blocking call outside of the constructor.([#37246](https://github.com/Azure/azure-sdk-for-python/issues/37246)) + +>>>>>>> Stashed changes ## 5.12.2 (2024-10-02) ### Bugs Fixed + - Implemented backpressure for async consumer to address a memory leak issue. ([#36398](https://github.com/Azure/azure-sdk-for-python/issues/36398)) ## 5.12.1 (2024-06-11) diff --git a/sdk/eventhub/azure-eventhub/tests/conftest.py b/sdk/eventhub/azure-eventhub/tests/conftest.py index b4da03be45bc..fe7cfe350330 100644 --- a/sdk/eventhub/azure-eventhub/tests/conftest.py +++ b/sdk/eventhub/azure-eventhub/tests/conftest.py @@ -15,7 +15,7 @@ from azure.core.settings import settings from azure.mgmt.eventhub import EventHubManagementClient -from azure.eventhub import EventHubProducerClient +from azure.eventhub import EventHubProducerClient, TransportType from azure.eventhub._pyamqp import ReceiveClient from azure.eventhub._pyamqp.authentication import SASTokenAuth from azure.eventhub.extensions.checkpointstoreblob import BlobCheckpointStore @@ -28,6 +28,9 @@ uamqp_transport_params = [False] uamqp_transport_ids = ["pyamqp"] +socket_transports = [TransportType.Amqp, TransportType.AmqpOverWebsocket] +socket_transport_ids = ["amqp", "ws"] + from devtools_testutils import get_region_override, get_credential as get_devtools_credential from tracing_common import FakeSpan @@ -56,6 +59,10 @@ def sleep(request): def uamqp_transport(request): return request.param +@pytest.fixture(scope="session", params=socket_transports, ids=socket_transport_ids) +def socket_transport(request): + return request.param + @pytest.fixture(scope="session") def storage_url(): try: diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py index 871767420e42..e38218f42f01 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py @@ -6,10 +6,13 @@ import pytest import asyncio import time +import ssl +import certifi from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential from azure.identity.aio import EnvironmentCredential from azure.eventhub import EventData +from azure.eventhub.exceptions import ConnectError from azure.eventhub.aio import ( EventHubConsumerClient, EventHubProducerClient, @@ -175,3 +178,105 @@ async def test_client_azure_named_key_credential_async(live_eventhub, uamqp_tran credential.update(live_eventhub["key_name"], live_eventhub["access_key"]) assert (await consumer_client.get_eventhub_properties()) is not None + +# New feature only for Pure Python AMQP, not uamqp. +@pytest.mark.liveTest +@pytest.mark.asyncio +async def test_client_with_ssl_context_async( + auth_credentials_async, + socket_transport +): + fully_qualified_namespace, eventhub_name, credential = auth_credentials_async + + # Check that SSLContext with invalid/nonexistent cert file raises an error + context = ssl.SSLContext(cafile='fakecert.pem') + context.verify_mode = ssl.CERT_REQUIRED + + producer = EventHubProducerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + transport_type=socket_transport, + ssl_context=context, + retry_total=0, + ) + async with producer: + with pytest.raises(ConnectError): + batch = await producer.create_batch() + + async def on_event(partition_context, event): + on_event.called = True + on_event.partition_id = partition_context.partition_id + on_event.event = event + + async def on_error(partition_context, error): + on_error.error = error + await consumer.close() + + consumer = EventHubConsumerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + consumer_group="$default", + transport_type=socket_transport, + ssl_context=context, + retry_total=0, + ) + on_error.error = None + async with consumer: + task = asyncio.ensure_future( + consumer.receive(on_event, on_error=on_error, starting_position="-1") + ) + await asyncio.sleep(15) + await task + assert isinstance(on_error.error, ConnectError) + + # Check that SSLContext with valid cert file doesn't raise an error + async def verify_context(): + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + asyncio.to_thread(context.load_verify_locations(certifi.where())) + purpose = ssl.Purpose.SERVER_AUTH + asyncio.to_thread(context.load_default_certs(purpose=purpose)) + return context + + context = await verify_context() + + producer = EventHubProducerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + transport_type=socket_transport, + ssl_context=context, + ) + async with producer: + batch = await producer.create_batch() + batch.add(EventData(body="A single message")) + batch.add(EventData(body="A second message")) + await producer.send_batch(batch) + + async def on_event(partition_context, event): + on_event.total += 1 + + async def on_error(partition_context, error): + on_error.error = error + await consumer.close() + + consumer = EventHubConsumerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + consumer_group="$default", + transport_type=socket_transport, + ssl_context=context, + ) + on_event.total = 0 + on_error.error = None + + async with consumer: + task = asyncio.ensure_future( + consumer.receive(on_event, on_error=on_error, starting_position="-1") + ) + await asyncio.sleep(15) + await task + assert on_event.total == 2 + assert on_error.error is None diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py index 24c08b2e5cfd..52a7082351c5 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py @@ -6,14 +6,16 @@ import pytest import time import threading +import ssl +import certifi -from azure.identity import EnvironmentCredential from azure.eventhub import ( EventData, EventHubProducerClient, EventHubConsumerClient, EventHubSharedKeyCredential, ) +from azure.eventhub.exceptions import ConnectError from azure.eventhub._client_base import EventHubSASTokenCredential from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential @@ -172,3 +174,109 @@ def test_client_azure_named_key_credential(live_eventhub, uamqp_transport): credential.update(live_eventhub["key_name"], live_eventhub["access_key"]) assert consumer_client.get_eventhub_properties() is not None + +# New feature only for Pure Python AMQP, not uamqp. +@pytest.mark.liveTest +def test_client_with_ssl_context( + auth_credentials, + socket_transport +): + fully_qualified_namespace, eventhub_name, credential = auth_credentials + + # Check that SSLContext with invalid/nonexistent cert file raises an error + context = ssl.SSLContext(cafile='fakecert.pem') + context.verify_mode = ssl.CERT_REQUIRED + + producer = EventHubProducerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + transport_type=socket_transport, + ssl_context=context, + retry_total=0, + ) + with producer: + with pytest.raises(ConnectError): + batch = producer.create_batch() + + def on_event(partition_context, event): + on_event.called = True + on_event.partition_id = partition_context.partition_id + on_event.event = event + + def on_error(partition_context, error): + on_error.error = error + consumer.close() + + consumer = EventHubConsumerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + consumer_group="$default", + transport_type=socket_transport, + ssl_context=context, + retry_total=0, + ) + on_error.error = None + with consumer: + thread = threading.Thread( + target=consumer.receive, + args=(on_event,), + kwargs={"on_error": on_error, "starting_position": "-1"}, + ) + thread.daemon = True + thread.start() + time.sleep(15) + assert isinstance(on_error.error, ConnectError) + thread.join() + + # Check that SSLContext with valid cert file doesn't raise an error + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.load_verify_locations(certifi.where()) + purpose = ssl.Purpose.SERVER_AUTH + context.load_default_certs(purpose=purpose) + + producer = EventHubProducerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + transport_type=socket_transport, + ssl_context=context, + ) + with producer: + batch = producer.create_batch() + batch.add(EventData(body="A single message")) + batch.add(EventData(body="A second message")) + producer.send_batch(batch) + + def on_event(partition_context, event): + on_event.total += 1 + + def on_error(partition_context, error): + on_error.error = error + consumer.close() + + consumer = EventHubConsumerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + consumer_group="$default", + transport_type=socket_transport, + ssl_context=context, + ) + on_event.total = 0 + on_error.error = None + + with consumer: + thread2 = threading.Thread( + target=consumer.receive, + args=(on_event,), + kwargs={"on_error": on_error, "starting_position": "-1"}, + ) + thread2.daemon = True + thread2.start() + time.sleep(15) + assert on_event.total == 2 + assert on_error.error is None + + thread2.join() diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py index cd4ce62c7927..b8169be5122f 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py @@ -3,7 +3,6 @@ import threading import sys -from azure.core.tracing import SpanKind from azure.eventhub import EventData from azure.eventhub import EventHubConsumerClient from azure.eventhub._eventprocessor.in_memory_checkpoint_store import ( From f0f7b0a21cdf2aedbd9c4675408a1bece1a83cb8 Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 3 Oct 2024 01:05:38 -0700 Subject: [PATCH 04/11] fix merge --- sdk/eventhub/azure-eventhub/CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index 732824a08e23..03d194437b2e 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -1,14 +1,11 @@ # Release History -<<<<<<< Updated upstream -======= ## 5.12.3 (2024-10-09) ### Bugs Fixed - Fixed a bug where creating the SSL context in the async clients was making a blocking call outside of the constructor.([#37246](https://github.com/Azure/azure-sdk-for-python/issues/37246)) ->>>>>>> Stashed changes ## 5.12.2 (2024-10-02) ### Bugs Fixed From d6f113159ddb88e14ff0b766963916005d8acd76 Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 3 Oct 2024 09:02:26 -0700 Subject: [PATCH 05/11] fix failing/lint/mypy --- .../azure/eventhub/_configuration.py | 1 + .../azure/eventhub/_pyamqp/_transport.py | 8 ++++---- .../eventhub/_pyamqp/aio/_client_async.py | 1 - .../eventhub/_pyamqp/aio/_transport_async.py | 5 ----- .../livetest/asynctests/test_auth_async.py | 19 +++++++++++++++---- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py index ce2445ae6043..e4595b7609e0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py @@ -60,6 +60,7 @@ def __init__( self.receive_timeout = receive_timeout self.send_timeout = send_timeout self.custom_endpoint_address = custom_endpoint_address + self.connection_verify: Optional[Union[str, "SSLContext"]] if ssl_context: if connection_verify: warnings.warn("ssl_context is specified, connection_verify will be ignored.") diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index d6e7c5f80b0a..12feee10ca7f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -517,11 +517,11 @@ def _setup_transport(self): def _wrap_socket( self, - sock: socket.socket, - context: Optional[ssl.SSLContext] = None, - **sslopts: Dict[str, Any] + sock, + **sslopts ): - if context: + if "context" in sslopts: + context = sslopts.pop("context") return context.wrap_socket(sock, **sslopts) return self._wrap_socket_sni(sock, **sslopts) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py index 51bffe2015da..7b2acfa44fbd 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py @@ -12,7 +12,6 @@ from functools import partial from typing import Any, Callable, Coroutine, List, Dict, Optional, Tuple, Union, overload, cast from typing_extensions import Literal -import certifi from ..outcomes import Accepted, Modified, Received, Rejected, Released from ._connection_async import Connection diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index a0f6e54e5955..80b70270451c 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -40,11 +40,6 @@ from ssl import SSLError from io import BytesIO import logging -from typing import Dict, Optional, Any - - - -import certifi from .._platform import KNOWN_TCP_OPTS, SOL_TCP from .._encode import encode_frame diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py index e38218f42f01..0f735c222ce3 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py @@ -232,14 +232,25 @@ async def on_error(partition_context, error): assert isinstance(on_error.error, ConnectError) # Check that SSLContext with valid cert file doesn't raise an error - async def verify_context(): + async def verify_context_async(): + # asyncio.to_thread only available in Python 3.9+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - asyncio.to_thread(context.load_verify_locations(certifi.where())) + await asyncio.to_thread(context.load_verify_locations(certifi.where())) purpose = ssl.Purpose.SERVER_AUTH - asyncio.to_thread(context.load_default_certs(purpose=purpose)) + await asyncio.to_thread(context.load_default_certs(purpose=purpose)) return context - context = await verify_context() + def verify_context(): # for Python 3.8 + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.load_verify_locations(certifi.where()) + purpose = ssl.Purpose.SERVER_AUTH + context.load_default_certs(purpose=purpose) + return context + + if hasattr(asyncio, "to_thread"): + context = await verify_context_async() + else: + context = verify_context() producer = EventHubProducerClient( fully_qualified_namespace=fully_qualified_namespace, From a570a329dd8746f008c37ec77012f938ea05a92d Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 3 Oct 2024 10:32:16 -0700 Subject: [PATCH 06/11] lint --- sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 12feee10ca7f..01d1a7182408 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -34,7 +34,6 @@ from __future__ import absolute_import, unicode_literals -from typing import Dict, Optional, Any import errno import re From 49a98a792ed35ee1aecb42f91d055250194b4e64 Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 31 Oct 2024 22:28:59 -0700 Subject: [PATCH 07/11] separate ssl_context property from conn verify --- .../azure/eventhub/_configuration.py | 11 +++--- .../eventhub/_pyamqp/aio/_client_async.py | 26 ++++++++------ .../azure/eventhub/_pyamqp/client.py | 34 ++++++++++++------- .../eventhub/_transport/_pyamqp_transport.py | 3 ++ .../aio/_transport/_pyamqp_transport_async.py | 3 ++ .../livetest/asynctests/test_auth_async.py | 4 +-- 6 files changed, 49 insertions(+), 32 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py index e4595b7609e0..ca217ecd0c80 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py @@ -60,13 +60,10 @@ def __init__( self.receive_timeout = receive_timeout self.send_timeout = send_timeout self.custom_endpoint_address = custom_endpoint_address - self.connection_verify: Optional[Union[str, "SSLContext"]] - if ssl_context: - if connection_verify: - warnings.warn("ssl_context is specified, connection_verify will be ignored.") - self.connection_verify = ssl_context - else: - self.connection_verify = connection_verify + self.connection_verify = connection_verify + self.ssl_context = ssl_context + if self.ssl_context and self.connection_verify: + warnings.warn("ssl_context is specified, connection_verify will be ignored.") self.custom_endpoint_hostname = None self.hostname = hostname self.use_tls = use_tls diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py index 7b2acfa44fbd..49d5e69a1023 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py @@ -130,9 +130,11 @@ class AMQPClientAsync(AMQPClientSync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint OR an instance of ssl.SSLContext to be used. - Default is None in which case `certifi.where()` will be used. - :paramtype connection_verify: str or ssl.SSLContext or None + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. + :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp), and 1 for transport type AmqpOverWebsocket. @@ -245,7 +247,7 @@ async def open_async(self, connection=None): self._connection = Connection( "amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname, sasl_credential=self._auth.sasl, - ssl_opts=self._connection_verify, + ssl_opts=self._ssl_opts, container_id=self._name, max_frame_size=self._max_frame_size, channel_max=self._channel_max, @@ -494,9 +496,11 @@ class SendClientAsync(SendClientSync, AMQPClientAsync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint OR an instance of ssl.SSLContext to be used. - Default is None in which case `certifi.where()` will be used. - :paramtype connection_verify: str or ssl.SSLContext or None + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. + :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ async def _client_ready_async(self): @@ -716,9 +720,11 @@ class ReceiveClientAsync(ReceiveClientSync, AMQPClientAsync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint OR an instance of ssl.SSLContext to be used. - Default is None in which case `certifi.where()` will be used. - :paramtype connection_verify: str or ssl.SSLContext or None + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. + :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ async def _client_ready_async(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index acbfaf5e674d..1989b469ac42 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -145,9 +145,11 @@ class AMQPClient( If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint OR an instance of ssl.SSLContext to be used. - Default is None in which case `certifi.where()` will be used. - :paramtype connection_verify: str or ssl.SSLContext or None + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. + :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp), and 1 for transport type AmqpOverWebsocket. @@ -213,10 +215,12 @@ def __init__(self, hostname, **kwargs): # Custom Endpoint self._custom_endpoint_address = kwargs.get("custom_endpoint_address") connection_verify = kwargs.get("connection_verify") - if connection_verify and not isinstance(connection_verify, str): # ssl.SSLContext - self._connection_verify = {"context": connection_verify} + ssl_context = kwargs.get("ssl_context") + self._ssl_opts = {} + if ssl_context: + self._ssl_opts["context"] = ssl_context else: # str or None - self._connection_verify = {"ca_certs": connection_verify or certifi.where()} + self._ssl_opts["ca_certs"] = connection_verify or certifi.where() # Emulator @@ -322,7 +326,7 @@ def open(self, connection=None): self._connection = Connection( "amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname, sasl_credential=self._auth.sasl, - ssl_opts=self._connection_verify, + ssl_opts=self._ssl_opts, container_id=self._name, max_frame_size=self._max_frame_size, channel_max=self._channel_max, @@ -572,9 +576,11 @@ class SendClient(AMQPClient): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint OR an instance of ssl.SSLContext to be used. - Default is None in which case `certifi.where()` will be used. - :paramtype connection_verify: str or ssl.SSLContext or None + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. + :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ def __init__(self, hostname, target, **kwargs): @@ -811,9 +817,11 @@ class ReceiveClient(AMQPClient): # pylint:disable=too-many-instance-attributes If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint OR an instance of ssl.SSLContext to be used. - Default is None in which case `certifi.where()` will be used. - :paramtype connection_verify: str or ssl.SSLContext or None + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. + :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ def __init__(self, hostname, source, **kwargs): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py index cd90adcdde2d..0ffd6016a511 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py @@ -340,6 +340,7 @@ def create_send_client(# pylint: disable=unused-argument target, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, transport_type=config.transport_type, http_proxy=config.http_proxy, socket_timeout=config.socket_timeout, @@ -494,6 +495,7 @@ def create_receive_client( transport_type=config.transport_type, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, socket_timeout=config.socket_timeout, auth=auth, idle_timeout=idle_timeout, @@ -605,6 +607,7 @@ def create_mgmt_client( http_proxy=config.http_proxy, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, use_tls=config.use_tls, ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py index e6ad90bc9792..b1bd9c82d047 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py @@ -131,6 +131,7 @@ def create_send_client(# pylint: disable=unused-argument target, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, transport_type=config.transport_type, http_proxy=config.http_proxy, socket_timeout=config.socket_timeout, @@ -215,6 +216,7 @@ def create_receive_client( transport_type=config.transport_type, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, socket_timeout=config.socket_timeout, auth=auth, idle_timeout=idle_timeout, @@ -391,6 +393,7 @@ def create_mgmt_client(address, mgmt_auth, config): # pylint: disable=unused-ar http_proxy=config.http_proxy, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, use_tls=config.use_tls, ) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py index 0f735c222ce3..b1cb39d4f43e 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py @@ -235,9 +235,9 @@ async def on_error(partition_context, error): async def verify_context_async(): # asyncio.to_thread only available in Python 3.9+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - await asyncio.to_thread(context.load_verify_locations(certifi.where())) + await asyncio.to_thread(context.load_verify_locations, certifi.where()) purpose = ssl.Purpose.SERVER_AUTH - await asyncio.to_thread(context.load_default_certs(purpose=purpose)) + await asyncio.to_thread(context.load_default_certs, purpose=purpose) return context def verify_context(): # for Python 3.8 From 02d0e3c5a65463b3da2ef542e9bea3450b60680c Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 31 Oct 2024 23:07:50 -0700 Subject: [PATCH 08/11] make sb changes --- sdk/eventhub/azure-eventhub/CHANGELOG.md | 6 ++-- sdk/servicebus/azure-servicebus/CHANGELOG.md | 2 ++ .../servicebus/_common/_configuration.py | 2 ++ .../azure/servicebus/_pyamqp/_transport.py | 20 +++++-------- .../servicebus/_pyamqp/aio/_client_async.py | 21 +++++++++----- .../_pyamqp/aio/_transport_async.py | 11 +------ .../azure/servicebus/_pyamqp/client.py | 29 ++++++++++++++----- .../azure/servicebus/_servicebus_client.py | 19 +++++++++++- .../_transport/_pyamqp_transport.py | 3 +- .../aio/_servicebus_client_async.py | 18 +++++++++++- .../aio/_transport/_pyamqp_transport_async.py | 3 +- 11 files changed, 89 insertions(+), 45 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index 03d194437b2e..22e6f1a75e12 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -1,10 +1,10 @@ # Release History -## 5.12.3 (2024-10-09) +## 5.12.3 (Unreleased) -### Bugs Fixed +### Features Added -- Fixed a bug where creating the SSL context in the async clients was making a blocking call outside of the constructor.([#37246](https://github.com/Azure/azure-sdk-for-python/issues/37246)) +- Added `ssl_context` parameter to the clients allow users to pass in the SSL context, in which case, `connection_verify` will be ignored if specified. ## 5.12.2 (2024-10-02) diff --git a/sdk/servicebus/azure-servicebus/CHANGELOG.md b/sdk/servicebus/azure-servicebus/CHANGELOG.md index b31709522598..2ef5f7c4c25f 100644 --- a/sdk/servicebus/azure-servicebus/CHANGELOG.md +++ b/sdk/servicebus/azure-servicebus/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Added `ssl_context` parameter to the clients allow users to pass in the SSL context, in which case, `connection_verify` will be ignored if specified.([#37246](https://github.com/Azure/azure-sdk-for-python/issues/37246)) + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_configuration.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_configuration.py index 9a1391cdcc3f..79f66e463b01 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_configuration.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_configuration.py @@ -9,6 +9,7 @@ from .constants import DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT, TransportType if TYPE_CHECKING: + from ssl import SSLContext from .._transport._base import AmqpTransport from ..aio._transport._base_async import AmqpTransportAsync @@ -25,6 +26,7 @@ def __init__(self, **kwargs): self.custom_endpoint_address: Optional[str] = kwargs.get("custom_endpoint_address") self.connection_verify: Optional[str] = kwargs.get("connection_verify") + self.ssl_context: Optional["SSLContext"] = kwargs.get("ssl_context") self.connection_port = DEFAULT_AMQPS_PORT self.custom_endpoint_hostname = None self.hostname = kwargs.pop("hostname") diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py index 5c3c5277b1ec..3c31d40fe768 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py @@ -46,8 +46,6 @@ import logging from threading import Lock -import certifi - from ._platform import KNOWN_TCP_OPTS, SOL_TCP from ._encode import encode_frame from ._decode import decode_frame, decode_empty_frame @@ -503,18 +501,16 @@ def _setup_transport(self): self.sock = self._wrap_socket(self.sock, **self.sslopts) self._quick_recv = self.sock.recv - def _wrap_socket(self, sock, context=None, **sslopts): - if context: - return self._wrap_context(sock, sslopts, **context) + def _wrap_socket( + self, + sock, + **sslopts + ): + if "context" in sslopts: + context = sslopts.pop("context") + return context.wrap_socket(sock, **sslopts) return self._wrap_socket_sni(sock, **sslopts) - def _wrap_context(self, sock, sslopts, check_hostname=None, **ctx_options): - ctx = ssl.create_default_context(**ctx_options) - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.load_verify_locations(cafile=certifi.where()) - ctx.check_hostname = check_hostname - return ctx.wrap_socket(sock, **sslopts) - def _wrap_socket_sni( self, sock, diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_client_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_client_async.py index 1ca7ba786654..6ab90f4659fe 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_client_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_client_async.py @@ -12,7 +12,6 @@ from functools import partial from typing import Any, Callable, Coroutine, List, Dict, Optional, Tuple, Union, overload, cast from typing_extensions import Literal -import certifi from ..outcomes import Accepted, Modified, Received, Rejected, Released from ._connection_async import Connection @@ -126,9 +125,11 @@ class AMQPClientAsync(AMQPClientSync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp), and 1 for transport type AmqpOverWebsocket. @@ -238,7 +239,7 @@ async def open_async(self, connection=None): self._connection = Connection( "amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname, sasl_credential=self._auth.sasl, - ssl_opts={"ca_certs": self._connection_verify or certifi.where()}, + ssl_opts=self._ssl_opts, container_id=self._name, max_frame_size=self._max_frame_size, channel_max=self._channel_max, @@ -480,9 +481,11 @@ class SendClientAsync(SendClientSync, AMQPClientAsync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ async def _client_ready_async(self): @@ -686,9 +689,11 @@ class ReceiveClientAsync(ReceiveClientSync, AMQPClientAsync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ async def _client_ready_async(self): diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py index 0778a89962eb..59cac3ba9338 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py @@ -42,8 +42,6 @@ import logging -import certifi - from .._platform import KNOWN_TCP_OPTS, SOL_TCP from .._encode import encode_frame from .._decode import decode_frame, decode_empty_frame @@ -167,7 +165,7 @@ def _build_ssl_opts(self, sslopts): return sslopts try: if "context" in sslopts: - return self._build_ssl_context(**sslopts.pop("context")) + return sslopts["context"] ssl_version = sslopts.get("ssl_version") if ssl_version is None: ssl_version = ssl.PROTOCOL_TLS_CLIENT @@ -206,13 +204,6 @@ def _build_ssl_opts(self, sslopts): except TypeError: raise TypeError("SSL configuration must be a dictionary, or the value True.") from None - def _build_ssl_context(self, check_hostname=None, **ctx_options): - ctx = ssl.create_default_context(**ctx_options) - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.load_verify_locations(cafile=certifi.where()) - ctx.check_hostname = check_hostname - return ctx - class AsyncTransport(AsyncTransportMixin): # pylint: disable=too-many-instance-attributes """Common superclass for TCP and SSL transports.""" diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py index 34613ea6a0a2..0b6e3d506b82 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py @@ -143,9 +143,11 @@ class AMQPClient(object): # pylint: disable=too-many-instance-attributes If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp), and 1 for transport type AmqpOverWebsocket. @@ -201,7 +203,14 @@ def __init__(self, hostname, **kwargs): # Custom Endpoint self._custom_endpoint_address = kwargs.get("custom_endpoint_address") - self._connection_verify = kwargs.get("connection_verify") + connection_verify = kwargs.get("connection_verify") + ssl_context = kwargs.get("ssl_context") + self._ssl_opts = {} + if ssl_context: + self._ssl_opts["context"] = ssl_context + else: # str or None + self._ssl_opts["ca_certs"] = connection_verify or certifi.where() + # Emulator self._use_tls: bool = kwargs.get("use_tls", True) @@ -306,7 +315,7 @@ def open(self, connection=None): self._connection = Connection( "amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname, sasl_credential=self._auth.sasl, - ssl_opts={"ca_certs": self._connection_verify or certifi.where()}, + ssl_opts=self._ssl_opts, container_id=self._name, max_frame_size=self._max_frame_size, channel_max=self._channel_max, @@ -556,9 +565,11 @@ class SendClient(AMQPClient): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ def __init__(self, hostname, target, **kwargs): @@ -779,9 +790,11 @@ class ReceiveClient(AMQPClient): # pylint:disable=too-many-instance-attributes If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ def __init__(self, hostname, source, **kwargs): diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py index 320e928fb20c..402d651a7c7d 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py @@ -86,6 +86,9 @@ class ServiceBusClient(object): # pylint: disable=client-accepts-api-version-ke :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -145,6 +148,7 @@ def __init__( self._handlers: WeakSet = WeakSet() self._custom_endpoint_address = kwargs.get("custom_endpoint_address") self._connection_verify = kwargs.get("connection_verify") + self._ssl_context = kwargs.get("ssl_context") def __enter__(self) -> "ServiceBusClient": if self._connection_sharing: @@ -156,12 +160,17 @@ def __exit__(self, *args: Any) -> None: def _create_connection(self): auth = create_authentication(self) + if self._ssl_context: + ssl_opts = {"context": self._ssl_context} + else: + ssl_opts = {"ca_certs": self._connection_verify or certifi.where()} + self._connection = self._amqp_transport.create_connection( host=self.fully_qualified_namespace, auth=auth.sasl, network_trace=self._config.logging_enable, custom_endpoint_address=self._custom_endpoint_address, - ssl_opts={"ca_certs": self._connection_verify or certifi.where()}, + ssl_opts=ssl_opts, transport_type=self._config.transport_type, http_proxy=self._config.http_proxy, ) @@ -229,6 +238,9 @@ def from_connection_string( :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -308,6 +320,7 @@ def get_queue_sender(self, queue_name: str, **kwargs: Any) -> ServiceBusSender: retry_backoff_max=self._config.retry_backoff_max, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -428,6 +441,7 @@ def get_queue_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -487,6 +501,7 @@ def get_topic_sender( retry_backoff_max=self._config.retry_backoff_max, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, client_identifier=client_identifier, socket_timeout=socket_timeout, @@ -611,6 +626,7 @@ def get_subscription_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -639,6 +655,7 @@ def get_subscription_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_pyamqp_transport.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_pyamqp_transport.py index fbb78214e9ec..e2000b42d083 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_pyamqp_transport.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_pyamqp_transport.py @@ -460,6 +460,7 @@ def create_send_client(config: "Configuration", **kwargs: Any) -> "SendClient": keep_alive_interval=config.keep_alive, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, transport_type=config.transport_type, http_proxy=config.http_proxy, socket_timeout=config.socket_timeout, @@ -583,6 +584,7 @@ def create_receive_client(receiver: "ServiceBusReceiver", **kwargs: "Any") -> "R transport_type=config.transport_type, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, receive_settle_mode=PyamqpTransport.ServiceBusToAMQPReceiveModeMap[receive_mode], send_settle_mode=( constants.SenderSettleMode.Settled @@ -851,7 +853,6 @@ def create_token_auth( timeout=config.auth_timeout, custom_endpoint_hostname=config.custom_endpoint_hostname, port=config.connection_port, - verify=config.connection_verify, ) @staticmethod diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py index 7187cb430263..0fa098a7d503 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py @@ -79,6 +79,9 @@ class ServiceBusClient(object): # pylint: disable=client-accepts-api-version-ke :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -135,6 +138,7 @@ def __init__( self._handlers: WeakSet = WeakSet() self._custom_endpoint_address = kwargs.get("custom_endpoint_address") self._connection_verify = kwargs.get("connection_verify") + self._ssl_context = kwargs.get("ssl_context") async def __aenter__(self) -> "ServiceBusClient": if self._connection_sharing: @@ -146,12 +150,16 @@ async def __aexit__(self, *args: Any) -> None: async def _create_connection(self): auth = await create_authentication(self) + if self._ssl_context: + ssl_opts = {"context": self._ssl_context} + else: + ssl_opts = {"ca_certs": self._connection_verify or certifi.where()} self._connection = self._amqp_transport.create_connection_async( host=self.fully_qualified_namespace, auth=auth.sasl, network_trace=self._config.logging_enable, custom_endpoint_address=self._custom_endpoint_address, - ssl_opts={"ca_certs": self._connection_verify or certifi.where()}, + ssl_opts=ssl_opts, transport_type=self._config.transport_type, http_proxy=self._config.http_proxy, ) @@ -197,6 +205,9 @@ def from_connection_string( :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -297,6 +308,7 @@ def get_queue_sender(self, queue_name: str, **kwargs: Any) -> ServiceBusSender: retry_backoff_max=self._config.retry_backoff_max, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -417,6 +429,7 @@ def get_queue_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -469,6 +482,7 @@ def get_topic_sender(self, topic_name: str, **kwargs: Any) -> ServiceBusSender: retry_backoff_max=self._config.retry_backoff_max, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -591,6 +605,7 @@ def get_subscription_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -619,6 +634,7 @@ def get_subscription_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_pyamqp_transport_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_pyamqp_transport_async.py index f3cf1f28bbd6..f90ee4d8db53 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_pyamqp_transport_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_pyamqp_transport_async.py @@ -112,6 +112,7 @@ def create_send_client_async(config: "Configuration", **kwargs: Any) -> "SendCli keep_alive_interval=config.keep_alive, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, transport_type=config.transport_type, http_proxy=config.http_proxy, socket_timeout=config.socket_timeout, @@ -209,6 +210,7 @@ def create_receive_client_async( transport_type=config.transport_type, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, receive_settle_mode=PyamqpTransportAsync.ServiceBusToAMQPReceiveModeMap[receive_mode], send_settle_mode=( constants.SenderSettleMode.Settled @@ -398,7 +400,6 @@ async def create_token_auth_async( timeout=config.auth_timeout, custom_endpoint_hostname=config.custom_endpoint_hostname, port=config.connection_port, - verify=config.connection_verify, ) @staticmethod From 22d73364f1b872df17e746d506c10ac87a2c66e0 Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 31 Oct 2024 23:56:32 -0700 Subject: [PATCH 09/11] add tests --- .../tests/async_tests/test_sb_client_async.py | 84 ++++++++++++++++++- .../azure-servicebus/tests/test_sb_client.py | 69 ++++++++++++++- .../azure-servicebus/tests/utilities.py | 20 +++++ 3 files changed, 171 insertions(+), 2 deletions(-) diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py index ccc8a151e73c..ab9366cf879d 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py @@ -13,6 +13,7 @@ import hashlib import base64 import certifi +import ssl from urllib.parse import quote as url_parse_quote from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, AccessToken @@ -39,9 +40,10 @@ CachedServiceBusResourceGroupPreparer, SERVICEBUS_ENDPOINT_SUFFIX, ) -from tests.utilities import get_logger, uamqp_transport as get_uamqp_transport, ArgPasserAsync +from tests.utilities import get_logger, uamqp_transport as get_uamqp_transport, ArgPasserAsync, SocketArgPasserAsync, socket_transport as get_socket_transport uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() +socket_transport_params, socket_transport_ids = get_socket_transport() _logger = get_logger(logging.DEBUG) @@ -553,6 +555,86 @@ async def test_azure_named_key_credential_async( async with client.get_queue_sender(servicebus_queue.name) as sender: await sender.send_messages(ServiceBusMessage("foo")) + @pytest.mark.asyncio + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedServiceBusResourceGroupPreparer() + @CachedServiceBusNamespacePreparer(name_prefix="servicebustest") + @CachedServiceBusQueuePreparer(name_prefix="servicebustest", dead_lettering_on_message_expiration=True) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @pytest.mark.parametrize("socket_transport", socket_transport_params, ids=socket_transport_ids) + @SocketArgPasserAsync() + async def test_sb_client_with_ssl_context_async( + self, + uamqp_transport, + socket_transport, + *, + servicebus_namespace=None, + servicebus_queue=None, + **kwargs, + ): + fully_qualified_namespace = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" + credential = get_credential(is_async=True) + + # Check that SSLContext with invalid/nonexistent cert file raises an error + context = ssl.SSLContext(cafile='fakecert.pem') + context.verify_mode = ssl.CERT_REQUIRED + client = ServiceBusClient( + fully_qualified_namespace=fully_qualified_namespace, + credential=credential, + uamqp_transport=uamqp_transport, + ssl_context=context, + transport_type=socket_transport, + retry_total=0, + ) + async with client: + with pytest.raises(ServiceBusConnectionError): + async with client.get_queue_sender(servicebus_queue.name) as sender: + await sender.send_messages(ServiceBusMessage("test")) + + with pytest.raises(ServiceBusConnectionError): + async with client.get_queue_receiver(servicebus_queue.name) as receiver: + messages = await receiver.receive_messages(max_message_count=1, max_wait_time=1) + + # Check that SSLContext with valid cert file doesn't raise an error + async def verify_context_async(): + # asyncio.to_thread only available in Python 3.9+ + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + await asyncio.to_thread(context.load_verify_locations, certifi.where()) + purpose = ssl.Purpose.SERVER_AUTH + await asyncio.to_thread(context.load_default_certs, purpose=purpose) + return context + + def verify_context(): # for Python 3.8 + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.load_verify_locations(certifi.where()) + purpose = ssl.Purpose.SERVER_AUTH + context.load_default_certs(purpose=purpose) + return context + + if hasattr(asyncio, "to_thread"): + context = await verify_context_async() + else: + context = verify_context() + + client = ServiceBusClient( + fully_qualified_namespace=fully_qualified_namespace, + credential=credential, + uamqp_transport=uamqp_transport, + ssl_context=context, + transport_type=socket_transport, + ) + + async with client: + async with client.get_queue_sender(servicebus_queue.name) as sender: + await sender.send_messages(ServiceBusMessage("test")) + await sender.send_messages(ServiceBusMessage("test")) + + async with client.get_queue_receiver(servicebus_queue.name) as receiver: + messages = await receiver.receive_messages(max_message_count=2, max_wait_time=10) + + assert len(messages) == 2 + @pytest.mark.asyncio @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) async def test_backoff_fixed_retry(self, uamqp_transport): diff --git a/sdk/servicebus/azure-servicebus/tests/test_sb_client.py b/sdk/servicebus/azure-servicebus/tests/test_sb_client.py index 4bf87ff9b978..086711c1faec 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_sb_client.py +++ b/sdk/servicebus/azure-servicebus/tests/test_sb_client.py @@ -14,6 +14,7 @@ import hashlib import base64 import certifi +import ssl try: from urllib.parse import quote as url_parse_quote @@ -52,9 +53,10 @@ CachedServiceBusResourceGroupPreparer, SERVICEBUS_ENDPOINT_SUFFIX, ) -from utilities import uamqp_transport as get_uamqp_transport, ArgPasser +from utilities import uamqp_transport as get_uamqp_transport, ArgPasser, socket_transport as get_socket_transport, SocketArgPasser uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() +socket_transport_params, socket_transport_ids = get_socket_transport() class TestServiceBusClient(AzureMgmtRecordedTestCase): @@ -568,6 +570,71 @@ def test_azure_named_key_credential( with client.get_queue_sender(servicebus_queue.name) as sender: sender.send_messages(ServiceBusMessage("foo")) + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedServiceBusResourceGroupPreparer() + @CachedServiceBusNamespacePreparer(name_prefix="servicebustest") + @CachedServiceBusQueuePreparer(name_prefix="servicebustest", dead_lettering_on_message_expiration=True) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @pytest.mark.parametrize("socket_transport", socket_transport_params, ids=socket_transport_ids) + @SocketArgPasser() + def test_sb_client_with_ssl_context( + self, + uamqp_transport, + socket_transport, + *, + servicebus_namespace=None, + servicebus_queue=None, + **kwargs, + ): + fully_qualified_namespace = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" + credential = get_credential() + + # Check that SSLContext with invalid/nonexistent cert file raises an error + context = ssl.SSLContext(cafile='fakecert.pem') + context.verify_mode = ssl.CERT_REQUIRED + client = ServiceBusClient( + fully_qualified_namespace=fully_qualified_namespace, + credential=credential, + uamqp_transport=uamqp_transport, + ssl_context=context, + transport_type=socket_transport, + retry_total=0, + ) + with client: + with pytest.raises(ServiceBusConnectionError): + with client.get_queue_sender(servicebus_queue.name) as sender: + sender.send_messages(ServiceBusMessage("test")) + + with pytest.raises(ServiceBusConnectionError): + with client.get_queue_receiver(servicebus_queue.name) as receiver: + messages = receiver.receive_messages(max_message_count=1, max_wait_time=1) + + # Check that SSLContext with valid cert file doesn't raise an error + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.load_verify_locations(certifi.where()) + purpose = ssl.Purpose.SERVER_AUTH + context.load_default_certs(purpose=purpose) + + client = ServiceBusClient( + fully_qualified_namespace=fully_qualified_namespace, + credential=credential, + uamqp_transport=uamqp_transport, + ssl_context=context, + transport_type=socket_transport, + ) + + with client: + with client.get_queue_sender(servicebus_queue.name) as sender: + sender.send_messages(ServiceBusMessage("test")) + sender.send_messages(ServiceBusMessage("test")) + + with client.get_queue_receiver(servicebus_queue.name) as receiver: + messages = receiver.receive_messages(max_message_count=2, max_wait_time=10) + + assert len(messages) == 2 + + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) def test_backoff_fixed_retry(self, uamqp_transport): diff --git a/sdk/servicebus/azure-servicebus/tests/utilities.py b/sdk/servicebus/azure-servicebus/tests/utilities.py index ada827fae6ef..0a7deb4998a8 100644 --- a/sdk/servicebus/azure-servicebus/tests/utilities.py +++ b/sdk/servicebus/azure-servicebus/tests/utilities.py @@ -15,6 +15,7 @@ uamqp_available = True except (ModuleNotFoundError, ImportError): uamqp_available = False +from azure.servicebus import TransportType from azure.servicebus._common.utils import utc_now # temporary - disable uamqp if China b/c of 8+ hr runtime @@ -76,6 +77,10 @@ def uamqp_transport(use_uamqp=uamqp_available, use_pyamqp=test_pyamqp): uamqp_transport_ids.append("pyamqp") return uamqp_transport_params, uamqp_transport_ids +def socket_transport(): + socket_transport_params = [TransportType.Amqp, TransportType.AmqpOverWebsocket] + socket_transport_ids = ["amqp", "ws"] + return socket_transport_params, socket_transport_ids class ArgPasser: def __call__(self, fn): @@ -91,3 +96,18 @@ async def _preparer(test_class, uamqp_transport, **kwargs): await fn(test_class, uamqp_transport=uamqp_transport, **kwargs) return _preparer + +class SocketArgPasser: + def __call__(self, fn): + def _preparer(test_class, uamqp_transport, socket_transport, **kwargs): + fn(test_class, uamqp_transport=uamqp_transport, socket_transport=socket_transport, **kwargs) + + return _preparer + + +class SocketArgPasserAsync: + def __call__(self, fn): + async def _preparer(test_class, uamqp_transport, socket_transport, **kwargs): + await fn(test_class, uamqp_transport=uamqp_transport, socket_transport=socket_transport, **kwargs) + + return _preparer From 0af59f0db67deb26f816bb06f098091ff9be4ec7 Mon Sep 17 00:00:00 2001 From: swathipil Date: Thu, 31 Oct 2024 23:58:47 -0700 Subject: [PATCH 10/11] changelog --- sdk/eventhub/azure-eventhub/CHANGELOG.md | 2 +- sdk/servicebus/azure-servicebus/CHANGELOG.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index 22e6f1a75e12..237d4e17af83 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -4,7 +4,7 @@ ### Features Added -- Added `ssl_context` parameter to the clients allow users to pass in the SSL context, in which case, `connection_verify` will be ignored if specified. +- Added `ssl_context` parameter to the clients to allow users to pass in the SSL context, in which case, `connection_verify` will be ignored if specified. ## 5.12.2 (2024-10-02) diff --git a/sdk/servicebus/azure-servicebus/CHANGELOG.md b/sdk/servicebus/azure-servicebus/CHANGELOG.md index 2ef5f7c4c25f..b90d72401b96 100644 --- a/sdk/servicebus/azure-servicebus/CHANGELOG.md +++ b/sdk/servicebus/azure-servicebus/CHANGELOG.md @@ -4,7 +4,7 @@ ### Features Added -- Added `ssl_context` parameter to the clients allow users to pass in the SSL context, in which case, `connection_verify` will be ignored if specified.([#37246](https://github.com/Azure/azure-sdk-for-python/issues/37246)) +- Added `ssl_context` parameter to the clients to allow users to pass in the SSL context, in which case, `connection_verify` will be ignored if specified.([#37246](https://github.com/Azure/azure-sdk-for-python/issues/37246)) ### Breaking Changes From 972bfe3c6b4119233cf265ef689dbf7fdc7d6565 Mon Sep 17 00:00:00 2001 From: swathipil Date: Fri, 1 Nov 2024 08:14:33 -0700 Subject: [PATCH 11/11] black --- .../azure/eventhub/_pyamqp/_transport.py | 6 +---- .../azure/eventhub/_pyamqp/client.py | 3 +-- sdk/eventhub/azure-eventhub/tests/conftest.py | 6 ++--- .../livetest/asynctests/test_auth_async.py | 24 +++++++------------ .../tests/livetest/synctests/test_auth.py | 14 +++++------ .../azure/servicebus/_pyamqp/_transport.py | 6 +---- .../azure/servicebus/_pyamqp/client.py | 3 +-- .../tests/async_tests/test_sb_client_async.py | 12 +++++++--- .../azure-servicebus/tests/test_sb_client.py | 16 ++++++++----- .../azure-servicebus/tests/utilities.py | 3 +++ 10 files changed, 44 insertions(+), 49 deletions(-) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 3c31d40fe768..da81f1fcec1f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -501,11 +501,7 @@ def _setup_transport(self): self.sock = self._wrap_socket(self.sock, **self.sslopts) self._quick_recv = self.sock.recv - def _wrap_socket( - self, - sock, - **sslopts - ): + def _wrap_socket(self, sock, **sslopts): if "context" in sslopts: context = sslopts.pop("context") return context.wrap_socket(sock, **sslopts) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index 0b6e3d506b82..88644f6e798f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -208,10 +208,9 @@ def __init__(self, hostname, **kwargs): self._ssl_opts = {} if ssl_context: self._ssl_opts["context"] = ssl_context - else: # str or None + else: # str or None self._ssl_opts["ca_certs"] = connection_verify or certifi.where() - # Emulator self._use_tls: bool = kwargs.get("use_tls", True) diff --git a/sdk/eventhub/azure-eventhub/tests/conftest.py b/sdk/eventhub/azure-eventhub/tests/conftest.py index 5e5bdcf4d71c..d00b169ef88f 100644 --- a/sdk/eventhub/azure-eventhub/tests/conftest.py +++ b/sdk/eventhub/azure-eventhub/tests/conftest.py @@ -60,13 +60,13 @@ def sleep(request): def uamqp_transport(request): return request.param - + @pytest.fixture(scope="session", params=socket_transports, ids=socket_transport_ids) def socket_transport(request): return request.param - -@pytest.fixture(scope="session") + +@pytest.fixture(scope="session") def storage_url(): try: account_name = os.environ["AZURE_STORAGE_ACCOUNT"] diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py index 42fc0a2ee876..b89b952b3588 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py @@ -169,17 +169,15 @@ async def test_client_azure_named_key_credential_async(live_eventhub, uamqp_tran credential.update(live_eventhub["key_name"], live_eventhub["access_key"]) assert (await consumer_client.get_eventhub_properties()) is not None + # New feature only for Pure Python AMQP, not uamqp. @pytest.mark.liveTest @pytest.mark.asyncio -async def test_client_with_ssl_context_async( - auth_credentials_async, - socket_transport -): +async def test_client_with_ssl_context_async(auth_credentials_async, socket_transport): fully_qualified_namespace, eventhub_name, credential = auth_credentials_async # Check that SSLContext with invalid/nonexistent cert file raises an error - context = ssl.SSLContext(cafile='fakecert.pem') + context = ssl.SSLContext(cafile="fakecert.pem") context.verify_mode = ssl.CERT_REQUIRED producer = EventHubProducerClient( @@ -193,12 +191,12 @@ async def test_client_with_ssl_context_async( async with producer: with pytest.raises(ConnectError): batch = await producer.create_batch() - + async def on_event(partition_context, event): on_event.called = True on_event.partition_id = partition_context.partition_id on_event.event = event - + async def on_error(partition_context, error): on_error.error = error await consumer.close() @@ -214,9 +212,7 @@ async def on_error(partition_context, error): ) on_error.error = None async with consumer: - task = asyncio.ensure_future( - consumer.receive(on_event, on_error=on_error, starting_position="-1") - ) + task = asyncio.ensure_future(consumer.receive(on_event, on_error=on_error, starting_position="-1")) await asyncio.sleep(15) await task assert isinstance(on_error.error, ConnectError) @@ -230,7 +226,7 @@ async def verify_context_async(): await asyncio.to_thread(context.load_default_certs, purpose=purpose) return context - def verify_context(): # for Python 3.8 + def verify_context(): # for Python 3.8 context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.load_verify_locations(certifi.where()) purpose = ssl.Purpose.SERVER_AUTH @@ -257,7 +253,7 @@ def verify_context(): # for Python 3.8 async def on_event(partition_context, event): on_event.total += 1 - + async def on_error(partition_context, error): on_error.error = error await consumer.close() @@ -274,9 +270,7 @@ async def on_error(partition_context, error): on_error.error = None async with consumer: - task = asyncio.ensure_future( - consumer.receive(on_event, on_error=on_error, starting_position="-1") - ) + task = asyncio.ensure_future(consumer.receive(on_event, on_error=on_error, starting_position="-1")) await asyncio.sleep(15) await task assert on_event.total == 2 diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py index 5890646a46c4..0d2c7bdc1452 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py @@ -167,16 +167,14 @@ def test_client_azure_named_key_credential(live_eventhub, uamqp_transport): credential.update(live_eventhub["key_name"], live_eventhub["access_key"]) assert consumer_client.get_eventhub_properties() is not None + # New feature only for Pure Python AMQP, not uamqp. @pytest.mark.liveTest -def test_client_with_ssl_context( - auth_credentials, - socket_transport -): +def test_client_with_ssl_context(auth_credentials, socket_transport): fully_qualified_namespace, eventhub_name, credential = auth_credentials # Check that SSLContext with invalid/nonexistent cert file raises an error - context = ssl.SSLContext(cafile='fakecert.pem') + context = ssl.SSLContext(cafile="fakecert.pem") context.verify_mode = ssl.CERT_REQUIRED producer = EventHubProducerClient( @@ -190,12 +188,12 @@ def test_client_with_ssl_context( with producer: with pytest.raises(ConnectError): batch = producer.create_batch() - + def on_event(partition_context, event): on_event.called = True on_event.partition_id = partition_context.partition_id on_event.event = event - + def on_error(partition_context, error): on_error.error = error consumer.close() @@ -243,7 +241,7 @@ def on_error(partition_context, error): def on_event(partition_context, event): on_event.total += 1 - + def on_error(partition_context, error): on_error.error = error consumer.close() diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py index 3c31d40fe768..da81f1fcec1f 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py @@ -501,11 +501,7 @@ def _setup_transport(self): self.sock = self._wrap_socket(self.sock, **self.sslopts) self._quick_recv = self.sock.recv - def _wrap_socket( - self, - sock, - **sslopts - ): + def _wrap_socket(self, sock, **sslopts): if "context" in sslopts: context = sslopts.pop("context") return context.wrap_socket(sock, **sslopts) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py index 0b6e3d506b82..88644f6e798f 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py @@ -208,10 +208,9 @@ def __init__(self, hostname, **kwargs): self._ssl_opts = {} if ssl_context: self._ssl_opts["context"] = ssl_context - else: # str or None + else: # str or None self._ssl_opts["ca_certs"] = connection_verify or certifi.where() - # Emulator self._use_tls: bool = kwargs.get("use_tls", True) diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py index ab9366cf879d..38435f10957b 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py @@ -40,7 +40,13 @@ CachedServiceBusResourceGroupPreparer, SERVICEBUS_ENDPOINT_SUFFIX, ) -from tests.utilities import get_logger, uamqp_transport as get_uamqp_transport, ArgPasserAsync, SocketArgPasserAsync, socket_transport as get_socket_transport +from tests.utilities import ( + get_logger, + uamqp_transport as get_uamqp_transport, + ArgPasserAsync, + SocketArgPasserAsync, + socket_transport as get_socket_transport, +) uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() socket_transport_params, socket_transport_ids = get_socket_transport() @@ -577,7 +583,7 @@ async def test_sb_client_with_ssl_context_async( credential = get_credential(is_async=True) # Check that SSLContext with invalid/nonexistent cert file raises an error - context = ssl.SSLContext(cafile='fakecert.pem') + context = ssl.SSLContext(cafile="fakecert.pem") context.verify_mode = ssl.CERT_REQUIRED client = ServiceBusClient( fully_qualified_namespace=fully_qualified_namespace, @@ -605,7 +611,7 @@ async def verify_context_async(): await asyncio.to_thread(context.load_default_certs, purpose=purpose) return context - def verify_context(): # for Python 3.8 + def verify_context(): # for Python 3.8 context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.load_verify_locations(certifi.where()) purpose = ssl.Purpose.SERVER_AUTH diff --git a/sdk/servicebus/azure-servicebus/tests/test_sb_client.py b/sdk/servicebus/azure-servicebus/tests/test_sb_client.py index 086711c1faec..14dbae7cfc3e 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_sb_client.py +++ b/sdk/servicebus/azure-servicebus/tests/test_sb_client.py @@ -53,7 +53,12 @@ CachedServiceBusResourceGroupPreparer, SERVICEBUS_ENDPOINT_SUFFIX, ) -from utilities import uamqp_transport as get_uamqp_transport, ArgPasser, socket_transport as get_socket_transport, SocketArgPasser +from utilities import ( + uamqp_transport as get_uamqp_transport, + ArgPasser, + socket_transport as get_socket_transport, + SocketArgPasser, +) uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() socket_transport_params, socket_transport_ids = get_socket_transport() @@ -591,7 +596,7 @@ def test_sb_client_with_ssl_context( credential = get_credential() # Check that SSLContext with invalid/nonexistent cert file raises an error - context = ssl.SSLContext(cafile='fakecert.pem') + context = ssl.SSLContext(cafile="fakecert.pem") context.verify_mode = ssl.CERT_REQUIRED client = ServiceBusClient( fully_qualified_namespace=fully_qualified_namespace, @@ -609,13 +614,13 @@ def test_sb_client_with_ssl_context( with pytest.raises(ServiceBusConnectionError): with client.get_queue_receiver(servicebus_queue.name) as receiver: messages = receiver.receive_messages(max_message_count=1, max_wait_time=1) - + # Check that SSLContext with valid cert file doesn't raise an error context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.load_verify_locations(certifi.where()) purpose = ssl.Purpose.SERVER_AUTH context.load_default_certs(purpose=purpose) - + client = ServiceBusClient( fully_qualified_namespace=fully_qualified_namespace, credential=credential, @@ -631,9 +636,8 @@ def test_sb_client_with_ssl_context( with client.get_queue_receiver(servicebus_queue.name) as receiver: messages = receiver.receive_messages(max_message_count=2, max_wait_time=10) - - assert len(messages) == 2 + assert len(messages) == 2 @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) def test_backoff_fixed_retry(self, uamqp_transport): diff --git a/sdk/servicebus/azure-servicebus/tests/utilities.py b/sdk/servicebus/azure-servicebus/tests/utilities.py index 0a7deb4998a8..3d8a0652ae2e 100644 --- a/sdk/servicebus/azure-servicebus/tests/utilities.py +++ b/sdk/servicebus/azure-servicebus/tests/utilities.py @@ -77,11 +77,13 @@ def uamqp_transport(use_uamqp=uamqp_available, use_pyamqp=test_pyamqp): uamqp_transport_ids.append("pyamqp") return uamqp_transport_params, uamqp_transport_ids + def socket_transport(): socket_transport_params = [TransportType.Amqp, TransportType.AmqpOverWebsocket] socket_transport_ids = ["amqp", "ws"] return socket_transport_params, socket_transport_ids + class ArgPasser: def __call__(self, fn): def _preparer(test_class, uamqp_transport, **kwargs): @@ -97,6 +99,7 @@ async def _preparer(test_class, uamqp_transport, **kwargs): return _preparer + class SocketArgPasser: def __call__(self, fn): def _preparer(test_class, uamqp_transport, socket_transport, **kwargs):