Skip to content

feat: implement provider status #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions openfeature/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
ErrorCode,
GeneralError,
OpenFeatureError,
ProviderFatalError,
ProviderNotReadyError,
TypeMismatchError,
)
from openfeature.flag_evaluation import (
Expand All @@ -24,7 +26,7 @@
before_hooks,
error_hooks,
)
from openfeature.provider import FeatureProvider
from openfeature.provider import FeatureProvider, ProviderStatus

logger = logging.getLogger("openfeature")

Expand Down Expand Up @@ -81,6 +83,10 @@ def __init__(
def provider(self) -> FeatureProvider:
return api._provider_registry.get_provider(self.domain)

def get_provider_status(self) -> ProviderStatus:
provider = api._provider_registry.get_provider(self.domain)
return api._provider_registry.get_provider_status(provider)

def get_metadata(self) -> ClientMetadata:
return ClientMetadata(domain=self.domain)

Expand Down Expand Up @@ -232,7 +238,7 @@ def get_object_details(
flag_evaluation_options,
)

def evaluate_flag_details(
def evaluate_flag_details( # noqa: PLR0915
self,
flag_type: FlagType,
flag_key: str,
Expand Down Expand Up @@ -282,6 +288,36 @@ def evaluate_flag_details(
reversed_merged_hooks = merged_hooks[:]
reversed_merged_hooks.reverse()

status = self.get_provider_status()
if status == ProviderStatus.NOT_READY:
error_hooks(
flag_type,
hook_context,
ProviderNotReadyError(),
reversed_merged_hooks,
hook_hints,
)
return FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=ErrorCode.PROVIDER_NOT_READY,
)
if status == ProviderStatus.FATAL:
error_hooks(
flag_type,
hook_context,
ProviderFatalError(),
reversed_merged_hooks,
hook_hints,
)
return FlagEvaluationDetails(
flag_key=flag_key,
value=default_value,
reason=Reason.ERROR,
error_code=ErrorCode.PROVIDER_FATAL,
)

try:
# https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md
# Any resulting evaluation context from a before hook will overwrite
Expand Down
31 changes: 31 additions & 0 deletions openfeature/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

class ErrorCode(Enum):
PROVIDER_NOT_READY = "PROVIDER_NOT_READY"
PROVIDER_FATAL = "PROVIDER_FATAL"
FLAG_NOT_FOUND = "FLAG_NOT_FOUND"
PARSE_ERROR = "PARSE_ERROR"
TYPE_MISMATCH = "TYPE_MISMATCH"
Expand Down Expand Up @@ -31,6 +32,36 @@ def __init__(
self.error_code = error_code


class ProviderNotReadyError(OpenFeatureError):
"""
This exception should be raised when the provider is not ready to be used.
"""

def __init__(self, error_message: typing.Optional[str] = None):
"""
Constructor for the ProviderNotReadyError. The error code for this type of
exception is ErrorCode.PROVIDER_NOT_READY.
@param error_message: a string message representing why the error has been
raised
"""
super().__init__(ErrorCode.PROVIDER_NOT_READY, error_message)


class ProviderFatalError(OpenFeatureError):
"""
This exception should be raised when the provider encounters a fatal error.
"""

def __init__(self, error_message: typing.Optional[str] = None):
"""
Constructor for the ProviderFatalError. The error code for this type of
exception is ErrorCode.PROVIDER_FATAL.
@param error_message: a string message representing why the error has been
raised
"""
super().__init__(ErrorCode.PROVIDER_FATAL, error_message)


class FlagNotFoundError(OpenFeatureError):
"""
This exception should be raised when the provider cannot find a flag with the
Expand Down
9 changes: 9 additions & 0 deletions openfeature/provider/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing
from enum import Enum

from openfeature.evaluation_context import EvaluationContext
from openfeature.flag_evaluation import FlagResolutionDetails
Expand All @@ -7,6 +8,14 @@
from .metadata import Metadata


class ProviderStatus(Enum):
NOT_READY = "NOT_READY"
READY = "READY"
ERROR = "ERROR"
STALE = "STALE"
FATAL = "FATAL"


class FeatureProvider(typing.Protocol): # pragma: no cover
def initialize(self, evaluation_context: EvaluationContext) -> None:
...
Expand Down
47 changes: 40 additions & 7 deletions openfeature/provider/registry.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import typing

from openfeature.evaluation_context import EvaluationContext
from openfeature.exception import GeneralError
from openfeature.provider import FeatureProvider
from openfeature.exception import ErrorCode, GeneralError, OpenFeatureError
from openfeature.provider import FeatureProvider, ProviderStatus
from openfeature.provider.no_op_provider import NoOpProvider


class ProviderRegistry:
_default_provider: FeatureProvider
_providers: typing.Dict[str, FeatureProvider]
_provider_status: typing.Dict[FeatureProvider, ProviderStatus]

def __init__(self) -> None:
self._default_provider = NoOpProvider()
self._providers = {}
self._provider_status = {}
self._set_provider_status(self._default_provider, ProviderStatus.NOT_READY)

def set_provider(self, domain: str, provider: FeatureProvider) -> None:
if provider is None:
Expand All @@ -22,9 +25,9 @@ def set_provider(self, domain: str, provider: FeatureProvider) -> None:
old_provider = providers[domain]
del providers[domain]
if old_provider not in providers.values():
old_provider.shutdown()
self._shutdown_provider(old_provider)
if provider not in providers.values():
provider.initialize(self._get_evaluation_context())
self._initialize_provider(provider)
providers[domain] = provider

def get_provider(self, domain: typing.Optional[str]) -> FeatureProvider:
Expand All @@ -36,9 +39,9 @@ def set_default_provider(self, provider: FeatureProvider) -> None:
if provider is None:
raise GeneralError(error_message="No provider")
if self._default_provider:
self._default_provider.shutdown()
self._shutdown_provider(self._default_provider)
self._default_provider = provider
provider.initialize(self._get_evaluation_context())
self._initialize_provider(provider)

def get_default_provider(self) -> FeatureProvider:
return self._default_provider
Expand All @@ -50,10 +53,40 @@ def clear_providers(self) -> None:

def shutdown(self) -> None:
for provider in {self._default_provider, *self._providers.values()}:
provider.shutdown()
self._shutdown_provider(provider)

def _get_evaluation_context(self) -> EvaluationContext:
# imported here to avoid circular imports
from openfeature.api import get_evaluation_context

return get_evaluation_context()

def _initialize_provider(self, provider: FeatureProvider) -> None:
try:
if hasattr(provider, "initialize"):
provider.initialize(self._get_evaluation_context())
self._set_provider_status(provider, ProviderStatus.READY)
except Exception as err:
if (
isinstance(err, OpenFeatureError)
and err.error_code == ErrorCode.PROVIDER_FATAL
):
self._set_provider_status(provider, ProviderStatus.FATAL)
else:
self._set_provider_status(provider, ProviderStatus.ERROR)

def _shutdown_provider(self, provider: FeatureProvider) -> None:
try:
if hasattr(provider, "shutdown"):
provider.shutdown()
self._set_provider_status(provider, ProviderStatus.NOT_READY)
except Exception:
self._set_provider_status(provider, ProviderStatus.FATAL)

def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus:
return self._provider_status.get(provider, ProviderStatus.NOT_READY)

def _set_provider_status(
self, provider: FeatureProvider, status: ProviderStatus
) -> None:
self._provider_status[provider] = status
54 changes: 54 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from openfeature.exception import ErrorCode, OpenFeatureError
from openfeature.flag_evaluation import Reason
from openfeature.hook import Hook
from openfeature.provider import ProviderStatus
from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider
from openfeature.provider.no_op_provider import NoOpProvider

Expand Down Expand Up @@ -182,3 +183,56 @@ def test_should_call_api_level_hooks(no_op_provider_client):
# Then
api_hook.before.assert_called_once()
api_hook.after.assert_called_once()


# Requirement 1.7.5
def test_should_define_a_provider_status_accessor(no_op_provider_client):
# When
status = no_op_provider_client.get_provider_status()
# Then
assert status is not None
assert status == ProviderStatus.READY


# Requirement 1.7.6
def test_should_shortcircuit_if_provider_is_not_ready(
no_op_provider_client, monkeypatch
):
# Given
monkeypatch.setattr(
no_op_provider_client, "get_provider_status", lambda: ProviderStatus.NOT_READY
)
spy_hook = MagicMock(spec=Hook)
no_op_provider_client.add_hooks([spy_hook])
# When
flag_details = no_op_provider_client.get_boolean_details(
flag_key="Key", default_value=True
)
# Then
assert flag_details is not None
assert flag_details.value
assert flag_details.reason == Reason.ERROR
assert flag_details.error_code == ErrorCode.PROVIDER_NOT_READY
spy_hook.error.assert_called_once()


# Requirement 1.7.7
def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state(
no_op_provider_client, monkeypatch
):
# Given
monkeypatch.setattr(
no_op_provider_client, "get_provider_status", lambda: ProviderStatus.FATAL
)
spy_hook = MagicMock(spec=Hook)
no_op_provider_client.add_hooks([spy_hook])
# When
flag_details = no_op_provider_client.get_boolean_details(
flag_key="Key", default_value=True
)
# Then
assert flag_details is not None
assert flag_details.value
assert flag_details.reason == Reason.ERROR
assert flag_details.error_code == ErrorCode.PROVIDER_FATAL
spy_hook.error.assert_called_once()