diff --git a/openfeature/client.py b/openfeature/client.py index deac93c8..1ccee33b 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -8,6 +8,8 @@ ErrorCode, GeneralError, OpenFeatureError, + ProviderFatalError, + ProviderNotReadyError, TypeMismatchError, ) from openfeature.flag_evaluation import ( @@ -24,7 +26,7 @@ before_hooks, error_hooks, ) -from openfeature.provider import FeatureProvider +from openfeature.provider import FeatureProvider, ProviderStatus logger = logging.getLogger("openfeature") @@ -81,6 +83,10 @@ def __init__( def provider(self) -> FeatureProvider: return api._provider_registry.get_provider(self.domain) + def get_provider_status(self) -> ProviderStatus: + provider = api._provider_registry.get_provider(self.domain) + return api._provider_registry.get_provider_status(provider) + def get_metadata(self) -> ClientMetadata: return ClientMetadata(domain=self.domain) @@ -232,7 +238,7 @@ def get_object_details( flag_evaluation_options, ) - def evaluate_flag_details( + def evaluate_flag_details( # noqa: PLR0915 self, flag_type: FlagType, flag_key: str, @@ -282,6 +288,36 @@ def evaluate_flag_details( reversed_merged_hooks = merged_hooks[:] reversed_merged_hooks.reverse() + status = self.get_provider_status() + if status == ProviderStatus.NOT_READY: + error_hooks( + flag_type, + hook_context, + ProviderNotReadyError(), + reversed_merged_hooks, + hook_hints, + ) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.PROVIDER_NOT_READY, + ) + if status == ProviderStatus.FATAL: + error_hooks( + flag_type, + hook_context, + ProviderFatalError(), + reversed_merged_hooks, + hook_hints, + ) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.PROVIDER_FATAL, + ) + try: # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md # Any resulting evaluation context from a before hook will overwrite diff --git a/openfeature/exception.py b/openfeature/exception.py index e8a4768d..e6ad2456 100644 --- a/openfeature/exception.py +++ b/openfeature/exception.py @@ -4,6 +4,7 @@ class ErrorCode(Enum): PROVIDER_NOT_READY = "PROVIDER_NOT_READY" + PROVIDER_FATAL = "PROVIDER_FATAL" FLAG_NOT_FOUND = "FLAG_NOT_FOUND" PARSE_ERROR = "PARSE_ERROR" TYPE_MISMATCH = "TYPE_MISMATCH" @@ -31,6 +32,36 @@ def __init__( self.error_code = error_code +class ProviderNotReadyError(OpenFeatureError): + """ + This exception should be raised when the provider is not ready to be used. + """ + + def __init__(self, error_message: typing.Optional[str] = None): + """ + Constructor for the ProviderNotReadyError. The error code for this type of + exception is ErrorCode.PROVIDER_NOT_READY. + @param error_message: a string message representing why the error has been + raised + """ + super().__init__(ErrorCode.PROVIDER_NOT_READY, error_message) + + +class ProviderFatalError(OpenFeatureError): + """ + This exception should be raised when the provider encounters a fatal error. + """ + + def __init__(self, error_message: typing.Optional[str] = None): + """ + Constructor for the ProviderFatalError. The error code for this type of + exception is ErrorCode.PROVIDER_FATAL. + @param error_message: a string message representing why the error has been + raised + """ + super().__init__(ErrorCode.PROVIDER_FATAL, error_message) + + class FlagNotFoundError(OpenFeatureError): """ This exception should be raised when the provider cannot find a flag with the diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index d5ddacff..edb94ae1 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -1,4 +1,5 @@ import typing +from enum import Enum from openfeature.evaluation_context import EvaluationContext from openfeature.flag_evaluation import FlagResolutionDetails @@ -7,6 +8,14 @@ from .metadata import Metadata +class ProviderStatus(Enum): + NOT_READY = "NOT_READY" + READY = "READY" + ERROR = "ERROR" + STALE = "STALE" + FATAL = "FATAL" + + class FeatureProvider(typing.Protocol): # pragma: no cover def initialize(self, evaluation_context: EvaluationContext) -> None: ... diff --git a/openfeature/provider/registry.py b/openfeature/provider/registry.py index 55b59931..779ee569 100644 --- a/openfeature/provider/registry.py +++ b/openfeature/provider/registry.py @@ -1,18 +1,21 @@ import typing from openfeature.evaluation_context import EvaluationContext -from openfeature.exception import GeneralError -from openfeature.provider import FeatureProvider +from openfeature.exception import ErrorCode, GeneralError, OpenFeatureError +from openfeature.provider import FeatureProvider, ProviderStatus from openfeature.provider.no_op_provider import NoOpProvider class ProviderRegistry: _default_provider: FeatureProvider _providers: typing.Dict[str, FeatureProvider] + _provider_status: typing.Dict[FeatureProvider, ProviderStatus] def __init__(self) -> None: self._default_provider = NoOpProvider() self._providers = {} + self._provider_status = {} + self._set_provider_status(self._default_provider, ProviderStatus.NOT_READY) def set_provider(self, domain: str, provider: FeatureProvider) -> None: if provider is None: @@ -22,9 +25,9 @@ def set_provider(self, domain: str, provider: FeatureProvider) -> None: old_provider = providers[domain] del providers[domain] if old_provider not in providers.values(): - old_provider.shutdown() + self._shutdown_provider(old_provider) if provider not in providers.values(): - provider.initialize(self._get_evaluation_context()) + self._initialize_provider(provider) providers[domain] = provider def get_provider(self, domain: typing.Optional[str]) -> FeatureProvider: @@ -36,9 +39,9 @@ def set_default_provider(self, provider: FeatureProvider) -> None: if provider is None: raise GeneralError(error_message="No provider") if self._default_provider: - self._default_provider.shutdown() + self._shutdown_provider(self._default_provider) self._default_provider = provider - provider.initialize(self._get_evaluation_context()) + self._initialize_provider(provider) def get_default_provider(self) -> FeatureProvider: return self._default_provider @@ -50,10 +53,40 @@ def clear_providers(self) -> None: def shutdown(self) -> None: for provider in {self._default_provider, *self._providers.values()}: - provider.shutdown() + self._shutdown_provider(provider) def _get_evaluation_context(self) -> EvaluationContext: # imported here to avoid circular imports from openfeature.api import get_evaluation_context return get_evaluation_context() + + def _initialize_provider(self, provider: FeatureProvider) -> None: + try: + if hasattr(provider, "initialize"): + provider.initialize(self._get_evaluation_context()) + self._set_provider_status(provider, ProviderStatus.READY) + except Exception as err: + if ( + isinstance(err, OpenFeatureError) + and err.error_code == ErrorCode.PROVIDER_FATAL + ): + self._set_provider_status(provider, ProviderStatus.FATAL) + else: + self._set_provider_status(provider, ProviderStatus.ERROR) + + def _shutdown_provider(self, provider: FeatureProvider) -> None: + try: + if hasattr(provider, "shutdown"): + provider.shutdown() + self._set_provider_status(provider, ProviderStatus.NOT_READY) + except Exception: + self._set_provider_status(provider, ProviderStatus.FATAL) + + def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus: + return self._provider_status.get(provider, ProviderStatus.NOT_READY) + + def _set_provider_status( + self, provider: FeatureProvider, status: ProviderStatus + ) -> None: + self._provider_status[provider] = status diff --git a/tests/test_client.py b/tests/test_client.py index 71873405..5f710609 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,6 +7,7 @@ from openfeature.exception import ErrorCode, OpenFeatureError from openfeature.flag_evaluation import Reason from openfeature.hook import Hook +from openfeature.provider import ProviderStatus from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider from openfeature.provider.no_op_provider import NoOpProvider @@ -182,3 +183,56 @@ def test_should_call_api_level_hooks(no_op_provider_client): # Then api_hook.before.assert_called_once() api_hook.after.assert_called_once() + + +# Requirement 1.7.5 +def test_should_define_a_provider_status_accessor(no_op_provider_client): + # When + status = no_op_provider_client.get_provider_status() + # Then + assert status is not None + assert status == ProviderStatus.READY + + +# Requirement 1.7.6 +def test_should_shortcircuit_if_provider_is_not_ready( + no_op_provider_client, monkeypatch +): + # Given + monkeypatch.setattr( + no_op_provider_client, "get_provider_status", lambda: ProviderStatus.NOT_READY + ) + spy_hook = MagicMock(spec=Hook) + no_op_provider_client.add_hooks([spy_hook]) + # When + flag_details = no_op_provider_client.get_boolean_details( + flag_key="Key", default_value=True + ) + # Then + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.PROVIDER_NOT_READY + spy_hook.error.assert_called_once() + + +# Requirement 1.7.7 +def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( + no_op_provider_client, monkeypatch +): + # Given + monkeypatch.setattr( + no_op_provider_client, "get_provider_status", lambda: ProviderStatus.FATAL + ) + spy_hook = MagicMock(spec=Hook) + no_op_provider_client.add_hooks([spy_hook]) + # When + flag_details = no_op_provider_client.get_boolean_details( + flag_key="Key", default_value=True + ) + # Then + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.PROVIDER_FATAL + spy_hook.error.assert_called_once()