Skip to content

Fix broken wiring of sync inject-decorated methods #673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 22 additions & 27 deletions src/dependency_injector/_cwiring.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,39 @@

import asyncio
import collections.abc
import functools
import inspect
import types

from . import providers
from .wiring import _Marker, PatchedCallable
from .wiring import _Marker

from .providers cimport Provider
from .providers cimport Provider, Resource


def _get_sync_patched(fn, patched: PatchedCallable):
@functools.wraps(fn)
def _patched(*args, **kwargs):
cdef object result
cdef dict to_inject
cdef object arg_key
cdef Provider provider
def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
cdef object result
cdef dict to_inject
cdef object arg_key
cdef Provider provider

to_inject = kwargs.copy()
for arg_key, provider in patched.injections.items():
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
to_inject[arg_key] = provider()
to_inject = kwargs.copy()
for arg_key, provider in injections.items():
if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker):
to_inject[arg_key] = provider()

result = fn(*args, **to_inject)
result = fn(*args, **to_inject)

if patched.closing:
for arg_key, provider in patched.closing.items():
if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):
continue
if not isinstance(provider, providers.Resource):
continue
provider.shutdown()
if closings:
for arg_key, provider in closings.items():
if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker):
continue
if not isinstance(provider, Resource):
continue
provider.shutdown()

return result
return _patched
return result


async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings):
async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /):
cdef object result
cdef dict to_inject
cdef list to_inject_await = []
Expand Down Expand Up @@ -69,7 +64,7 @@ async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dic
for arg_key, provider in closings.items():
if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker):
continue
if not isinstance(provider, providers.Resource):
if not isinstance(provider, Resource):
continue
shutdown = provider.shutdown()
if _isawaitable(shutdown):
Expand Down
17 changes: 15 additions & 2 deletions src/dependency_injector/wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ def is_loader_installed() -> bool:
_loader = AutoLoader()

# Optimizations
from ._cwiring import _get_sync_patched # noqa
from ._cwiring import _sync_inject # noqa
from ._cwiring import _async_inject # noqa


Expand All @@ -1047,4 +1047,17 @@ async def _patched(*args, **kwargs):
patched.closing,
)

return _patched
return cast(F, _patched)


def _get_sync_patched(fn: F, patched: PatchedCallable) -> F:
@functools.wraps(fn)
def _patched(*args, **kwargs):
return _sync_inject(
fn,
args,
kwargs,
patched.injections,
patched.closing,
)
return cast(F, _patched)
7 changes: 7 additions & 0 deletions tests/unit/wiring/test_introspection_py36.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from dependency_injector.wiring import inject


def test_isfunction():
@inject
def foo(): ...

assert inspect.isfunction(foo)


def test_asyncio_iscoroutinefunction():
@inject
async def foo():
Expand Down