Skip to content

Consolidated credentials #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions sdk/eventhub/azure-eventhubs/azure/eventhub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ._common import EventData, EventDataBatch, EventPosition
from ._producer_client import EventHubProducerClient
from ._consumer_client import EventHubConsumerClient
from ._common import EventHubSharedKeyCredential, EventHubSASTokenCredential
from ._client_base import EventHubSharedKeyCredential
from ._eventprocessor.partition_manager import PartitionManager
from ._eventprocessor.common import CloseReason, OwnershipLostError
from ._eventprocessor.partition_context import PartitionContext
Expand Down Expand Up @@ -38,7 +38,6 @@
"EventHubConsumerClient",
"TransportType",
"EventHubSharedKeyCredential",
"EventHubSASTokenCredential",
"PartitionManager",
"CloseReason",
"OwnershipLostError",
Expand Down
94 changes: 52 additions & 42 deletions sdk/eventhub/azure-eventhubs/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from .exceptions import _handle_exception, ClientClosedError
from ._configuration import Configuration
from ._utils import parse_sas_token, utc_from_timestamp
from ._common import EventHubSharedKeyCredential, EventHubSASTokenCredential
from ._connection_manager import get_connection_manager
from ._constants import (
CONTAINER_PREFIX,
Expand All @@ -44,6 +43,8 @@
from azure.core.credentials import TokenCredential # type: ignore

_LOGGER = logging.getLogger(__name__)
_Address = collections.namedtuple('Address', 'hostname path')
_AccessToken = collections.namedtuple('AccessToken', 'token expires_on')


def _parse_conn_str(conn_str):
Expand Down Expand Up @@ -75,19 +76,23 @@ def _generate_sas_token(uri, policy, key, expiry=None):
:returns: SAS token as string literal.
:rtype: str
"""
from base64 import b64encode, b64decode
from hashlib import sha256
from hmac import HMAC
if not expiry:
expiry = time.time() + 3600 # Default to 1 hour.
encoded_uri = quote_plus(uri)
ttl = int(expiry)
sign_key = '{}\n{}'.format(encoded_uri, ttl)
sign_key = '%s\n%d' % (encoded_uri, ttl)
signature = b64encode(HMAC(b64decode(key), sign_key.encode('utf-8'), sha256).digest())
result = {
'sr': uri,
'sig': signature,
'se': str(ttl)}
if policy:
result['skn'] = policy
return 'SharedAccessSignature ' + urlencode(result)
token = 'SharedAccessSignature ' + urlencode(result)
return _AccessToken(token=token, expires_on=ttl)


def _build_uri(address, entity):
Expand All @@ -100,7 +105,22 @@ def _build_uri(address, entity):
return address


_Address = collections.namedtuple('Address', 'hostname path')
class EventHubSharedKeyCredential(object):
"""
The shared access key credential used for authentication.

:param str policy: The name of the shared access policy.
:param str key: The shared access key.
"""
def __init__(self, policy, key):
self.policy = policy
self.key = key
self.token_type = b"servicebus.windows.net:sastoken"

def get_token(self, *scopes, **kwargs):
if not scopes:
raise ValueError("No token scope provided.")
return _generate_sas_token(scopes[0], self.policy, self.key)


class ClientBase(object): # pylint:disable=too-many-instance-attributes
Expand Down Expand Up @@ -141,39 +161,29 @@ def _create_auth(self):
Create an ~uamqp.authentication.SASTokenAuth instance to authenticate
the session.
"""
http_proxy = self._config.http_proxy
transport_type = self._config.transport_type
auth_timeout = self._config.auth_timeout

# TODO: the following code can be refactored to create auth from classes directly instead of using if-else
if isinstance(self._credential, EventHubSharedKeyCredential): # pylint:disable=no-else-return
username = self._credential.policy
password = self._credential.key
if "@sas.root" in username:
return authentication.SASLPlain(
self._address.hostname, username, password, http_proxy=http_proxy, transport_type=transport_type)
return authentication.SASTokenAuth.from_shared_access_key(
self._auth_uri, username, password, timeout=auth_timeout, http_proxy=http_proxy,
transport_type=transport_type)

elif isinstance(self._credential, EventHubSASTokenCredential):
token = self._credential.get_sas_token()
try:
expiry = int(parse_sas_token(token)['se'])
except (KeyError, TypeError, IndexError):
raise ValueError("Supplied SAS token has no valid expiry value.")
return authentication.SASTokenAuth(
self._auth_uri, self._auth_uri, token,
expires_at=expiry,
timeout=auth_timeout,
http_proxy=http_proxy,
transport_type=transport_type)

else: # Azure credential
get_jwt_token = functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE)
return authentication.JWTTokenAuth(self._auth_uri, self._auth_uri,
get_jwt_token, http_proxy=http_proxy,
transport_type=transport_type)
try:
token_type = self._credential.token_type
except AttributeError:
token_type = b'jwt'
if token_type == b"servicebus.windows.net:sastoken":
auth = authentication.JWTTokenAuth(
self._auth_uri,
self._auth_uri,
functools.partial(self._credential.get_token, self._auth_uri),
token_type=token_type,
timeout=self._config.auth_timeout,
http_proxy=self._config.http_proxy,
transport_type=self._config.transport_type)
auth.update_token()
return auth
return authentication.JWTTokenAuth(
self._auth_uri,
self._auth_uri,
functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE),
token_type=token_type,
timeout=self._config.auth_timeout,
http_proxy=self._config.http_proxy,
transport_type=self._config.transport_type)

def _close_connection(self):
self._conn_manager.reset_connection_if_broken()
Expand Down Expand Up @@ -317,11 +327,11 @@ def _open(self):
if not self.running:
if self._handler:
self._handler.close()
self._create_handler()
self._handler.open(connection=self._client._conn_manager.get_connection( # pylint: disable=protected-access
self._client._address.hostname,
self._client._create_auth()
))
auth = self._client._create_auth()
self._create_handler(auth)
self._handler.open(
connection=self._client._conn_manager.get_connection(self._client._address.hostname, auth) # pylint: disable=protected-access
)
while not self._handler.client_ready():
time.sleep(0.05)
self._max_message_size_on_link = self._handler.message_handler._link.peer_max_message_size \
Expand Down
31 changes: 0 additions & 31 deletions sdk/eventhub/azure-eventhubs/azure/eventhub/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,34 +385,3 @@ def _selector(self):
elif isinstance(self.value, six.integer_types):
return ("amqp.annotation.x-opt-sequence-number {} '{}'".format(operator, self.value)).encode('utf-8')
return ("amqp.annotation.x-opt-offset {} '{}'".format(operator, self.value)).encode('utf-8')


# TODO: move some behaviors to these two classes.
class EventHubSASTokenCredential(object):
"""
SAS token used for authentication.

:param token: A SAS token or function that returns a SAS token. If a function is supplied,
it will be used to retrieve subsequent tokens in the case of token expiry. The function should
take no arguments. The token can be type of str or Callable object.
"""
def __init__(self, token):
self.token = token

def get_sas_token(self):
if callable(self.token): # pylint:disable=no-else-return
return self.token()
else:
return self.token


class EventHubSharedKeyCredential(object):
"""
The shared access key credential used for authentication.

:param str policy: The name of the shared access policy.
:param str key: The shared access key.
"""
def __init__(self, policy, key):
self.policy = policy
self.key = key
4 changes: 2 additions & 2 deletions sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self, client, source, **kwargs):
self._last_enqueued_event_properties = {}
self._last_received_event = None

def _create_handler(self):
def _create_handler(self, auth):
source = Source(self._source)
if self._offset is not None:
source.set_filter(self._offset._selector()) # pylint:disable=protected-access
Expand All @@ -113,7 +113,7 @@ def _create_handler(self):
properties = create_properties(self._client._config.user_agent) # pylint:disable=protected-access
self._handler = ReceiveClient(
source,
auth=self._client._create_auth(), # pylint:disable=protected-access
auth=auth,
debug=self._client._config.network_tracing, # pylint:disable=protected-access
prefetch=self._prefetch,
link_properties=self._link_properties,
Expand Down
10 changes: 5 additions & 5 deletions sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import threading
from typing import Any, Union, Dict, Tuple, TYPE_CHECKING, Callable, List

from ._common import EventHubSharedKeyCredential, EventHubSASTokenCredential, EventData
from ._client_base import ClientBase
from ._common import EventData
from ._client_base import ClientBase, EventHubSharedKeyCredential
from ._consumer import EventHubConsumer
from ._constants import ALL_PARTITIONS
from ._eventprocessor.event_processor import EventProcessor
Expand Down Expand Up @@ -40,8 +40,8 @@ class EventHubConsumerClient(ClientBase):
:param str event_hub_path: The path of the specific Event Hub to connect the client to.
:param credential: The credential object used for authentication which implements particular interface
of getting tokens. It accepts :class:`EventHubSharedKeyCredential<azure.eventhub.EventHubSharedKeyCredential>`,
:class:`EventHubSASTokenCredential<azure.eventhub.EventHubSASTokenCredential>`, or credential objects generated by
the azure-identity library and objects that implement `get_token(self, *scopes)` method.
or credential objects generated by the azure-identity library and objects that
implement `get_token(self, *scopes)` method.
:keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`.
:keyword float auth_timeout: The time in seconds to wait for a token to be authorized by the service.
The default value is 60 seconds. If set to 0, no timeout will be enforced from the client.
Expand Down Expand Up @@ -72,7 +72,7 @@ class EventHubConsumerClient(ClientBase):
"""

def __init__(self, host, event_hub_path, credential, **kwargs):
# type:(str, str, Union[EventHubSharedKeyCredential, EventHubSASTokenCredential, TokenCredential], Any) -> None
# type:(str, str, TokenCredential, Any) -> None
self._partition_manager = kwargs.pop("partition_manager", None)
self._load_balancing_interval = kwargs.pop("load_balancing_interval", 10)
network_tracing = kwargs.pop("logging_enable", False)
Expand Down
4 changes: 2 additions & 2 deletions sdk/eventhub/azure-eventhubs/azure/eventhub/_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def __init__(self, client, target, **kwargs):
self._lock = threading.Lock()
self._link_properties = {types.AMQPSymbol(TIMEOUT_SYMBOL): types.AMQPLong(int(self._timeout * 1000))}

def _create_handler(self):
def _create_handler(self, auth):
self._handler = SendClient(
self._target,
auth=self._client._create_auth(), # pylint:disable=protected-access
auth=auth,
debug=self._client._config.network_tracing, # pylint:disable=protected-access
msg_timeout=self._timeout * 1000,
error_policy=self._retry_policy,
Expand Down
15 changes: 5 additions & 10 deletions sdk/eventhub/azure-eventhubs/azure/eventhub/_producer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,10 @@
from uamqp import constants # type:ignore

from .exceptions import ConnectError, EventHubError
from ._client_base import ClientBase
from ._client_base import ClientBase, EventHubSharedKeyCredential
from ._producer import EventHubProducer
from ._constants import ALL_PARTITIONS
from ._common import (
EventData,
EventHubSharedKeyCredential,
EventHubSASTokenCredential,
EventDataBatch
)
from ._common import EventData, EventDataBatch

if TYPE_CHECKING:
from azure.core.credentials import TokenCredential # type: ignore
Expand All @@ -34,8 +29,8 @@ class EventHubProducerClient(ClientBase):
:param str event_hub_path: The path of the specific Event Hub to connect the client to.
:param credential: The credential object used for authentication which implements particular interface
of getting tokens. It accepts :class:`EventHubSharedKeyCredential<azure.eventhub.EventHubSharedKeyCredential>`,
:class:`EventHubSASTokenCredential<azure.eventhub.EventHubSASTokenCredential>`, or credential objects generated by
the azure-identity library and objects that implement `get_token(self, *scopes)` method.
or credential objects generated by the azure-identity library and objects that
implement `get_token(self, *scopes)` method.
:keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`.
:keyword float auth_timeout: The time in seconds to wait for a token to be authorized by the service.
The default value is 60 seconds. If set to 0, no timeout will be enforced from the client.
Expand All @@ -60,7 +55,7 @@ class EventHubProducerClient(ClientBase):

"""
def __init__(self, host, event_hub_path, credential, **kwargs):
# type:(str, str, Union[EventHubSharedKeyCredential, EventHubSASTokenCredential, TokenCredential], Any) -> None
# type:(str, str, TokenCredential, Any) -> None
""""""
super(EventHubProducerClient, self).__init__(
host=host, event_hub_path=event_hub_path, credential=credential,
Expand Down
2 changes: 2 additions & 0 deletions sdk/eventhub/azure-eventhubs/azure/eventhub/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from ._client_base_async import EventHubSharedKeyCredential
from ._consumer_client_async import EventHubConsumerClient
from ._producer_client_async import EventHubProducerClient
from ._eventprocessor.partition_manager import PartitionManager
from ._eventprocessor.partition_context import PartitionContext

__all__ = [
"EventHubSharedKeyCredential",
"EventHubConsumerClient",
"EventHubProducerClient",
"PartitionManager",
Expand Down
Loading