Skip to content

Commit 4d2dfdc

Browse files
committed
refactor: move provider dict to a ProviderRegistry class
Signed-off-by: Federico Bond <[email protected]>
1 parent 90b25c1 commit 4d2dfdc

File tree

3 files changed

+68
-40
lines changed

3 files changed

+68
-40
lines changed

Diff for: openfeature/api.py

+8-39
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66
from openfeature.hook import Hook
77
from openfeature.provider import FeatureProvider
88
from openfeature.provider.metadata import Metadata
9-
from openfeature.provider.no_op_provider import NoOpProvider
10-
11-
_provider: FeatureProvider = NoOpProvider()
9+
from openfeature.provider.registry import ProviderRegistry
1210

1311
_evaluation_context = EvaluationContext()
1412

1513
_hooks: typing.List[Hook] = []
1614

17-
_providers: typing.Dict[str, FeatureProvider] = {}
15+
_provider_registry: ProviderRegistry = ProviderRegistry()
1816

1917

2018
def get_client(
@@ -26,46 +24,18 @@ def get_client(
2624
def set_provider(
2725
provider: FeatureProvider, domain: typing.Optional[str] = None
2826
) -> None:
29-
if provider is None:
30-
raise GeneralError(error_message="No provider")
31-
32-
if domain:
33-
_set_domain_provider(domain, provider)
34-
return
35-
36-
global _provider
37-
if _provider:
38-
_provider.shutdown()
39-
_provider = provider
40-
provider.initialize(_evaluation_context)
41-
42-
43-
def _set_domain_provider(domain: str, provider: FeatureProvider) -> None:
44-
if domain in _providers:
45-
old_provider = _providers[domain]
46-
del _providers[domain]
47-
if old_provider not in _providers.values():
48-
old_provider.shutdown()
49-
if provider not in _providers.values():
50-
provider.initialize(_evaluation_context)
51-
_providers[domain] = provider
52-
53-
54-
def _get_provider(domain: typing.Optional[str] = None) -> FeatureProvider:
55-
global _provider
5627
if domain is None:
57-
return _provider
58-
return _providers.get(domain, _provider)
28+
_provider_registry.set_default_provider(provider)
29+
else:
30+
_provider_registry.set_provider(domain, provider)
5931

6032

6133
def clear_providers() -> None:
62-
for provider in _providers.values():
63-
provider.shutdown()
64-
_providers.clear()
34+
return _provider_registry.clear_providers()
6535

6636

6737
def get_provider_metadata(domain: typing.Optional[str] = None) -> Metadata:
68-
return _get_provider(domain).get_metadata()
38+
return _provider_registry.get_provider(domain).get_metadata()
6939

7040

7141
def get_evaluation_context() -> EvaluationContext:
@@ -96,5 +66,4 @@ def get_hooks() -> typing.List[Hook]:
9666

9767

9868
def shutdown() -> None:
99-
for provider in {_provider, *_providers.values()}:
100-
provider.shutdown()
69+
_provider_registry.shutdown()

Diff for: openfeature/client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(
7979

8080
@property
8181
def provider(self) -> FeatureProvider:
82-
return api._get_provider(domain=self.domain)
82+
return api._provider_registry.get_provider(self.domain)
8383

8484
def get_metadata(self) -> ClientMetadata:
8585
return ClientMetadata(domain=self.domain)

Diff for: openfeature/provider/registry.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import typing
2+
3+
from openfeature.evaluation_context import EvaluationContext
4+
from openfeature.exception import GeneralError
5+
from openfeature.provider import FeatureProvider
6+
from openfeature.provider.no_op_provider import NoOpProvider
7+
8+
9+
class ProviderRegistry:
10+
_default_provider: FeatureProvider
11+
_providers: typing.Dict[str, FeatureProvider]
12+
13+
def __init__(self) -> None:
14+
self._default_provider = NoOpProvider()
15+
self._providers = {}
16+
17+
def set_provider(self, domain: str, provider: FeatureProvider) -> None:
18+
if provider is None:
19+
raise GeneralError(error_message="No provider")
20+
providers = self._providers
21+
if domain in providers:
22+
old_provider = providers[domain]
23+
del providers[domain]
24+
if old_provider not in providers.values():
25+
old_provider.shutdown()
26+
if provider not in providers.values():
27+
provider.initialize(self._get_evaluation_context())
28+
providers[domain] = provider
29+
30+
def get_provider(self, domain: typing.Optional[str]) -> FeatureProvider:
31+
if domain is None:
32+
return self._default_provider
33+
return self._providers.get(domain, self._default_provider)
34+
35+
def set_default_provider(self, provider: FeatureProvider) -> None:
36+
if provider is None:
37+
raise GeneralError(error_message="No provider")
38+
if self._default_provider:
39+
self._default_provider.shutdown()
40+
self._default_provider = provider
41+
provider.initialize(self._get_evaluation_context())
42+
43+
def get_default_provider(self) -> FeatureProvider:
44+
return self._default_provider
45+
46+
def clear_providers(self) -> None:
47+
for provider in self._providers.values():
48+
provider.shutdown()
49+
self._providers.clear()
50+
51+
def shutdown(self) -> None:
52+
for provider in {self._default_provider, *self._providers.values()}:
53+
provider.shutdown()
54+
55+
def _get_evaluation_context(self) -> EvaluationContext:
56+
# imported here to avoid circular imports
57+
from openfeature.api import get_evaluation_context # noqa: PLC0415
58+
59+
return get_evaluation_context()

0 commit comments

Comments
 (0)