diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index b1f01622..62e2a755 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -598,14 +598,14 @@ def _locate_dependent_closing_args(provider: providers.Provider) -> Dict[str, pr return {} closing_deps = {} - for arg in provider.args: + for arg in [*provider.args, *provider.kwargs.values()]: if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"): continue - - if not arg.args and isinstance(arg, providers.Resource): + if isinstance(arg, providers.Resource): return {str(id(arg)): arg} - else: - closing_deps += _locate_dependent_closing_args(arg) + if arg.args or arg.kwargs: + closing_deps |= _locate_dependent_closing_args(arg) + return closing_deps diff --git a/tests/unit/samples/wiringstringids/resourceclosing.py b/tests/unit/samples/wiringstringids/resourceclosing.py index 6360e15c..b692c488 100644 --- a/tests/unit/samples/wiringstringids/resourceclosing.py +++ b/tests/unit/samples/wiringstringids/resourceclosing.py @@ -2,9 +2,14 @@ from dependency_injector.wiring import inject, Provide, Closing +class Singleton: + pass + + class Service: init_counter: int = 0 shutdown_counter: int = 0 + dependency: Singleton = None @classmethod def reset_counter(cls): @@ -12,7 +17,9 @@ def reset_counter(cls): cls.shutdown_counter = 0 @classmethod - def init(cls): + def init(cls, dependency: Singleton = None): + if dependency: + cls.dependency = dependency cls.init_counter += 1 @classmethod @@ -25,6 +32,11 @@ def __init__(self, service: Service): self.service = service +class NestedService: + def __init__(self, factory_service: FactoryService): + self.factory_service = factory_service + + def init_service(): service = Service() service.init() @@ -32,10 +44,37 @@ def init_service(): service.shutdown() +def init_service_with_singleton(singleton: Singleton): + service = Service() + service.init(singleton) + yield service + service.shutdown() + + class Container(containers.DeclarativeContainer): service = providers.Resource(init_service) factory_service = providers.Factory(FactoryService, service) + factory_service_kwargs = providers.Factory( + FactoryService, + service=service + ) + nested_service = providers.Factory(NestedService, factory_service) + + +class ContainerSingleton(containers.DeclarativeContainer): + + singleton = providers.Singleton(Singleton) + service = providers.Resource( + init_service_with_singleton, + singleton + ) + factory_service = providers.Factory(FactoryService, service) + factory_service_kwargs = providers.Factory( + FactoryService, + service=service + ) + nested_service = providers.Factory(NestedService, factory_service) @inject @@ -44,5 +83,21 @@ def test_function(service: Service = Closing[Provide["service"]]): @inject -def test_function_dependency(factory: FactoryService = Closing[Provide["factory_service"]]): +def test_function_dependency( + factory: FactoryService = Closing[Provide["factory_service"]] +): + return factory + + +@inject +def test_function_dependency_kwargs( + factory: FactoryService = Closing[Provide["factory_service_kwargs"]] +): return factory + + +@inject +def test_function_nested_dependency( + nested: NestedService = Closing[Provide["nested_service"]] +): + return nested diff --git a/tests/unit/wiring/string_ids/test_main_py36.py b/tests/unit/wiring/string_ids/test_main_py36.py index d4c49fe8..2861c3aa 100644 --- a/tests/unit/wiring/string_ids/test_main_py36.py +++ b/tests/unit/wiring/string_ids/test_main_py36.py @@ -33,9 +33,12 @@ def subcontainer(): container.unwire() -@fixture -def resourceclosing_container(): - container = resourceclosing.Container() +@fixture(params=[ + resourceclosing.Container, + resourceclosing.ContainerSingleton, +]) +def resourceclosing_container(request): + container = request.param() container.wire(modules=[resourceclosing]) yield container container.unwire() @@ -303,6 +306,36 @@ def test_closing_dependency_resource(): assert result_2.service.init_counter == 2 assert result_2.service.shutdown_counter == 2 + +@mark.usefixtures("resourceclosing_container") +def test_closing_dependency_resource_kwargs(): + resourceclosing.Service.reset_counter() + + result_1 = resourceclosing.test_function_dependency_kwargs() + assert isinstance(result_1, resourceclosing.FactoryService) + assert result_1.service.init_counter == 1 + assert result_1.service.shutdown_counter == 1 + + result_2 = resourceclosing.test_function_dependency_kwargs() + assert isinstance(result_2, resourceclosing.FactoryService) + assert result_2.service.init_counter == 2 + assert result_2.service.shutdown_counter == 2 + + +@mark.usefixtures("resourceclosing_container") +def test_closing_nested_dependency_resource(): + resourceclosing.Service.reset_counter() + + result_1 = resourceclosing.test_function_nested_dependency() + assert isinstance(result_1, resourceclosing.NestedService) + assert result_1.factory_service.service.init_counter == 1 + assert result_1.factory_service.service.shutdown_counter == 1 + + result_2 = resourceclosing.test_function_nested_dependency() + assert isinstance(result_2, resourceclosing.NestedService) + assert result_2.factory_service.service.init_counter == 2 + assert result_2.factory_service.service.shutdown_counter == 2 + assert result_1 is not result_2