diff --git a/openfeature/api.py b/openfeature/api.py index db9e1476..902e8b3b 100644 --- a/openfeature/api.py +++ b/openfeature/api.py @@ -25,7 +25,10 @@ def set_provider(provider: AbstractProvider): global _provider if provider is None: raise GeneralError(error_message="No provider") + if _provider: + _provider.shutdown() _provider = provider + provider.initialize(_evaluation_context) def get_provider() -> typing.Optional[AbstractProvider]: @@ -63,3 +66,7 @@ def clear_hooks(): def get_hooks() -> typing.List[Hook]: global _hooks return _hooks + + +def shutdown(): + _provider.shutdown() diff --git a/openfeature/provider/provider.py b/openfeature/provider/provider.py index 6a59e9cf..73ce37b5 100644 --- a/openfeature/provider/provider.py +++ b/openfeature/provider/provider.py @@ -8,6 +8,12 @@ class AbstractProvider: + def initialize(self, evaluation_context: EvaluationContext): + pass + + def shutdown(self): + pass + @abstractmethod def get_metadata(self) -> Metadata: pass diff --git a/tests/test_api.py b/tests/test_api.py index 5ffc823f..040db910 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -12,12 +12,14 @@ get_provider_metadata, set_evaluation_context, set_provider, + shutdown, ) from openfeature.evaluation_context import EvaluationContext from openfeature.exception import ErrorCode, GeneralError from openfeature.hook import Hook from openfeature.provider.metadata import Metadata from openfeature.provider.no_op_provider import NoOpProvider +from openfeature.provider.provider import AbstractProvider def test_should_not_raise_exception_with_noop_client(): @@ -56,6 +58,32 @@ def test_should_try_set_provider_and_fail_if_none_provided(): assert ge.value.error_code == ErrorCode.GENERAL +def test_should_invoke_provider_initialize_function_on_newly_registered_provider(): + # Given + evaluation_context = EvaluationContext("targeting_key", {"attr1": "val1"}) + provider = MagicMock(spec=AbstractProvider) + + # When + set_evaluation_context(evaluation_context) + set_provider(provider) + + # Then + provider.initialize.assert_called_with(evaluation_context) + + +def test_should_invoke_provider_shutdown_function_once_provider_is_no_longer_in_use(): + # Given + provider_1 = MagicMock(spec=AbstractProvider) + provider_2 = MagicMock(spec=AbstractProvider) + + # When + set_provider(provider_1) + set_provider(provider_2) + + # Then + assert provider_1.shutdown.called + + def test_should_return_a_provider_if_setup_correctly(): # Given set_provider(NoOpProvider()) @@ -116,3 +144,15 @@ def test_should_add_hooks_to_api_hooks(): # Then assert get_hooks() == [hook_1, hook_2] + + +def test_should_call_provider_shutdown_on_api_shutdown(): + # Given + provider = MagicMock(spec=AbstractProvider) + set_provider(provider) + + # When + shutdown() + + # Then + assert provider.shutdown.called