Source code for diwire.container_context

"""Global container context using Python's contextvars for lazy proxying."""

from __future__ import annotations

import inspect
import threading
import types
from collections.abc import Callable, Coroutine
from contextvars import ContextVar, Token
from dataclasses import dataclass
from functools import wraps
from typing import TYPE_CHECKING, Any, TypeVar, get_origin, overload

from diwire.exceptions import DIWireContainerNotSetError
from diwire.types import Factory, Lifetime

if TYPE_CHECKING:
    from diwire.container import Container
    from diwire.container_scopes import ScopedContainer

# Import signature builder to exclude Injected parameters from signature
from diwire.container_helpers import _build_signature_without_injected

T = TypeVar("T")
_C = TypeVar("_C", bound=type)

_current_container: ContextVar[Container | None] = ContextVar(
    "diwire_current_container",
    default=None,
)

# Thread-local fallback for when ContextVar is explicitly cleared or not set.
# Note: asyncio.run() does propagate ContextVar values; this fallback is primarily
# for cases where a new context is created without copying the parent context.
# Each thread gets its own fallback container to prevent cross-thread leakage.
# See: https://github.com/python/cpython/issues/102609
_thread_local_fallback: threading.local = threading.local()


@dataclass(slots=True)
class _DeferredRegistration:
    key: Any
    factory: Factory | None
    instance: Any | None
    lifetime: Lifetime
    scope: str | None
    is_async: bool | None
    concrete_class: type | None
    via_decorator: bool

    def apply(self, container: Container) -> None:
        if self.via_decorator:
            decorator = container.register(
                lifetime=self.lifetime,
                scope=self.scope,
                is_async=self.is_async,
                concrete_class=self.concrete_class,
            )
            decorator(self.key)
            return

        container.register(
            self.key,
            factory=self.factory,
            instance=self.instance,
            lifetime=self.lifetime,
            scope=self.scope,
            is_async=self.is_async,
            concrete_class=self.concrete_class,
        )


class _ContextInjected:
    """A callable wrapper that resolves dependencies from the context container.

    Similar to Injected, but lazily gets the container from context on each call.
    """

    def __init__(
        self,
        func: Callable[..., Any],
        proxy: _ContainerContextProxy,
    ) -> None:
        self._func = func
        self._proxy = proxy

        wraps(func)(self)
        self.__name__: str = getattr(func, "__name__", repr(func))
        self.__wrapped__: Callable[..., Any] = func

        # Build signature at decoration time by detecting Injected in annotations
        # This allows frameworks like FastAPI to correctly identify parameters
        self.__signature__ = _build_signature_without_injected(func)

    def _get_injected(self) -> Any:
        """Get the Injected wrapper from the current container."""
        container = self._proxy.get_current()
        return container.resolve(self._func)

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        injected = self._get_injected()
        return injected(*args, **kwargs)

    def __repr__(self) -> str:
        return f"_ContextInjected({self._func!r})"

    def __get__(self, obj: Any, objtype: type | None = None) -> Any:
        if obj is None:
            return self
        return types.MethodType(self, obj)


class _ContextScopedInjected:
    """A callable wrapper that creates a new scope from context container for each call.

    Similar to ScopedInjected, but lazily gets the container from context on each call.
    """

    def __init__(
        self,
        func: Callable[..., Any],
        proxy: _ContainerContextProxy,
        scope_name: str,
    ) -> None:
        self._func = func
        self._proxy = proxy
        self._scope_name = scope_name

        wraps(func)(self)
        self.__name__: str = getattr(func, "__name__", repr(func))
        self.__wrapped__: Callable[..., Any] = func

        # Build signature at decoration time by detecting Injected in annotations
        # This allows frameworks like FastAPI to correctly identify parameters
        self.__signature__ = _build_signature_without_injected(func)

    def _get_scoped_injected(self) -> Any:
        """Get the ScopedInjected wrapper from the current container."""
        container = self._proxy.get_current()
        return container.resolve(self._func, scope=self._scope_name)

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        scoped_injected = self._get_scoped_injected()
        return scoped_injected(*args, **kwargs)

    def __repr__(self) -> str:
        return f"_ContextScopedInjected({self._func!r}, scope={self._scope_name!r})"

    def __get__(self, obj: Any, objtype: type | None = None) -> Any:
        if obj is None:
            return self
        return types.MethodType(self, obj)


class _AsyncContextInjected:
    """A callable wrapper that resolves dependencies from the context container for async functions.

    Similar to AsyncInjected, but lazily gets the container from context on each call.
    """

    def __init__(
        self,
        func: Callable[..., Coroutine[Any, Any, Any]],
        proxy: _ContainerContextProxy,
    ) -> None:
        self._func = func
        self._proxy = proxy

        wraps(func)(self)
        self.__name__: str = getattr(func, "__name__", repr(func))
        self.__wrapped__: Callable[..., Coroutine[Any, Any, Any]] = func

        # Build signature at decoration time by detecting Injected in annotations
        # This allows frameworks like FastAPI to correctly identify parameters
        self.__signature__ = _build_signature_without_injected(func)

    def _get_async_injected(self) -> Any:
        """Get the AsyncInjected wrapper from the current container."""
        container = self._proxy.get_current()
        return container.resolve(self._func)

    async def __call__(self, *args: Any, **kwargs: Any) -> Any:
        async_injected = self._get_async_injected()
        return await async_injected(*args, **kwargs)

    def __repr__(self) -> str:
        return f"_AsyncContextInjected({self._func!r})"

    def __get__(self, obj: Any, objtype: type | None = None) -> Any:
        if obj is None:
            return self
        return types.MethodType(self, obj)


class _AsyncContextScopedInjected:
    """A callable wrapper that creates a new async scope from context container for each call.

    Similar to AsyncScopedInjected, but lazily gets the container from context on each call.
    """

    def __init__(
        self,
        func: Callable[..., Coroutine[Any, Any, Any]],
        proxy: _ContainerContextProxy,
        scope_name: str,
    ) -> None:
        self._func = func
        self._proxy = proxy
        self._scope_name = scope_name

        wraps(func)(self)
        self.__name__: str = getattr(func, "__name__", repr(func))
        self.__wrapped__: Callable[..., Coroutine[Any, Any, Any]] = func

        # Build signature at decoration time by detecting Injected in annotations
        # This allows frameworks like FastAPI to correctly identify parameters
        self.__signature__ = _build_signature_without_injected(func)

    def _get_async_scoped_injected(self) -> Any:
        """Get the AsyncScopedInjected wrapper from the current container."""
        container = self._proxy.get_current()
        return container.resolve(self._func, scope=self._scope_name)

    async def __call__(self, *args: Any, **kwargs: Any) -> Any:
        async_scoped_injected = self._get_async_scoped_injected()
        return await async_scoped_injected(*args, **kwargs)

    def __repr__(self) -> str:
        return f"_AsyncContextScopedInjected({self._func!r}, scope={self._scope_name!r})"

    def __get__(self, obj: Any, objtype: type | None = None) -> Any:
        if obj is None:
            return self
        return types.MethodType(self, obj)


[docs] class _ContainerContextProxy: """Lazy proxy that forwards calls to the current container from context. This allows setting up decorators before the container is configured, with the actual container lookup happening at call time. Resolution order in get_current(): 1. ContextVar (highest precedence - for explicit per-request containers) 2. Thread-local (for asyncio.run() case where ContextVar doesn't propagate) 3. Instance-level default (lowest precedence - for cross-thread access) The instance-level default exists because some frameworks (e.g., FastAPI/Starlette) run sync endpoint handlers in a thread pool, meaning neither ContextVar nor thread-local storage can access a container set in the main thread. Registrations can be deferred until a container is set; they are applied the next time set_current() is called. """ def __init__(self) -> None: self._default_container: Container | None = None self._deferred_registrations: list[_DeferredRegistration] = []
[docs] def set_current(self, container: Container) -> Token[Container | None]: """Set the current container in the context. Sets the container in three storage mechanisms: 1. ContextVar - for same async context access 2. Thread-local - for asyncio.run() case (same thread, new context) 3. Instance-level default - for thread pool access (different threads) Args: container: The container to set as current. Returns: A token that can be used to reset the container. """ self._default_container = container _thread_local_fallback.container = container token = _current_container.set(container) self._flush_deferred(container) return token
def _get_current_or_none(self) -> Container | None: """Return the current container if available, otherwise None.""" container = _current_container.get() if container is not None: return container container = getattr(_thread_local_fallback, "container", None) if container is not None: return container return self._default_container
[docs] def get_current(self) -> Container: """Get the current container from the context. Resolution order (first non-None wins): 1. ContextVar - for per-context containers 2. Thread-local - for asyncio.run() (same thread, new context) 3. Instance-level default - for thread pools (different thread entirely) Returns: The current container. Raises: DIWireContainerNotSetError: If no container has been set. """ # 1. Try ContextVar (highest precedence) container = _current_container.get() if container is not None: return container # 2. Fallback: Thread-local (for asyncio.run() in same thread) container = getattr(_thread_local_fallback, "container", None) if container is not None: return container # 3. Fallback: Instance-level default (for thread pools like FastAPI sync handlers) if self._default_container is not None: return self._default_container raise DIWireContainerNotSetError
def _flush_deferred(self, container: Container) -> None: if not self._deferred_registrations: return pending = self._deferred_registrations self._deferred_registrations = [] for registration in pending: registration.apply(container)
[docs] def reset(self, token: Token[Container | None]) -> None: """Reset the container to its previous value. Args: token: The token returned by set_current. """ _current_container.reset(token) current = _current_container.get() if current is None: if hasattr(_thread_local_fallback, "container"): del _thread_local_fallback.container self._default_container = None else: _thread_local_fallback.container = current self._default_container = current
# Decorator overloads @overload def resolve( self, key: None = None, *, scope: str, ) -> Callable[[Callable[..., Any]], Any]: ... @overload def resolve( self, key: None = None, *, scope: None = None, ) -> Callable[[Callable[..., Any]], Any]: ... @overload def resolve(self, key: type[T], *, scope: None = None) -> T: ... @overload def resolve(self, key: type[T], *, scope: str) -> T: ... @overload def resolve(self, key: Callable[..., Any], *, scope: None = None) -> Any: ... @overload def resolve(self, key: Callable[..., Any], *, scope: str) -> Any: ... @overload def resolve(self, key: Any, *, scope: str | None = None) -> Any: ...
[docs] def resolve(self, key: Any | None = None, *, scope: str | None = None) -> Any: """Resolve a service or create a dependency-injected wrapper. When called with key=None, returns a decorator that can be applied to functions to enable dependency injection with lazy container lookup. When called with a type, resolves and returns a service instance from the current container. Args: key: The service key to resolve, or None for decorator usage. scope: Optional scope name for scoped resolution. Returns: A service instance, or a wrapper for function decoration. Examples: .. code-block:: python # Decorator usage (container looked up at call time): @container_context.resolve(scope="request") async def handler(service: Annotated[Service, Injected()]) -> dict: ... # Direct resolution: service = container_context.resolve(Service) """ # DECORATOR PATTERN: resolve(scope="...") or resolve() returns decorator if key is None: def decorator(func: Callable[..., Any]) -> Any: return self.resolve(func, scope=scope) return decorator # For callable types (functions), create lazy wrappers if callable(key) and not isinstance(key, type): is_async_func = inspect.iscoroutinefunction(key) if scope is not None: if is_async_func: return _AsyncContextScopedInjected(key, self, scope) return _ContextScopedInjected(key, self, scope) if is_async_func: return _AsyncContextInjected(key, self) return _ContextInjected(key, self) # For types and other keys, delegate to the current container return self.get_current().resolve(key, scope=scope)
[docs] def aresolve(self, key: type[T], *, scope: str | None = None) -> Coroutine[Any, Any, T]: """Asynchronously resolve a service from the current container. Args: key: The service key to resolve. scope: Optional scope name for scoped resolution. Returns: A coroutine that resolves to the service instance. """ return self.get_current().aresolve(key, scope=scope)
# Register overloads mirror Container.register for API parity. @overload def register(self, key: _C, /) -> _C: ... @overload def register(self, key: Callable[..., T], /) -> Callable[..., T]: ... @overload def register( self, key: None = None, /, factory: None = None, instance: None = None, lifetime: Lifetime = ..., scope: str | None = ..., is_async: bool | None = ..., # noqa: FBT001 concrete_class: type | None = ..., ) -> Callable[[T], T]: ... @overload def register( self, key: type, /, *, lifetime: Lifetime = ..., scope: str | None = ..., is_async: bool | None = ..., ) -> Callable[[T], T]: ... @overload def register( self, key: str, /, *, lifetime: Lifetime = ..., scope: str | None = ..., is_async: bool | None = ..., ) -> Callable[[T], T]: ... @overload def register( self, key: Any, /, factory: Factory | None = ..., instance: Any | None = ..., lifetime: Lifetime = ..., scope: str | None = ..., is_async: bool | None = ..., # noqa: FBT001 concrete_class: type | None = ..., ) -> None: ...
[docs] def register( # noqa: PLR0913, C901, PLR0911 self, key: Any | None = None, /, factory: Factory | None = None, instance: Any | None = None, lifetime: Lifetime = Lifetime.TRANSIENT, scope: str | None = None, is_async: bool | None = None, # noqa: FBT001 concrete_class: type | None = None, ) -> Any: """Register a service with the current container. Supports the same decorator and direct-call patterns as Container.register. If no container is set, registration is deferred until set_current(). """ container = self._get_current_or_none() if container is not None: return container.register( key, factory=factory, instance=instance, lifetime=lifetime, scope=scope, is_async=is_async, concrete_class=concrete_class, ) all_params_at_defaults = ( factory is None and instance is None and lifetime == Lifetime.TRANSIENT and scope is None and is_async is None and concrete_class is None ) if key is None: def decorator(target: T) -> T: current = self._get_current_or_none() if current is not None: register_decorator = current.register( lifetime=lifetime, scope=scope, is_async=is_async, concrete_class=concrete_class, ) register_decorator(target) return target self._deferred_registrations.append( _DeferredRegistration( key=target, factory=None, instance=None, lifetime=lifetime, scope=scope, is_async=is_async, concrete_class=concrete_class, via_decorator=True, ), ) return target return decorator # Case: Type as key (could be bare decorator, interface decorator, or factory) # When container is available, delegate to it. Otherwise defer registration. if ( isinstance(key, type) and factory is None and instance is None and concrete_class is None ): # If container is available, delegate to it (uses proxy pattern) # Note: This path is tested in TestDeferredRegistrationWithTypeKey but # coverage measurement seems to have timing issues with it. current = self._get_current_or_none() if current is not None: # pragma: no cover return current.register( key, lifetime=lifetime, scope=scope, is_async=is_async, ) # No container - use deferred registration # For bare decorator (all defaults), defer and return the class directly if all_params_at_defaults: self._deferred_registrations.append( _DeferredRegistration( key=key, factory=None, instance=None, lifetime=lifetime, scope=scope, is_async=is_async, concrete_class=None, via_decorator=False, ), ) return key # Non-default params - return a decorator for interface/factory patterns interface_key = key def type_decorator(target: T) -> T: current = self._get_current_or_none() if current is not None: # Delegate to container which has the smart decorator logic register_decorator = current.register( interface_key, lifetime=lifetime, scope=scope, is_async=is_async, ) register_decorator(target) return target # Deferred registration - determine what to register if isinstance(target, type): if target is interface_key: # Bare decorator: @container_context.register on the same class self._deferred_registrations.append( _DeferredRegistration( key=target, factory=None, instance=None, lifetime=lifetime, scope=scope, is_async=is_async, concrete_class=None, via_decorator=False, ), ) else: # Interface registration: different class self._deferred_registrations.append( _DeferredRegistration( key=interface_key, factory=None, instance=None, lifetime=lifetime, scope=scope, is_async=is_async, concrete_class=target, via_decorator=False, ), ) else: # Factory function - need to defer with factory self._deferred_registrations.append( _DeferredRegistration( key=interface_key, factory=target, # type: ignore[arg-type] instance=None, lifetime=lifetime, scope=scope, is_async=is_async, concrete_class=None, via_decorator=False, ), ) return target return type_decorator is_factory_function = ( callable(key) and not isinstance(key, type) and get_origin(key) is None and ( inspect.isfunction(key) or inspect.ismethod(key) or inspect.iscoroutinefunction(key) ) ) is_decorator_target = all_params_at_defaults and ( isinstance(key, staticmethod) or is_factory_function ) if is_decorator_target: self._deferred_registrations.append( _DeferredRegistration( key=key, factory=factory, instance=instance, lifetime=lifetime, scope=scope, is_async=is_async, concrete_class=concrete_class, via_decorator=False, ), ) return key self._deferred_registrations.append( _DeferredRegistration( key=key, factory=factory, instance=instance, lifetime=lifetime, scope=scope, is_async=is_async, concrete_class=concrete_class, via_decorator=False, ), ) return None
def enter_scope(self, scope_name: str | None = None) -> ScopedContainer: """Start a new scope on the current container. Args: scope_name: Optional name for the scope. Returns: A ScopedContainer context manager. """ return self.get_current().enter_scope(scope_name) def compile(self) -> None: """Compile the current container for optimized resolution.""" return self.get_current().compile()
[docs] def close(self) -> None: """Close the current container. Closes all active scopes and marks the container as closed. After calling this method, any attempt to resolve services or start new scopes will raise DIWireContainerClosedError. """ return self.get_current().close()
[docs] async def aclose(self) -> None: """Asynchronously close the current container. Use this method when scopes contain async generator factories that need proper async cleanup. """ return await self.get_current().aclose()
[docs] def close_scope(self, scope_name: str) -> None: """Close all active scopes that contain the given scope name. This closes the named scope and all its child scopes in LIFO order (children first, then parents). Args: scope_name: The name of the scope to close. """ return self.get_current().close_scope(scope_name)
[docs] async def aclose_scope(self, scope_name: str) -> None: """Asynchronously close all active scopes that contain the given scope name. This closes the named scope and all its child scopes in LIFO order (children first, then parents). Args: scope_name: The name of the scope to close. """ return await self.get_current().aclose_scope(scope_name)
container_context = _ContainerContextProxy()