diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index e5229ee5b8f9..237d4e17af83 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -4,15 +4,12 @@ ### Features Added -### Breaking Changes - -### Bugs Fixed - -### Other Changes +- Added `ssl_context` parameter to the clients to allow users to pass in the SSL context, in which case, `connection_verify` will be ignored if specified. ## 5.12.2 (2024-10-02) ### Bugs Fixed + - Implemented backpressure for async consumer to address a memory leak issue. ([#36398](https://github.com/Azure/azure-sdk-for-python/issues/36398)) ## 5.12.1 (2024-06-11) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py index fb507713617c..1ed02ffd947f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +import warnings from typing import Optional, Dict, Any, Union, TYPE_CHECKING from urllib.parse import urlparse @@ -9,6 +10,7 @@ from ._constants import TransportType, DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT if TYPE_CHECKING: + from ssl import SSLContext from ._transport._base import AmqpTransport from .aio._transport._base_async import AmqpTransportAsync @@ -35,6 +37,7 @@ def __init__( send_timeout: int = 60, custom_endpoint_address: Optional[str] = None, connection_verify: Optional[str] = None, + ssl_context: Optional["SSLContext"] = None, use_tls: bool = True, **kwargs: Any ): @@ -54,6 +57,9 @@ def __init__( self.send_timeout = send_timeout self.custom_endpoint_address = custom_endpoint_address self.connection_verify = connection_verify + self.ssl_context = ssl_context + if self.ssl_context and self.connection_verify: + warnings.warn("ssl_context is specified, connection_verify will be ignored.") self.custom_endpoint_hostname = None self.hostname = hostname self.use_tls = use_tls diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py index 9d2e803a0d41..38f23d621b3e 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: + from ssl import SSLContext from ._eventprocessor.partition_context import PartitionContext from ._common import EventData from ._client_base import CredentialTypes @@ -122,6 +123,9 @@ class EventHubConsumerClient(ClientBase): # pylint: disable=client-accepts-api- authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: str or None + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -234,6 +238,7 @@ def from_connection_string( load_balancing_strategy: Union[str, LoadBalancingStrategy] = LoadBalancingStrategy.GREEDY, custom_endpoint_address: Optional[str] = None, connection_verify: Optional[str] = None, + ssl_context: Optional["SSLContext"] = None, uamqp_transport: bool = False, **kwargs: Any, ) -> "EventHubConsumerClient": @@ -309,6 +314,9 @@ def from_connection_string( authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: str or None + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -346,6 +354,7 @@ def from_connection_string( load_balancing_strategy=load_balancing_strategy, custom_endpoint_address=custom_endpoint_address, connection_verify=connection_verify, + ssl_context=ssl_context, uamqp_transport=uamqp_transport, **kwargs, ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py index a480a52625cd..2da6716af1b6 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py @@ -30,6 +30,7 @@ from .exceptions import ConnectError, EventHubError if TYPE_CHECKING: + from ssl import SSLContext from ._client_base import CredentialTypes SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]] @@ -128,6 +129,9 @@ class EventHubProducerClient(ClientBase): # pylint: disable=client-accepts-api- authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: Optional[str] + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -419,6 +423,7 @@ def from_connection_string( transport_type: Optional["TransportType"] = TransportType.Amqp, custom_endpoint_address: Optional[str] = None, connection_verify: Optional[str] = None, + ssl_context: Optional["SSLContext"] = None, uamqp_transport: bool = False, **kwargs: Any, ) -> "EventHubProducerClient": @@ -506,6 +511,9 @@ def from_connection_string( authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: Optional[str] + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -542,6 +550,7 @@ def from_connection_string( transport_type=transport_type, custom_endpoint_address=custom_endpoint_address, connection_verify=connection_verify, + ssl_context=ssl_context, uamqp_transport=uamqp_transport, **kwargs, ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py index 5c3c5277b1ec..da81f1fcec1f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py @@ -46,8 +46,6 @@ import logging from threading import Lock -import certifi - from ._platform import KNOWN_TCP_OPTS, SOL_TCP from ._encode import encode_frame from ._decode import decode_frame, decode_empty_frame @@ -503,18 +501,12 @@ def _setup_transport(self): self.sock = self._wrap_socket(self.sock, **self.sslopts) self._quick_recv = self.sock.recv - def _wrap_socket(self, sock, context=None, **sslopts): - if context: - return self._wrap_context(sock, sslopts, **context) + def _wrap_socket(self, sock, **sslopts): + if "context" in sslopts: + context = sslopts.pop("context") + return context.wrap_socket(sock, **sslopts) return self._wrap_socket_sni(sock, **sslopts) - def _wrap_context(self, sock, sslopts, check_hostname=None, **ctx_options): - ctx = ssl.create_default_context(**ctx_options) - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.load_verify_locations(cafile=certifi.where()) - ctx.check_hostname = check_hostname - return ctx.wrap_socket(sock, **sslopts) - def _wrap_socket_sni( self, sock, diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py index 1ca7ba786654..6ab90f4659fe 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py @@ -12,7 +12,6 @@ from functools import partial from typing import Any, Callable, Coroutine, List, Dict, Optional, Tuple, Union, overload, cast from typing_extensions import Literal -import certifi from ..outcomes import Accepted, Modified, Received, Rejected, Released from ._connection_async import Connection @@ -126,9 +125,11 @@ class AMQPClientAsync(AMQPClientSync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp), and 1 for transport type AmqpOverWebsocket. @@ -238,7 +239,7 @@ async def open_async(self, connection=None): self._connection = Connection( "amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname, sasl_credential=self._auth.sasl, - ssl_opts={"ca_certs": self._connection_verify or certifi.where()}, + ssl_opts=self._ssl_opts, container_id=self._name, max_frame_size=self._max_frame_size, channel_max=self._channel_max, @@ -480,9 +481,11 @@ class SendClientAsync(SendClientSync, AMQPClientAsync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ async def _client_ready_async(self): @@ -686,9 +689,11 @@ class ReceiveClientAsync(ReceiveClientSync, AMQPClientAsync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ async def _client_ready_async(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py index 0778a89962eb..59cac3ba9338 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py @@ -42,8 +42,6 @@ import logging -import certifi - from .._platform import KNOWN_TCP_OPTS, SOL_TCP from .._encode import encode_frame from .._decode import decode_frame, decode_empty_frame @@ -167,7 +165,7 @@ def _build_ssl_opts(self, sslopts): return sslopts try: if "context" in sslopts: - return self._build_ssl_context(**sslopts.pop("context")) + return sslopts["context"] ssl_version = sslopts.get("ssl_version") if ssl_version is None: ssl_version = ssl.PROTOCOL_TLS_CLIENT @@ -206,13 +204,6 @@ def _build_ssl_opts(self, sslopts): except TypeError: raise TypeError("SSL configuration must be a dictionary, or the value True.") from None - def _build_ssl_context(self, check_hostname=None, **ctx_options): - ctx = ssl.create_default_context(**ctx_options) - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.load_verify_locations(cafile=certifi.where()) - ctx.check_hostname = check_hostname - return ctx - class AsyncTransport(AsyncTransportMixin): # pylint: disable=too-many-instance-attributes """Common superclass for TCP and SSL transports.""" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py index 34613ea6a0a2..88644f6e798f 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py @@ -143,9 +143,11 @@ class AMQPClient(object): # pylint: disable=too-many-instance-attributes If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp), and 1 for transport type AmqpOverWebsocket. @@ -201,7 +203,13 @@ def __init__(self, hostname, **kwargs): # Custom Endpoint self._custom_endpoint_address = kwargs.get("custom_endpoint_address") - self._connection_verify = kwargs.get("connection_verify") + connection_verify = kwargs.get("connection_verify") + ssl_context = kwargs.get("ssl_context") + self._ssl_opts = {} + if ssl_context: + self._ssl_opts["context"] = ssl_context + else: # str or None + self._ssl_opts["ca_certs"] = connection_verify or certifi.where() # Emulator self._use_tls: bool = kwargs.get("use_tls", True) @@ -306,7 +314,7 @@ def open(self, connection=None): self._connection = Connection( "amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname, sasl_credential=self._auth.sasl, - ssl_opts={"ca_certs": self._connection_verify or certifi.where()}, + ssl_opts=self._ssl_opts, container_id=self._name, max_frame_size=self._max_frame_size, channel_max=self._channel_max, @@ -556,9 +564,11 @@ class SendClient(AMQPClient): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ def __init__(self, hostname, target, **kwargs): @@ -779,9 +789,11 @@ class ReceiveClient(AMQPClient): # pylint:disable=too-many-instance-attributes If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ def __init__(self, hostname, source, **kwargs): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py index abfa16844170..53ff6fc24088 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py @@ -332,6 +332,7 @@ def create_send_client( # pylint: disable=unused-argument target, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, transport_type=config.transport_type, http_proxy=config.http_proxy, socket_timeout=config.socket_timeout, @@ -474,6 +475,7 @@ def create_receive_client( transport_type=config.transport_type, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, socket_timeout=config.socket_timeout, auth=auth, idle_timeout=idle_timeout, @@ -550,7 +552,6 @@ def create_token_auth( timeout=config.auth_timeout, custom_endpoint_hostname=config.custom_endpoint_hostname, port=config.connection_port, - verify=config.connection_verify, ) # if update_token: # token_auth.update_token() # TODO: why don't we need to update in pyamqp? @@ -575,6 +576,7 @@ def create_mgmt_client(address, mgmt_auth, config): # pylint: disable=unused-ar http_proxy=config.http_proxy, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, use_tls=config.use_tls, ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py index 5a8162d418f4..5d127b6dbdb4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_consumer_client_async.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: + from ssl import SSLContext from ._client_base_async import CredentialTypes from ._eventprocessor.partition_context import PartitionContext from ._eventprocessor.checkpoint_store import CheckpointStore @@ -134,6 +135,9 @@ class EventHubConsumerClient(ClientBaseAsync): # pylint: disable=client-accepts authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: Optional[str] + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -246,6 +250,7 @@ def from_connection_string( load_balancing_strategy: Union[str, LoadBalancingStrategy] = LoadBalancingStrategy.GREEDY, custom_endpoint_address: Optional[str] = None, connection_verify: Optional[str] = None, + ssl_context: Optional["SSLContext"] = None, uamqp_transport: bool = False, **kwargs: Any, ) -> "EventHubConsumerClient": @@ -321,6 +326,9 @@ def from_connection_string( authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: Optional[str] + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -357,6 +365,7 @@ def from_connection_string( load_balancing_strategy=load_balancing_strategy, custom_endpoint_address=custom_endpoint_address, connection_verify=connection_verify, + ssl_context=ssl_context, uamqp_transport=uamqp_transport, **kwargs, ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py index 6f5e8bc1c32b..1b51e01f9aa4 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py @@ -20,6 +20,7 @@ from .._common import EventDataBatch, EventData if TYPE_CHECKING: + from ssl import SSLContext from ._client_base_async import CredentialTypes SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]] @@ -115,6 +116,9 @@ class EventHubProducerClient(ClientBaseAsync): # pylint: disable=client-accepts authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: Optional[str] + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -395,6 +399,7 @@ def from_connection_string( transport_type: TransportType = TransportType.Amqp, custom_endpoint_address: Optional[str] = None, connection_verify: Optional[str] = None, + ssl_context: Optional["SSLContext"] = None, uamqp_transport: bool = False, **kwargs: Any, ) -> "EventHubProducerClient": @@ -473,6 +478,9 @@ def from_connection_string( authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. :paramtype connection_verify: Optional[str] + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -508,6 +516,7 @@ def from_connection_string( transport_type=transport_type, custom_endpoint_address=custom_endpoint_address, connection_verify=connection_verify, + ssl_context=ssl_context, uamqp_transport=uamqp_transport, **kwargs, ) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py index 7571818cd127..e026667e21eb 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_transport/_pyamqp_transport_async.py @@ -127,6 +127,7 @@ def create_send_client( # pylint: disable=unused-argument target, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, transport_type=config.transport_type, http_proxy=config.http_proxy, socket_timeout=config.socket_timeout, @@ -211,6 +212,7 @@ def create_receive_client( transport_type=config.transport_type, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, socket_timeout=config.socket_timeout, auth=auth, idle_timeout=idle_timeout, @@ -364,7 +366,6 @@ async def create_token_auth_async( timeout=config.auth_timeout, custom_endpoint_hostname=config.custom_endpoint_hostname, port=config.connection_port, - verify=config.connection_verify, ) # if update_token: # token_auth.update_token() # TODO: why don't we need to update in pyamqp? @@ -389,6 +390,7 @@ def create_mgmt_client(address, mgmt_auth, config): # pylint: disable=unused-ar http_proxy=config.http_proxy, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, use_tls=config.use_tls, ) diff --git a/sdk/eventhub/azure-eventhub/tests/conftest.py b/sdk/eventhub/azure-eventhub/tests/conftest.py index 033721c178a9..d00b169ef88f 100644 --- a/sdk/eventhub/azure-eventhub/tests/conftest.py +++ b/sdk/eventhub/azure-eventhub/tests/conftest.py @@ -15,7 +15,7 @@ from azure.core.settings import settings from azure.mgmt.eventhub import EventHubManagementClient -from azure.eventhub import EventHubProducerClient +from azure.eventhub import EventHubProducerClient, TransportType from azure.eventhub._pyamqp import ReceiveClient from azure.eventhub._pyamqp.authentication import SASTokenAuth from azure.eventhub.extensions.checkpointstoreblob import BlobCheckpointStore @@ -30,6 +30,9 @@ uamqp_transport_params = [False] uamqp_transport_ids = ["pyamqp"] +socket_transports = [TransportType.Amqp, TransportType.AmqpOverWebsocket] +socket_transport_ids = ["amqp", "ws"] + from devtools_testutils import get_region_override, get_credential as get_devtools_credential from tracing_common import FakeSpan @@ -58,6 +61,11 @@ def uamqp_transport(request): return request.param +@pytest.fixture(scope="session", params=socket_transports, ids=socket_transport_ids) +def socket_transport(request): + return request.param + + @pytest.fixture(scope="session") def storage_url(): try: diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py index 40438f62452a..b89b952b3588 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py @@ -6,10 +6,13 @@ import pytest import asyncio import time +import ssl +import certifi from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential from azure.identity.aio import EnvironmentCredential from azure.eventhub import EventData +from azure.eventhub.exceptions import ConnectError from azure.eventhub.aio import ( EventHubConsumerClient, EventHubProducerClient, @@ -165,3 +168,110 @@ async def test_client_azure_named_key_credential_async(live_eventhub, uamqp_tran credential.update(live_eventhub["key_name"], live_eventhub["access_key"]) assert (await consumer_client.get_eventhub_properties()) is not None + + +# New feature only for Pure Python AMQP, not uamqp. +@pytest.mark.liveTest +@pytest.mark.asyncio +async def test_client_with_ssl_context_async(auth_credentials_async, socket_transport): + fully_qualified_namespace, eventhub_name, credential = auth_credentials_async + + # Check that SSLContext with invalid/nonexistent cert file raises an error + context = ssl.SSLContext(cafile="fakecert.pem") + context.verify_mode = ssl.CERT_REQUIRED + + producer = EventHubProducerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + transport_type=socket_transport, + ssl_context=context, + retry_total=0, + ) + async with producer: + with pytest.raises(ConnectError): + batch = await producer.create_batch() + + async def on_event(partition_context, event): + on_event.called = True + on_event.partition_id = partition_context.partition_id + on_event.event = event + + async def on_error(partition_context, error): + on_error.error = error + await consumer.close() + + consumer = EventHubConsumerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + consumer_group="$default", + transport_type=socket_transport, + ssl_context=context, + retry_total=0, + ) + on_error.error = None + async with consumer: + task = asyncio.ensure_future(consumer.receive(on_event, on_error=on_error, starting_position="-1")) + await asyncio.sleep(15) + await task + assert isinstance(on_error.error, ConnectError) + + # Check that SSLContext with valid cert file doesn't raise an error + async def verify_context_async(): + # asyncio.to_thread only available in Python 3.9+ + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + await asyncio.to_thread(context.load_verify_locations, certifi.where()) + purpose = ssl.Purpose.SERVER_AUTH + await asyncio.to_thread(context.load_default_certs, purpose=purpose) + return context + + def verify_context(): # for Python 3.8 + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.load_verify_locations(certifi.where()) + purpose = ssl.Purpose.SERVER_AUTH + context.load_default_certs(purpose=purpose) + return context + + if hasattr(asyncio, "to_thread"): + context = await verify_context_async() + else: + context = verify_context() + + producer = EventHubProducerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + transport_type=socket_transport, + ssl_context=context, + ) + async with producer: + batch = await producer.create_batch() + batch.add(EventData(body="A single message")) + batch.add(EventData(body="A second message")) + await producer.send_batch(batch) + + async def on_event(partition_context, event): + on_event.total += 1 + + async def on_error(partition_context, error): + on_error.error = error + await consumer.close() + + consumer = EventHubConsumerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + consumer_group="$default", + transport_type=socket_transport, + ssl_context=context, + ) + on_event.total = 0 + on_error.error = None + + async with consumer: + task = asyncio.ensure_future(consumer.receive(on_event, on_error=on_error, starting_position="-1")) + await asyncio.sleep(15) + await task + assert on_event.total == 2 + assert on_error.error is None diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py index 3c4762a00f7c..0d2c7bdc1452 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py @@ -6,14 +6,16 @@ import pytest import time import threading +import ssl +import certifi -from azure.identity import EnvironmentCredential from azure.eventhub import ( EventData, EventHubProducerClient, EventHubConsumerClient, EventHubSharedKeyCredential, ) +from azure.eventhub.exceptions import ConnectError from azure.eventhub._client_base import EventHubSASTokenCredential from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential @@ -164,3 +166,107 @@ def test_client_azure_named_key_credential(live_eventhub, uamqp_transport): credential.update(live_eventhub["key_name"], live_eventhub["access_key"]) assert consumer_client.get_eventhub_properties() is not None + + +# New feature only for Pure Python AMQP, not uamqp. +@pytest.mark.liveTest +def test_client_with_ssl_context(auth_credentials, socket_transport): + fully_qualified_namespace, eventhub_name, credential = auth_credentials + + # Check that SSLContext with invalid/nonexistent cert file raises an error + context = ssl.SSLContext(cafile="fakecert.pem") + context.verify_mode = ssl.CERT_REQUIRED + + producer = EventHubProducerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + transport_type=socket_transport, + ssl_context=context, + retry_total=0, + ) + with producer: + with pytest.raises(ConnectError): + batch = producer.create_batch() + + def on_event(partition_context, event): + on_event.called = True + on_event.partition_id = partition_context.partition_id + on_event.event = event + + def on_error(partition_context, error): + on_error.error = error + consumer.close() + + consumer = EventHubConsumerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + consumer_group="$default", + transport_type=socket_transport, + ssl_context=context, + retry_total=0, + ) + on_error.error = None + with consumer: + thread = threading.Thread( + target=consumer.receive, + args=(on_event,), + kwargs={"on_error": on_error, "starting_position": "-1"}, + ) + thread.daemon = True + thread.start() + time.sleep(15) + assert isinstance(on_error.error, ConnectError) + thread.join() + + # Check that SSLContext with valid cert file doesn't raise an error + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.load_verify_locations(certifi.where()) + purpose = ssl.Purpose.SERVER_AUTH + context.load_default_certs(purpose=purpose) + + producer = EventHubProducerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + transport_type=socket_transport, + ssl_context=context, + ) + with producer: + batch = producer.create_batch() + batch.add(EventData(body="A single message")) + batch.add(EventData(body="A second message")) + producer.send_batch(batch) + + def on_event(partition_context, event): + on_event.total += 1 + + def on_error(partition_context, error): + on_error.error = error + consumer.close() + + consumer = EventHubConsumerClient( + fully_qualified_namespace=fully_qualified_namespace, + eventhub_name=eventhub_name, + credential=credential(), + consumer_group="$default", + transport_type=socket_transport, + ssl_context=context, + ) + on_event.total = 0 + on_error.error = None + + with consumer: + thread2 = threading.Thread( + target=consumer.receive, + args=(on_event,), + kwargs={"on_error": on_error, "starting_position": "-1"}, + ) + thread2.daemon = True + thread2.start() + time.sleep(15) + assert on_event.total == 2 + assert on_error.error is None + + thread2.join() diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py index f2ace9755ef4..a30dc5322ce8 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py @@ -3,7 +3,6 @@ import threading import sys -from azure.core.tracing import SpanKind from azure.eventhub import EventData from azure.eventhub import EventHubConsumerClient from azure.eventhub._eventprocessor.in_memory_checkpoint_store import ( diff --git a/sdk/servicebus/azure-servicebus/CHANGELOG.md b/sdk/servicebus/azure-servicebus/CHANGELOG.md index b31709522598..b90d72401b96 100644 --- a/sdk/servicebus/azure-servicebus/CHANGELOG.md +++ b/sdk/servicebus/azure-servicebus/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Added `ssl_context` parameter to the clients to allow users to pass in the SSL context, in which case, `connection_verify` will be ignored if specified.([#37246](https://github.com/Azure/azure-sdk-for-python/issues/37246)) + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_configuration.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_configuration.py index 9a1391cdcc3f..79f66e463b01 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_configuration.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_common/_configuration.py @@ -9,6 +9,7 @@ from .constants import DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT, TransportType if TYPE_CHECKING: + from ssl import SSLContext from .._transport._base import AmqpTransport from ..aio._transport._base_async import AmqpTransportAsync @@ -25,6 +26,7 @@ def __init__(self, **kwargs): self.custom_endpoint_address: Optional[str] = kwargs.get("custom_endpoint_address") self.connection_verify: Optional[str] = kwargs.get("connection_verify") + self.ssl_context: Optional["SSLContext"] = kwargs.get("ssl_context") self.connection_port = DEFAULT_AMQPS_PORT self.custom_endpoint_hostname = None self.hostname = kwargs.pop("hostname") diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py index 5c3c5277b1ec..da81f1fcec1f 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py @@ -46,8 +46,6 @@ import logging from threading import Lock -import certifi - from ._platform import KNOWN_TCP_OPTS, SOL_TCP from ._encode import encode_frame from ._decode import decode_frame, decode_empty_frame @@ -503,18 +501,12 @@ def _setup_transport(self): self.sock = self._wrap_socket(self.sock, **self.sslopts) self._quick_recv = self.sock.recv - def _wrap_socket(self, sock, context=None, **sslopts): - if context: - return self._wrap_context(sock, sslopts, **context) + def _wrap_socket(self, sock, **sslopts): + if "context" in sslopts: + context = sslopts.pop("context") + return context.wrap_socket(sock, **sslopts) return self._wrap_socket_sni(sock, **sslopts) - def _wrap_context(self, sock, sslopts, check_hostname=None, **ctx_options): - ctx = ssl.create_default_context(**ctx_options) - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.load_verify_locations(cafile=certifi.where()) - ctx.check_hostname = check_hostname - return ctx.wrap_socket(sock, **sslopts) - def _wrap_socket_sni( self, sock, diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_client_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_client_async.py index 1ca7ba786654..6ab90f4659fe 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_client_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_client_async.py @@ -12,7 +12,6 @@ from functools import partial from typing import Any, Callable, Coroutine, List, Dict, Optional, Tuple, Union, overload, cast from typing_extensions import Literal -import certifi from ..outcomes import Accepted, Modified, Received, Rejected, Released from ._connection_async import Connection @@ -126,9 +125,11 @@ class AMQPClientAsync(AMQPClientSync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp), and 1 for transport type AmqpOverWebsocket. @@ -238,7 +239,7 @@ async def open_async(self, connection=None): self._connection = Connection( "amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname, sasl_credential=self._auth.sasl, - ssl_opts={"ca_certs": self._connection_verify or certifi.where()}, + ssl_opts=self._ssl_opts, container_id=self._name, max_frame_size=self._max_frame_size, channel_max=self._channel_max, @@ -480,9 +481,11 @@ class SendClientAsync(SendClientSync, AMQPClientAsync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ async def _client_ready_async(self): @@ -686,9 +689,11 @@ class ReceiveClientAsync(ReceiveClientSync, AMQPClientAsync): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ async def _client_ready_async(self): diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py index 0778a89962eb..59cac3ba9338 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py @@ -42,8 +42,6 @@ import logging -import certifi - from .._platform import KNOWN_TCP_OPTS, SOL_TCP from .._encode import encode_frame from .._decode import decode_frame, decode_empty_frame @@ -167,7 +165,7 @@ def _build_ssl_opts(self, sslopts): return sslopts try: if "context" in sslopts: - return self._build_ssl_context(**sslopts.pop("context")) + return sslopts["context"] ssl_version = sslopts.get("ssl_version") if ssl_version is None: ssl_version = ssl.PROTOCOL_TLS_CLIENT @@ -206,13 +204,6 @@ def _build_ssl_opts(self, sslopts): except TypeError: raise TypeError("SSL configuration must be a dictionary, or the value True.") from None - def _build_ssl_context(self, check_hostname=None, **ctx_options): - ctx = ssl.create_default_context(**ctx_options) - ctx.verify_mode = ssl.CERT_REQUIRED - ctx.load_verify_locations(cafile=certifi.where()) - ctx.check_hostname = check_hostname - return ctx - class AsyncTransport(AsyncTransportMixin): # pylint: disable=too-many-instance-attributes """Common superclass for TCP and SSL transports.""" diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py index 34613ea6a0a2..88644f6e798f 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/client.py @@ -143,9 +143,11 @@ class AMQPClient(object): # pylint: disable=too-many-instance-attributes If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp), and 1 for transport type AmqpOverWebsocket. @@ -201,7 +203,13 @@ def __init__(self, hostname, **kwargs): # Custom Endpoint self._custom_endpoint_address = kwargs.get("custom_endpoint_address") - self._connection_verify = kwargs.get("connection_verify") + connection_verify = kwargs.get("connection_verify") + ssl_context = kwargs.get("ssl_context") + self._ssl_opts = {} + if ssl_context: + self._ssl_opts["context"] = ssl_context + else: # str or None + self._ssl_opts["ca_certs"] = connection_verify or certifi.where() # Emulator self._use_tls: bool = kwargs.get("use_tls", True) @@ -306,7 +314,7 @@ def open(self, connection=None): self._connection = Connection( "amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname, sasl_credential=self._auth.sasl, - ssl_opts={"ca_certs": self._connection_verify or certifi.where()}, + ssl_opts=self._ssl_opts, container_id=self._name, max_frame_size=self._max_frame_size, channel_max=self._channel_max, @@ -556,9 +564,11 @@ class SendClient(AMQPClient): If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ def __init__(self, hostname, target, **kwargs): @@ -779,9 +789,11 @@ class ReceiveClient(AMQPClient): # pylint:disable=too-many-instance-attributes If port is not specified in the `custom_endpoint_address`, by default port 443 will be used. :paramtype custom_endpoint_address: str :keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to - authenticate the identity of the connection endpoint. - Default is None in which case `certifi.where()` will be used. + authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None + in which case `certifi.where()` will be used. :paramtype connection_verify: str + :keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored. + :paramtype ssl_context: ssl.SSLContext or None """ def __init__(self, hostname, source, **kwargs): diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py index 320e928fb20c..402d651a7c7d 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_servicebus_client.py @@ -86,6 +86,9 @@ class ServiceBusClient(object): # pylint: disable=client-accepts-api-version-ke :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -145,6 +148,7 @@ def __init__( self._handlers: WeakSet = WeakSet() self._custom_endpoint_address = kwargs.get("custom_endpoint_address") self._connection_verify = kwargs.get("connection_verify") + self._ssl_context = kwargs.get("ssl_context") def __enter__(self) -> "ServiceBusClient": if self._connection_sharing: @@ -156,12 +160,17 @@ def __exit__(self, *args: Any) -> None: def _create_connection(self): auth = create_authentication(self) + if self._ssl_context: + ssl_opts = {"context": self._ssl_context} + else: + ssl_opts = {"ca_certs": self._connection_verify or certifi.where()} + self._connection = self._amqp_transport.create_connection( host=self.fully_qualified_namespace, auth=auth.sasl, network_trace=self._config.logging_enable, custom_endpoint_address=self._custom_endpoint_address, - ssl_opts={"ca_certs": self._connection_verify or certifi.where()}, + ssl_opts=ssl_opts, transport_type=self._config.transport_type, http_proxy=self._config.http_proxy, ) @@ -229,6 +238,9 @@ def from_connection_string( :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -308,6 +320,7 @@ def get_queue_sender(self, queue_name: str, **kwargs: Any) -> ServiceBusSender: retry_backoff_max=self._config.retry_backoff_max, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -428,6 +441,7 @@ def get_queue_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -487,6 +501,7 @@ def get_topic_sender( retry_backoff_max=self._config.retry_backoff_max, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, client_identifier=client_identifier, socket_timeout=socket_timeout, @@ -611,6 +626,7 @@ def get_subscription_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -639,6 +655,7 @@ def get_subscription_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_pyamqp_transport.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_pyamqp_transport.py index fbb78214e9ec..e2000b42d083 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_pyamqp_transport.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_transport/_pyamqp_transport.py @@ -460,6 +460,7 @@ def create_send_client(config: "Configuration", **kwargs: Any) -> "SendClient": keep_alive_interval=config.keep_alive, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, transport_type=config.transport_type, http_proxy=config.http_proxy, socket_timeout=config.socket_timeout, @@ -583,6 +584,7 @@ def create_receive_client(receiver: "ServiceBusReceiver", **kwargs: "Any") -> "R transport_type=config.transport_type, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, receive_settle_mode=PyamqpTransport.ServiceBusToAMQPReceiveModeMap[receive_mode], send_settle_mode=( constants.SenderSettleMode.Settled @@ -851,7 +853,6 @@ def create_token_auth( timeout=config.auth_timeout, custom_endpoint_hostname=config.custom_endpoint_hostname, port=config.connection_port, - verify=config.connection_verify, ) @staticmethod diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py index 7187cb430263..0fa098a7d503 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_servicebus_client_async.py @@ -79,6 +79,9 @@ class ServiceBusClient(object): # pylint: disable=client-accepts-api-version-ke :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -135,6 +138,7 @@ def __init__( self._handlers: WeakSet = WeakSet() self._custom_endpoint_address = kwargs.get("custom_endpoint_address") self._connection_verify = kwargs.get("connection_verify") + self._ssl_context = kwargs.get("ssl_context") async def __aenter__(self) -> "ServiceBusClient": if self._connection_sharing: @@ -146,12 +150,16 @@ async def __aexit__(self, *args: Any) -> None: async def _create_connection(self): auth = await create_authentication(self) + if self._ssl_context: + ssl_opts = {"context": self._ssl_context} + else: + ssl_opts = {"ca_certs": self._connection_verify or certifi.where()} self._connection = self._amqp_transport.create_connection_async( host=self.fully_qualified_namespace, auth=auth.sasl, network_trace=self._config.logging_enable, custom_endpoint_address=self._custom_endpoint_address, - ssl_opts={"ca_certs": self._connection_verify or certifi.where()}, + ssl_opts=ssl_opts, transport_type=self._config.transport_type, http_proxy=self._config.http_proxy, ) @@ -197,6 +205,9 @@ def from_connection_string( :keyword str connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to authenticate the identity of the connection endpoint. Default is None in which case `certifi.where()` will be used. + :keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified, + connection_verify will be ignored. + :paramtype ssl_context: ssl.SSLContext or None :keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is False and the Pure Python AMQP library will be used as the underlying transport. :paramtype uamqp_transport: bool @@ -297,6 +308,7 @@ def get_queue_sender(self, queue_name: str, **kwargs: Any) -> ServiceBusSender: retry_backoff_max=self._config.retry_backoff_max, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -417,6 +429,7 @@ def get_queue_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -469,6 +482,7 @@ def get_topic_sender(self, topic_name: str, **kwargs: Any) -> ServiceBusSender: retry_backoff_max=self._config.retry_backoff_max, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -591,6 +605,7 @@ def get_subscription_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) @@ -619,6 +634,7 @@ def get_subscription_receiver( prefetch_count=prefetch_count, custom_endpoint_address=self._custom_endpoint_address, connection_verify=self._connection_verify, + ssl_context=self._ssl_context, amqp_transport=self._amqp_transport, **kwargs, ) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_pyamqp_transport_async.py b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_pyamqp_transport_async.py index f3cf1f28bbd6..f90ee4d8db53 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_pyamqp_transport_async.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/aio/_transport/_pyamqp_transport_async.py @@ -112,6 +112,7 @@ def create_send_client_async(config: "Configuration", **kwargs: Any) -> "SendCli keep_alive_interval=config.keep_alive, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, transport_type=config.transport_type, http_proxy=config.http_proxy, socket_timeout=config.socket_timeout, @@ -209,6 +210,7 @@ def create_receive_client_async( transport_type=config.transport_type, custom_endpoint_address=config.custom_endpoint_address, connection_verify=config.connection_verify, + ssl_context=config.ssl_context, receive_settle_mode=PyamqpTransportAsync.ServiceBusToAMQPReceiveModeMap[receive_mode], send_settle_mode=( constants.SenderSettleMode.Settled @@ -398,7 +400,6 @@ async def create_token_auth_async( timeout=config.auth_timeout, custom_endpoint_hostname=config.custom_endpoint_hostname, port=config.connection_port, - verify=config.connection_verify, ) @staticmethod diff --git a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py index ccc8a151e73c..38435f10957b 100644 --- a/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py +++ b/sdk/servicebus/azure-servicebus/tests/async_tests/test_sb_client_async.py @@ -13,6 +13,7 @@ import hashlib import base64 import certifi +import ssl from urllib.parse import quote as url_parse_quote from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, AccessToken @@ -39,9 +40,16 @@ CachedServiceBusResourceGroupPreparer, SERVICEBUS_ENDPOINT_SUFFIX, ) -from tests.utilities import get_logger, uamqp_transport as get_uamqp_transport, ArgPasserAsync +from tests.utilities import ( + get_logger, + uamqp_transport as get_uamqp_transport, + ArgPasserAsync, + SocketArgPasserAsync, + socket_transport as get_socket_transport, +) uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() +socket_transport_params, socket_transport_ids = get_socket_transport() _logger = get_logger(logging.DEBUG) @@ -553,6 +561,86 @@ async def test_azure_named_key_credential_async( async with client.get_queue_sender(servicebus_queue.name) as sender: await sender.send_messages(ServiceBusMessage("foo")) + @pytest.mark.asyncio + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedServiceBusResourceGroupPreparer() + @CachedServiceBusNamespacePreparer(name_prefix="servicebustest") + @CachedServiceBusQueuePreparer(name_prefix="servicebustest", dead_lettering_on_message_expiration=True) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @pytest.mark.parametrize("socket_transport", socket_transport_params, ids=socket_transport_ids) + @SocketArgPasserAsync() + async def test_sb_client_with_ssl_context_async( + self, + uamqp_transport, + socket_transport, + *, + servicebus_namespace=None, + servicebus_queue=None, + **kwargs, + ): + fully_qualified_namespace = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" + credential = get_credential(is_async=True) + + # Check that SSLContext with invalid/nonexistent cert file raises an error + context = ssl.SSLContext(cafile="fakecert.pem") + context.verify_mode = ssl.CERT_REQUIRED + client = ServiceBusClient( + fully_qualified_namespace=fully_qualified_namespace, + credential=credential, + uamqp_transport=uamqp_transport, + ssl_context=context, + transport_type=socket_transport, + retry_total=0, + ) + async with client: + with pytest.raises(ServiceBusConnectionError): + async with client.get_queue_sender(servicebus_queue.name) as sender: + await sender.send_messages(ServiceBusMessage("test")) + + with pytest.raises(ServiceBusConnectionError): + async with client.get_queue_receiver(servicebus_queue.name) as receiver: + messages = await receiver.receive_messages(max_message_count=1, max_wait_time=1) + + # Check that SSLContext with valid cert file doesn't raise an error + async def verify_context_async(): + # asyncio.to_thread only available in Python 3.9+ + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + await asyncio.to_thread(context.load_verify_locations, certifi.where()) + purpose = ssl.Purpose.SERVER_AUTH + await asyncio.to_thread(context.load_default_certs, purpose=purpose) + return context + + def verify_context(): # for Python 3.8 + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.load_verify_locations(certifi.where()) + purpose = ssl.Purpose.SERVER_AUTH + context.load_default_certs(purpose=purpose) + return context + + if hasattr(asyncio, "to_thread"): + context = await verify_context_async() + else: + context = verify_context() + + client = ServiceBusClient( + fully_qualified_namespace=fully_qualified_namespace, + credential=credential, + uamqp_transport=uamqp_transport, + ssl_context=context, + transport_type=socket_transport, + ) + + async with client: + async with client.get_queue_sender(servicebus_queue.name) as sender: + await sender.send_messages(ServiceBusMessage("test")) + await sender.send_messages(ServiceBusMessage("test")) + + async with client.get_queue_receiver(servicebus_queue.name) as receiver: + messages = await receiver.receive_messages(max_message_count=2, max_wait_time=10) + + assert len(messages) == 2 + @pytest.mark.asyncio @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) async def test_backoff_fixed_retry(self, uamqp_transport): diff --git a/sdk/servicebus/azure-servicebus/tests/test_sb_client.py b/sdk/servicebus/azure-servicebus/tests/test_sb_client.py index 4bf87ff9b978..14dbae7cfc3e 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_sb_client.py +++ b/sdk/servicebus/azure-servicebus/tests/test_sb_client.py @@ -14,6 +14,7 @@ import hashlib import base64 import certifi +import ssl try: from urllib.parse import quote as url_parse_quote @@ -52,9 +53,15 @@ CachedServiceBusResourceGroupPreparer, SERVICEBUS_ENDPOINT_SUFFIX, ) -from utilities import uamqp_transport as get_uamqp_transport, ArgPasser +from utilities import ( + uamqp_transport as get_uamqp_transport, + ArgPasser, + socket_transport as get_socket_transport, + SocketArgPasser, +) uamqp_transport_params, uamqp_transport_ids = get_uamqp_transport() +socket_transport_params, socket_transport_ids = get_socket_transport() class TestServiceBusClient(AzureMgmtRecordedTestCase): @@ -568,6 +575,70 @@ def test_azure_named_key_credential( with client.get_queue_sender(servicebus_queue.name) as sender: sender.send_messages(ServiceBusMessage("foo")) + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedServiceBusResourceGroupPreparer() + @CachedServiceBusNamespacePreparer(name_prefix="servicebustest") + @CachedServiceBusQueuePreparer(name_prefix="servicebustest", dead_lettering_on_message_expiration=True) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @pytest.mark.parametrize("socket_transport", socket_transport_params, ids=socket_transport_ids) + @SocketArgPasser() + def test_sb_client_with_ssl_context( + self, + uamqp_transport, + socket_transport, + *, + servicebus_namespace=None, + servicebus_queue=None, + **kwargs, + ): + fully_qualified_namespace = f"{servicebus_namespace.name}{SERVICEBUS_ENDPOINT_SUFFIX}" + credential = get_credential() + + # Check that SSLContext with invalid/nonexistent cert file raises an error + context = ssl.SSLContext(cafile="fakecert.pem") + context.verify_mode = ssl.CERT_REQUIRED + client = ServiceBusClient( + fully_qualified_namespace=fully_qualified_namespace, + credential=credential, + uamqp_transport=uamqp_transport, + ssl_context=context, + transport_type=socket_transport, + retry_total=0, + ) + with client: + with pytest.raises(ServiceBusConnectionError): + with client.get_queue_sender(servicebus_queue.name) as sender: + sender.send_messages(ServiceBusMessage("test")) + + with pytest.raises(ServiceBusConnectionError): + with client.get_queue_receiver(servicebus_queue.name) as receiver: + messages = receiver.receive_messages(max_message_count=1, max_wait_time=1) + + # Check that SSLContext with valid cert file doesn't raise an error + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.load_verify_locations(certifi.where()) + purpose = ssl.Purpose.SERVER_AUTH + context.load_default_certs(purpose=purpose) + + client = ServiceBusClient( + fully_qualified_namespace=fully_qualified_namespace, + credential=credential, + uamqp_transport=uamqp_transport, + ssl_context=context, + transport_type=socket_transport, + ) + + with client: + with client.get_queue_sender(servicebus_queue.name) as sender: + sender.send_messages(ServiceBusMessage("test")) + sender.send_messages(ServiceBusMessage("test")) + + with client.get_queue_receiver(servicebus_queue.name) as receiver: + messages = receiver.receive_messages(max_message_count=2, max_wait_time=10) + + assert len(messages) == 2 + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) def test_backoff_fixed_retry(self, uamqp_transport): diff --git a/sdk/servicebus/azure-servicebus/tests/utilities.py b/sdk/servicebus/azure-servicebus/tests/utilities.py index ada827fae6ef..3d8a0652ae2e 100644 --- a/sdk/servicebus/azure-servicebus/tests/utilities.py +++ b/sdk/servicebus/azure-servicebus/tests/utilities.py @@ -15,6 +15,7 @@ uamqp_available = True except (ModuleNotFoundError, ImportError): uamqp_available = False +from azure.servicebus import TransportType from azure.servicebus._common.utils import utc_now # temporary - disable uamqp if China b/c of 8+ hr runtime @@ -77,6 +78,12 @@ def uamqp_transport(use_uamqp=uamqp_available, use_pyamqp=test_pyamqp): return uamqp_transport_params, uamqp_transport_ids +def socket_transport(): + socket_transport_params = [TransportType.Amqp, TransportType.AmqpOverWebsocket] + socket_transport_ids = ["amqp", "ws"] + return socket_transport_params, socket_transport_ids + + class ArgPasser: def __call__(self, fn): def _preparer(test_class, uamqp_transport, **kwargs): @@ -91,3 +98,19 @@ async def _preparer(test_class, uamqp_transport, **kwargs): await fn(test_class, uamqp_transport=uamqp_transport, **kwargs) return _preparer + + +class SocketArgPasser: + def __call__(self, fn): + def _preparer(test_class, uamqp_transport, socket_transport, **kwargs): + fn(test_class, uamqp_transport=uamqp_transport, socket_transport=socket_transport, **kwargs) + + return _preparer + + +class SocketArgPasserAsync: + def __call__(self, fn): + async def _preparer(test_class, uamqp_transport, socket_transport, **kwargs): + await fn(test_class, uamqp_transport=uamqp_transport, socket_transport=socket_transport, **kwargs) + + return _preparer