Skip to content

Commit db29e02

Browse files
swathipill0lawrence
authored andcommitted
[EventHub] add ssl_context kwarg to clients (Azure#37702)
* remove verify from pyamqp JWTToken * add ssl_context kwarg * add tests * fix merge * fix failing/lint/mypy * lint * separate ssl_context property from conn verify * make sb changes * add tests * changelog * black
1 parent 179148a commit db29e02

29 files changed

+579
-92
lines changed

sdk/eventhub/azure-eventhub/CHANGELOG.md

+2-5
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,14 @@
44

55
### Features Added
66

7-
### Breaking Changes
8-
9-
### Bugs Fixed
10-
11-
### Other Changes
7+
- Added `ssl_context` parameter to the clients to allow users to pass in the SSL context, in which case, `connection_verify` will be ignored if specified.
128

139
- Added logging to track received messages.
1410

1511
## 5.12.2 (2024-10-02)
1612

1713
### Bugs Fixed
14+
1815
- Implemented backpressure for async consumer to address a memory leak issue. ([#36398](https://github.com/Azure/azure-sdk-for-python/issues/36398))
1916

2017
## 5.12.1 (2024-06-11)

sdk/eventhub/azure-eventhub/azure/eventhub/_configuration.py

+6
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License. See License.txt in the project root for license information.
44
# --------------------------------------------------------------------------------------------
5+
import warnings
56
from typing import Optional, Dict, Any, Union, TYPE_CHECKING
67
from urllib.parse import urlparse
78

89
from azure.core.pipeline.policies import RetryMode
910
from ._constants import TransportType, DEFAULT_AMQPS_PORT, DEFAULT_AMQP_WSS_PORT
1011

1112
if TYPE_CHECKING:
13+
from ssl import SSLContext
1214
from ._transport._base import AmqpTransport
1315
from .aio._transport._base_async import AmqpTransportAsync
1416

@@ -35,6 +37,7 @@ def __init__(
3537
send_timeout: int = 60,
3638
custom_endpoint_address: Optional[str] = None,
3739
connection_verify: Optional[str] = None,
40+
ssl_context: Optional["SSLContext"] = None,
3841
use_tls: bool = True,
3942
**kwargs: Any
4043
):
@@ -54,6 +57,9 @@ def __init__(
5457
self.send_timeout = send_timeout
5558
self.custom_endpoint_address = custom_endpoint_address
5659
self.connection_verify = connection_verify
60+
self.ssl_context = ssl_context
61+
if self.ssl_context and self.connection_verify:
62+
warnings.warn("ssl_context is specified, connection_verify will be ignored.")
5763
self.custom_endpoint_hostname = None
5864
self.hostname = hostname
5965
self.use_tls = use_tls

sdk/eventhub/azure-eventhub/azure/eventhub/_consumer_client.py

+9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616

1717
if TYPE_CHECKING:
18+
from ssl import SSLContext
1819
from ._eventprocessor.partition_context import PartitionContext
1920
from ._common import EventData
2021
from ._client_base import CredentialTypes
@@ -122,6 +123,9 @@ class EventHubConsumerClient(ClientBase): # pylint: disable=client-accepts-api-
122123
authenticate the identity of the connection endpoint.
123124
Default is None in which case `certifi.where()` will be used.
124125
:paramtype connection_verify: str or None
126+
:keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified,
127+
connection_verify will be ignored.
128+
:paramtype ssl_context: ssl.SSLContext or None
125129
:keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is
126130
False and the Pure Python AMQP library will be used as the underlying transport.
127131
:paramtype uamqp_transport: bool
@@ -234,6 +238,7 @@ def from_connection_string(
234238
load_balancing_strategy: Union[str, LoadBalancingStrategy] = LoadBalancingStrategy.GREEDY,
235239
custom_endpoint_address: Optional[str] = None,
236240
connection_verify: Optional[str] = None,
241+
ssl_context: Optional["SSLContext"] = None,
237242
uamqp_transport: bool = False,
238243
**kwargs: Any,
239244
) -> "EventHubConsumerClient":
@@ -309,6 +314,9 @@ def from_connection_string(
309314
authenticate the identity of the connection endpoint.
310315
Default is None in which case `certifi.where()` will be used.
311316
:paramtype connection_verify: str or None
317+
:keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified,
318+
connection_verify will be ignored.
319+
:paramtype ssl_context: ssl.SSLContext or None
312320
:keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is
313321
False and the Pure Python AMQP library will be used as the underlying transport.
314322
:paramtype uamqp_transport: bool
@@ -346,6 +354,7 @@ def from_connection_string(
346354
load_balancing_strategy=load_balancing_strategy,
347355
custom_endpoint_address=custom_endpoint_address,
348356
connection_verify=connection_verify,
357+
ssl_context=ssl_context,
349358
uamqp_transport=uamqp_transport,
350359
**kwargs,
351360
)

sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py

+9
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .exceptions import ConnectError, EventHubError
3131

3232
if TYPE_CHECKING:
33+
from ssl import SSLContext
3334
from ._client_base import CredentialTypes
3435

3536
SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]]
@@ -128,6 +129,9 @@ class EventHubProducerClient(ClientBase): # pylint: disable=client-accepts-api-
128129
authenticate the identity of the connection endpoint.
129130
Default is None in which case `certifi.where()` will be used.
130131
:paramtype connection_verify: Optional[str]
132+
:keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified,
133+
connection_verify will be ignored.
134+
:paramtype ssl_context: ssl.SSLContext or None
131135
:keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is
132136
False and the Pure Python AMQP library will be used as the underlying transport.
133137
:paramtype uamqp_transport: bool
@@ -419,6 +423,7 @@ def from_connection_string(
419423
transport_type: Optional["TransportType"] = TransportType.Amqp,
420424
custom_endpoint_address: Optional[str] = None,
421425
connection_verify: Optional[str] = None,
426+
ssl_context: Optional["SSLContext"] = None,
422427
uamqp_transport: bool = False,
423428
**kwargs: Any,
424429
) -> "EventHubProducerClient":
@@ -506,6 +511,9 @@ def from_connection_string(
506511
authenticate the identity of the connection endpoint.
507512
Default is None in which case `certifi.where()` will be used.
508513
:paramtype connection_verify: Optional[str]
514+
:keyword ssl_context: The SSLContext object to use in the underlying Pure Python AMQP transport. If specified,
515+
connection_verify will be ignored.
516+
:paramtype ssl_context: ssl.SSLContext or None
509517
:keyword uamqp_transport: Whether to use the `uamqp` library as the underlying transport. The default value is
510518
False and the Pure Python AMQP library will be used as the underlying transport.
511519
:paramtype uamqp_transport: bool
@@ -542,6 +550,7 @@ def from_connection_string(
542550
transport_type=transport_type,
543551
custom_endpoint_address=custom_endpoint_address,
544552
connection_verify=connection_verify,
553+
ssl_context=ssl_context,
545554
uamqp_transport=uamqp_transport,
546555
**kwargs,
547556
)

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@
4646
import logging
4747
from threading import Lock
4848

49-
import certifi
50-
5149
from ._platform import KNOWN_TCP_OPTS, SOL_TCP
5250
from ._encode import encode_frame
5351
from ._decode import decode_frame, decode_empty_frame
@@ -503,18 +501,12 @@ def _setup_transport(self):
503501
self.sock = self._wrap_socket(self.sock, **self.sslopts)
504502
self._quick_recv = self.sock.recv
505503

506-
def _wrap_socket(self, sock, context=None, **sslopts):
507-
if context:
508-
return self._wrap_context(sock, sslopts, **context)
504+
def _wrap_socket(self, sock, **sslopts):
505+
if "context" in sslopts:
506+
context = sslopts.pop("context")
507+
return context.wrap_socket(sock, **sslopts)
509508
return self._wrap_socket_sni(sock, **sslopts)
510509

511-
def _wrap_context(self, sock, sslopts, check_hostname=None, **ctx_options):
512-
ctx = ssl.create_default_context(**ctx_options)
513-
ctx.verify_mode = ssl.CERT_REQUIRED
514-
ctx.load_verify_locations(cafile=certifi.where())
515-
ctx.check_hostname = check_hostname
516-
return ctx.wrap_socket(sock, **sslopts)
517-
518510
def _wrap_socket_sni(
519511
self,
520512
sock,

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_client_async.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from functools import partial
1313
from typing import Any, Callable, Coroutine, List, Dict, Optional, Tuple, Union, overload, cast
1414
from typing_extensions import Literal
15-
import certifi
1615

1716
from ..outcomes import Accepted, Modified, Received, Rejected, Released
1817
from ._connection_async import Connection
@@ -126,9 +125,11 @@ class AMQPClientAsync(AMQPClientSync):
126125
If port is not specified in the `custom_endpoint_address`, by default port 443 will be used.
127126
:paramtype custom_endpoint_address: str
128127
:keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to
129-
authenticate the identity of the connection endpoint.
130-
Default is None in which case `certifi.where()` will be used.
128+
authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None
129+
in which case `certifi.where()` will be used.
131130
:paramtype connection_verify: str
131+
:keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored.
132+
:paramtype ssl_context: ssl.SSLContext or None
132133
:keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should
133134
wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp),
134135
and 1 for transport type AmqpOverWebsocket.
@@ -238,7 +239,7 @@ async def open_async(self, connection=None):
238239
self._connection = Connection(
239240
"amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname,
240241
sasl_credential=self._auth.sasl,
241-
ssl_opts={"ca_certs": self._connection_verify or certifi.where()},
242+
ssl_opts=self._ssl_opts,
242243
container_id=self._name,
243244
max_frame_size=self._max_frame_size,
244245
channel_max=self._channel_max,
@@ -480,9 +481,11 @@ class SendClientAsync(SendClientSync, AMQPClientAsync):
480481
If port is not specified in the `custom_endpoint_address`, by default port 443 will be used.
481482
:paramtype custom_endpoint_address: str
482483
:keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to
483-
authenticate the identity of the connection endpoint.
484-
Default is None in which case `certifi.where()` will be used.
484+
authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None
485+
in which case `certifi.where()` will be used.
485486
:paramtype connection_verify: str
487+
:keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored.
488+
:paramtype ssl_context: ssl.SSLContext or None
486489
"""
487490

488491
async def _client_ready_async(self):
@@ -686,9 +689,11 @@ class ReceiveClientAsync(ReceiveClientSync, AMQPClientAsync):
686689
If port is not specified in the `custom_endpoint_address`, by default port 443 will be used.
687690
:paramtype custom_endpoint_address: str
688691
:keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to
689-
authenticate the identity of the connection endpoint.
690-
Default is None in which case `certifi.where()` will be used.
692+
authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None
693+
in which case `certifi.where()` will be used.
691694
:paramtype connection_verify: str
695+
:keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored.
696+
:paramtype ssl_context: ssl.SSLContext or None
692697
"""
693698

694699
async def _client_ready_async(self):

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@
4242
import logging
4343

4444

45-
import certifi
46-
4745
from .._platform import KNOWN_TCP_OPTS, SOL_TCP
4846
from .._encode import encode_frame
4947
from .._decode import decode_frame, decode_empty_frame
@@ -167,7 +165,7 @@ def _build_ssl_opts(self, sslopts):
167165
return sslopts
168166
try:
169167
if "context" in sslopts:
170-
return self._build_ssl_context(**sslopts.pop("context"))
168+
return sslopts["context"]
171169
ssl_version = sslopts.get("ssl_version")
172170
if ssl_version is None:
173171
ssl_version = ssl.PROTOCOL_TLS_CLIENT
@@ -206,13 +204,6 @@ def _build_ssl_opts(self, sslopts):
206204
except TypeError:
207205
raise TypeError("SSL configuration must be a dictionary, or the value True.") from None
208206

209-
def _build_ssl_context(self, check_hostname=None, **ctx_options):
210-
ctx = ssl.create_default_context(**ctx_options)
211-
ctx.verify_mode = ssl.CERT_REQUIRED
212-
ctx.load_verify_locations(cafile=certifi.where())
213-
ctx.check_hostname = check_hostname
214-
return ctx
215-
216207

217208
class AsyncTransport(AsyncTransportMixin): # pylint: disable=too-many-instance-attributes
218209
"""Common superclass for TCP and SSL transports."""

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/client.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,11 @@ class AMQPClient(object): # pylint: disable=too-many-instance-attributes
143143
If port is not specified in the `custom_endpoint_address`, by default port 443 will be used.
144144
:paramtype custom_endpoint_address: str
145145
:keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to
146-
authenticate the identity of the connection endpoint.
147-
Default is None in which case `certifi.where()` will be used.
146+
authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None
147+
in which case `certifi.where()` will be used.
148148
:paramtype connection_verify: str
149+
:keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored.
150+
:paramtype ssl_context: ssl.SSLContext or None
149151
:keyword float socket_timeout: The maximum time in seconds that the underlying socket in the transport should
150152
wait when reading or writing data before timing out. The default value is 0.2 (for transport type Amqp),
151153
and 1 for transport type AmqpOverWebsocket.
@@ -201,7 +203,13 @@ def __init__(self, hostname, **kwargs):
201203

202204
# Custom Endpoint
203205
self._custom_endpoint_address = kwargs.get("custom_endpoint_address")
204-
self._connection_verify = kwargs.get("connection_verify")
206+
connection_verify = kwargs.get("connection_verify")
207+
ssl_context = kwargs.get("ssl_context")
208+
self._ssl_opts = {}
209+
if ssl_context:
210+
self._ssl_opts["context"] = ssl_context
211+
else: # str or None
212+
self._ssl_opts["ca_certs"] = connection_verify or certifi.where()
205213

206214
# Emulator
207215
self._use_tls: bool = kwargs.get("use_tls", True)
@@ -306,7 +314,7 @@ def open(self, connection=None):
306314
self._connection = Connection(
307315
"amqps://" + self._hostname if self._use_tls else "amqp://" + self._hostname,
308316
sasl_credential=self._auth.sasl,
309-
ssl_opts={"ca_certs": self._connection_verify or certifi.where()},
317+
ssl_opts=self._ssl_opts,
310318
container_id=self._name,
311319
max_frame_size=self._max_frame_size,
312320
channel_max=self._channel_max,
@@ -556,9 +564,11 @@ class SendClient(AMQPClient):
556564
If port is not specified in the `custom_endpoint_address`, by default port 443 will be used.
557565
:paramtype custom_endpoint_address: str
558566
:keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to
559-
authenticate the identity of the connection endpoint.
560-
Default is None in which case `certifi.where()` will be used.
567+
authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None
568+
in which case `certifi.where()` will be used.
561569
:paramtype connection_verify: str
570+
:keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored.
571+
:paramtype ssl_context: ssl.SSLContext or None
562572
"""
563573

564574
def __init__(self, hostname, target, **kwargs):
@@ -779,9 +789,11 @@ class ReceiveClient(AMQPClient): # pylint:disable=too-many-instance-attributes
779789
If port is not specified in the `custom_endpoint_address`, by default port 443 will be used.
780790
:paramtype custom_endpoint_address: str
781791
:keyword connection_verify: Path to the custom CA_BUNDLE file of the SSL certificate which is used to
782-
authenticate the identity of the connection endpoint.
783-
Default is None in which case `certifi.where()` will be used.
792+
authenticate the identity of the connection endpoint. Ignored if ssl_context passed in. Default is None
793+
in which case `certifi.where()` will be used.
784794
:paramtype connection_verify: str
795+
:keyword ssl_context: An instance of ssl.SSLContext to be used. If this is specified, connection_verify is ignored.
796+
:paramtype ssl_context: ssl.SSLContext or None
785797
"""
786798

787799
def __init__(self, hostname, source, **kwargs):

sdk/eventhub/azure-eventhub/azure/eventhub/_transport/_pyamqp_transport.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def create_send_client( # pylint: disable=unused-argument
332332
target,
333333
custom_endpoint_address=config.custom_endpoint_address,
334334
connection_verify=config.connection_verify,
335+
ssl_context=config.ssl_context,
335336
transport_type=config.transport_type,
336337
http_proxy=config.http_proxy,
337338
socket_timeout=config.socket_timeout,
@@ -474,6 +475,7 @@ def create_receive_client(
474475
transport_type=config.transport_type,
475476
custom_endpoint_address=config.custom_endpoint_address,
476477
connection_verify=config.connection_verify,
478+
ssl_context=config.ssl_context,
477479
socket_timeout=config.socket_timeout,
478480
auth=auth,
479481
idle_timeout=idle_timeout,
@@ -550,7 +552,6 @@ def create_token_auth(
550552
timeout=config.auth_timeout,
551553
custom_endpoint_hostname=config.custom_endpoint_hostname,
552554
port=config.connection_port,
553-
verify=config.connection_verify,
554555
)
555556
# if update_token:
556557
# token_auth.update_token() # TODO: why don't we need to update in pyamqp?
@@ -575,6 +576,7 @@ def create_mgmt_client(address, mgmt_auth, config): # pylint: disable=unused-ar
575576
http_proxy=config.http_proxy,
576577
custom_endpoint_address=config.custom_endpoint_address,
577578
connection_verify=config.connection_verify,
579+
ssl_context=config.ssl_context,
578580
use_tls=config.use_tls,
579581
)
580582

0 commit comments

Comments
 (0)