Skip to content

Commit 3b76a0d

Browse files
authored
Allow Closing to detect dependent resources (#636)
1 parent a79ea17 commit 3b76a0d

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

Diff for: src/dependency_injector/wiring.py

+19
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,22 @@ def _fetch_reference_injections( # noqa: C901
593593
return injections, closing
594594

595595

596+
def _locate_dependent_closing_args(provider: providers.Provider) -> dict[str, providers.Provider]:
597+
if not hasattr(provider, "args"):
598+
return {}
599+
600+
closing_deps = {}
601+
for arg in provider.args:
602+
if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"):
603+
continue
604+
605+
if not arg.args and isinstance(arg, providers.Resource):
606+
return {str(id(arg)): arg}
607+
else:
608+
closing_deps += _locate_dependent_closing_args(arg)
609+
return closing_deps
610+
611+
596612
def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
597613
patched_callable = _patched_registry.get_callable(fn)
598614
if patched_callable is None:
@@ -614,6 +630,9 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non
614630

615631
if injection in patched_callable.reference_closing:
616632
patched_callable.add_closing(injection, provider)
633+
deps = _locate_dependent_closing_args(provider)
634+
for key, dep in deps.items():
635+
patched_callable.add_closing(key, dep)
617636

618637

619638
def _unbind_injections(fn: Callable[..., Any]) -> None:

Diff for: tests/unit/samples/wiringstringids/resourceclosing.py

+11
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ def shutdown(cls):
2020
cls.shutdown_counter += 1
2121

2222

23+
class FactoryService:
24+
def __init__(self, service: Service):
25+
self.service = service
26+
27+
2328
def init_service():
2429
service = Service()
2530
service.init()
@@ -30,8 +35,14 @@ def init_service():
3035
class Container(containers.DeclarativeContainer):
3136

3237
service = providers.Resource(init_service)
38+
factory_service = providers.Factory(FactoryService, service)
3339

3440

3541
@inject
3642
def test_function(service: Service = Closing[Provide["service"]]):
3743
return service
44+
45+
46+
@inject
47+
def test_function_dependency(factory: FactoryService = Closing[Provide["factory_service"]]):
48+
return factory

Diff for: tests/unit/wiring/string_ids/test_main_py36.py

+17
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,23 @@ def test_closing_resource():
289289
assert result_1 is not result_2
290290

291291

292+
@mark.usefixtures("resourceclosing_container")
293+
def test_closing_dependency_resource():
294+
resourceclosing.Service.reset_counter()
295+
296+
result_1 = resourceclosing.test_function_dependency()
297+
assert isinstance(result_1, resourceclosing.FactoryService)
298+
assert result_1.service.init_counter == 1
299+
assert result_1.service.shutdown_counter == 1
300+
301+
result_2 = resourceclosing.test_function_dependency()
302+
assert isinstance(result_2, resourceclosing.FactoryService)
303+
assert result_2.service.init_counter == 2
304+
assert result_2.service.shutdown_counter == 2
305+
306+
assert result_1 is not result_2
307+
308+
292309
@mark.usefixtures("resourceclosing_container")
293310
def test_closing_resource_bypass_marker_injection():
294311
resourceclosing.Service.reset_counter()

0 commit comments

Comments
 (0)