Skip to content

[EventHub] add ssl_context kwarg to clients #37702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 4, 2024
7 changes: 2 additions & 5 deletions sdk/eventhub/azure-eventhub/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@

### Features Added

### Breaking Changes

### Bugs Fixed

### Other Changes
- Added `ssl_context` parameter to the clients to allow users to pass in the SSL context, in which case, `connection_verify` will be ignored if specified.

## 5.12.2 (2024-10-02)

### Bugs Fixed

- Implemented backpressure for async consumer to address a memory leak issue. ([#36398](https://github.com/Azure/azure-sdk-for-python/issues/36398))

## 5.12.1 (2024-06-11)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import warnings
from typing import Optional, Dict, Any, Union, TYPE_CHECKING
from urllib.parse import urlparse

from azure.core.pipeline.policies import RetryMode
from ._constants import TransportType, DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT

if TYPE_CHECKING:
from ssl import SSLContext
from ._transport._base import AmqpTransport
from .aio._transport._base_async import AmqpTransportAsync

Expand All @@ -35,6 +37,7 @@ def __init__(
send_timeout: int = 60,
custom_endpoint_address: Optional[str] = None,
connection_verify: Optional[str] = None,
ssl_context: Optional["SSLContext"] = None,
use_tls: bool = True,
**kwargs: Any
):
Expand All @@ -54,6 +57,9 @@ def __init__(
self.send_timeout = send_timeout
self.custom_endpoint_address = custom_endpoint_address
self.connection_verify = connection_verify
self.ssl_context = ssl_context
if self.ssl_context and self.connection_verify:
warnings.warn("ssl_context is specified, connection_verify will be ignored.")
self.custom_endpoint_hostname = None
self.hostname = hostname
self.use_tls = use_tls
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


if TYPE_CHECKING:
from ssl import SSLContext
from ._eventprocessor.partition_context import PartitionContext
from ._common import EventData
from ._client_base import CredentialTypes
Expand Down Expand Up @@ -122,6 +123,9 @@ class EventHubConsumerClient(ClientBase): # pylint: disable=client-accepts-api-
authenticate the identity of the connection endpoint.
Default is None in which case `certifi.where()` will be used.
:paramtype connection_verify: str or None
:keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified,
connection_verify will be ignored.
:paramtype ssl_context: ssl.SSLContext or None
:keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is
False and the Pure Python AMQP library will be used as the underlying transport.
:paramtype uamqp_transport: bool
Expand Down Expand Up @@ -234,6 +238,7 @@ def from_connection_string(
load_balancing_strategy: Union[str, LoadBalancingStrategy] = LoadBalancingStrategy.GREEDY,
custom_endpoint_address: Optional[str] = None,
connection_verify: Optional[str] = None,
ssl_context: Optional["SSLContext"] = None,
uamqp_transport: bool = False,
**kwargs: Any,
) -> "EventHubConsumerClient":
Expand Down Expand Up @@ -309,6 +314,9 @@ def from_connection_string(
authenticate the identity of the connection endpoint.
Default is None in which case `certifi.where()` will be used.
:paramtype connection_verify: str or None
:keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified,
connection_verify will be ignored.
:paramtype ssl_context: ssl.SSLContext or None
:keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is
False and the Pure Python AMQP library will be used as the underlying transport.
:paramtype uamqp_transport: bool
Expand Down Expand Up @@ -346,6 +354,7 @@ def from_connection_string(
load_balancing_strategy=load_balancing_strategy,
custom_endpoint_address=custom_endpoint_address,
connection_verify=connection_verify,
ssl_context=ssl_context,
uamqp_transport=uamqp_transport,
**kwargs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .exceptions import ConnectError, EventHubError

if TYPE_CHECKING:
from ssl import SSLContext
from ._client_base import CredentialTypes

SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]]
Expand Down Expand Up @@ -128,6 +129,9 @@ class EventHubProducerClient(ClientBase): # pylint: disable=client-accepts-api-
authenticate the identity of the connection endpoint.
Default is None in which case `certifi.where()` will be used.
:paramtype connection_verify: Optional[str]
:keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified,
connection_verify will be ignored.
:paramtype ssl_context: ssl.SSLContext or None
:keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is
False and the Pure Python AMQP library will be used as the underlying transport.
:paramtype uamqp_transport: bool
Expand Down Expand Up @@ -419,6 +423,7 @@ def from_connection_string(
transport_type: Optional["TransportType"] = TransportType.Amqp,
custom_endpoint_address: Optional[str] = None,
connection_verify: Optional[str] = None,
ssl_context: Optional["SSLContext"] = None,
uamqp_transport: bool = False,
**kwargs: Any,
) -> "EventHubProducerClient":
Expand Down Expand Up @@ -506,6 +511,9 @@ def from_connection_string(
authenticate the identity of the connection endpoint.
Default is None in which case `certifi.where()` will be used.
:paramtype connection_verify: Optional[str]
:keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified,
connection_verify will be ignored.
:paramtype ssl_context: ssl.SSLContext or None
:keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is
False and the Pure Python AMQP library will be used as the underlying transport.
:paramtype uamqp_transport: bool
Expand Down Expand Up @@ -542,6 +550,7 @@ def from_connection_string(
transport_type=transport_type,
custom_endpoint_address=custom_endpoint_address,
connection_verify=connection_verify,
ssl_context=ssl_context,
uamqp_transport=uamqp_transport,
**kwargs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
import logging
from threading import Lock

import certifi

from ._platform import KNOWN_TCP_OPTS, SOL_TCP
from ._encode import encode_frame
from ._decode import decode_frame, decode_empty_frame
Expand Down Expand Up @@ -503,18 +501,12 @@ def _setup_transport(self):
self.sock = self._wrap_socket(self.sock, **self.sslopts)
self._quick_recv = self.sock.recv

def _wrap_socket(self, sock, context=None, **sslopts):
if context:
return self._wrap_context(sock, sslopts, **context)
def _wrap_socket(self, sock, **sslopts):
if "context" in sslopts:
context = sslopts.pop("context")
return context.wrap_socket(sock, **sslopts)
return self._wrap_socket_sni(sock, **sslopts)

def _wrap_context(self, sock, sslopts, check_hostname=None, **ctx_options):
ctx = ssl.create_default_context(**ctx_options)
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.load_verify_locations(cafile=certifi.where())
ctx.check_hostname = check_hostname
return ctx.wrap_socket(sock, **sslopts)

def _wrap_socket_sni(
self,
sock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from functools import partial
from typing import Any, Callable, Coroutine, List, Dict, Optional, Tuple, Union, overload, cast
from typing_extensions import Literal
import certifi

from ..outcomes import Accepted, Modified, Received, Rejected, Released
from ._connection_async import Connection
Expand Down Expand Up @@ -126,9 +125,11 @@ class AMQPClientAsync(AMQPClientSync):
If port is not specified in the `custom_endpoint_address`, by default port 443 will be used.
:paramtype custom_endpoint_address: str
:keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to
authenticate the identity of the connection endpoint.
Default is None in which case `certifi.where()` will be used.
authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None
in which case `certifi.where()` will be used.
:paramtype connection_verify: str
:keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored.
:paramtype ssl_context: ssl.SSLContext or None
:keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should
wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp),
and 1 for transport type AmqpOverWebsocket.
Expand Down Expand Up @@ -238,7 +239,7 @@ async def open_async(self, connection=None):
self._connection = Connection(
"amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname,
sasl_credential=self._auth.sasl,
ssl_opts={"ca_certs": self._connection_verify or certifi.where()},
ssl_opts=self._ssl_opts,
container_id=self._name,
max_frame_size=self._max_frame_size,
channel_max=self._channel_max,
Expand Down Expand Up @@ -480,9 +481,11 @@ class SendClientAsync(SendClientSync, AMQPClientAsync):
If port is not specified in the `custom_endpoint_address`, by default port 443 will be used.
:paramtype custom_endpoint_address: str
:keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to
authenticate the identity of the connection endpoint.
Default is None in which case `certifi.where()` will be used.
authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None
in which case `certifi.where()` will be used.
:paramtype connection_verify: str
:keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored.
:paramtype ssl_context: ssl.SSLContext or None
"""

async def _client_ready_async(self):
Expand Down Expand Up @@ -686,9 +689,11 @@ class ReceiveClientAsync(ReceiveClientSync, AMQPClientAsync):
If port is not specified in the `custom_endpoint_address`, by default port 443 will be used.
:paramtype custom_endpoint_address: str
:keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to
authenticate the identity of the connection endpoint.
Default is None in which case `certifi.where()` will be used.
authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None
in which case `certifi.where()` will be used.
:paramtype connection_verify: str
:keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored.
:paramtype ssl_context: ssl.SSLContext or None
"""

async def _client_ready_async(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
import logging


import certifi

from .._platform import KNOWN_TCP_OPTS, SOL_TCP
from .._encode import encode_frame
from .._decode import decode_frame, decode_empty_frame
Expand Down Expand Up @@ -167,7 +165,7 @@ def _build_ssl_opts(self, sslopts):
return sslopts
try:
if "context" in sslopts:
return self._build_ssl_context(**sslopts.pop("context"))
return sslopts["context"]
ssl_version = sslopts.get("ssl_version")
if ssl_version is None:
ssl_version = ssl.PROTOCOL_TLS_CLIENT
Expand Down Expand Up @@ -206,13 +204,6 @@ def _build_ssl_opts(self, sslopts):
except TypeError:
raise TypeError("SSL configuration must be a dictionary, or the value True.") from None

def _build_ssl_context(self, check_hostname=None, **ctx_options):
ctx = ssl.create_default_context(**ctx_options)
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.load_verify_locations(cafile=certifi.where())
ctx.check_hostname = check_hostname
return ctx


class AsyncTransport(AsyncTransportMixin): # pylint: disable=too-many-instance-attributes
"""Common superclass for TCP and SSL transports."""
Expand Down
28 changes: 20 additions & 8 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ class AMQPClient(object): # pylint: disable=too-many-instance-attributes
If port is not specified in the `custom_endpoint_address`, by default port 443 will be used.
:paramtype custom_endpoint_address: str
:keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to
authenticate the identity of the connection endpoint.
Default is None in which case `certifi.where()` will be used.
authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None
in which case `certifi.where()` will be used.
:paramtype connection_verify: str
:keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored.
:paramtype ssl_context: ssl.SSLContext or None
:keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should
wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp),
and 1 for transport type AmqpOverWebsocket.
Expand Down Expand Up @@ -201,7 +203,13 @@ def __init__(self, hostname, **kwargs):

# Custom Endpoint
self._custom_endpoint_address = kwargs.get("custom_endpoint_address")
self._connection_verify = kwargs.get("connection_verify")
connection_verify = kwargs.get("connection_verify")
ssl_context = kwargs.get("ssl_context")
self._ssl_opts = {}
if ssl_context:
self._ssl_opts["context"] = ssl_context
else: # str or None
self._ssl_opts["ca_certs"] = connection_verify or certifi.where()

# Emulator
self._use_tls: bool = kwargs.get("use_tls", True)
Expand Down Expand Up @@ -306,7 +314,7 @@ def open(self, connection=None):
self._connection = Connection(
"amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname,
sasl_credential=self._auth.sasl,
ssl_opts={"ca_certs": self._connection_verify or certifi.where()},
ssl_opts=self._ssl_opts,
container_id=self._name,
max_frame_size=self._max_frame_size,
channel_max=self._channel_max,
Expand Down Expand Up @@ -556,9 +564,11 @@ class SendClient(AMQPClient):
If port is not specified in the `custom_endpoint_address`, by default port 443 will be used.
:paramtype custom_endpoint_address: str
:keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to
authenticate the identity of the connection endpoint.
Default is None in which case `certifi.where()` will be used.
authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None
in which case `certifi.where()` will be used.
:paramtype connection_verify: str
:keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored.
:paramtype ssl_context: ssl.SSLContext or None
"""

def __init__(self, hostname, target, **kwargs):
Expand Down Expand Up @@ -779,9 +789,11 @@ class ReceiveClient(AMQPClient): # pylint:disable=too-many-instance-attributes
If port is not specified in the `custom_endpoint_address`, by default port 443 will be used.
:paramtype custom_endpoint_address: str
:keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to
authenticate the identity of the connection endpoint.
Default is None in which case `certifi.where()` will be used.
authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None
in which case `certifi.where()` will be used.
:paramtype connection_verify: str
:keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored.
:paramtype ssl_context: ssl.SSLContext or None
"""

def __init__(self, hostname, source, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def create_send_client( # pylint: disable=unused-argument
target,
custom_endpoint_address=config.custom_endpoint_address,
connection_verify=config.connection_verify,
ssl_context=config.ssl_context,
transport_type=config.transport_type,
http_proxy=config.http_proxy,
socket_timeout=config.socket_timeout,
Expand Down Expand Up @@ -474,6 +475,7 @@ def create_receive_client(
transport_type=config.transport_type,
custom_endpoint_address=config.custom_endpoint_address,
connection_verify=config.connection_verify,
ssl_context=config.ssl_context,
socket_timeout=config.socket_timeout,
auth=auth,
idle_timeout=idle_timeout,
Expand Down Expand Up @@ -550,7 +552,6 @@ def create_token_auth(
timeout=config.auth_timeout,
custom_endpoint_hostname=config.custom_endpoint_hostname,
port=config.connection_port,
verify=config.connection_verify,
)
# if update_token:
# token_auth.update_token() # TODO: why don't we need to update in pyamqp?
Expand All @@ -575,6 +576,7 @@ def create_mgmt_client(address, mgmt_auth, config): # pylint: disable=unused-ar
http_proxy=config.http_proxy,
custom_endpoint_address=config.custom_endpoint_address,
connection_verify=config.connection_verify,
ssl_context=config.ssl_context,
use_tls=config.use_tls,
)

Expand Down
Loading