Source code for diwire._internal.resolver_context

from __future__ import annotations

import functools
import inspect
from collections.abc import Awaitable, Callable, Mapping
from contextvars import ContextVar, Token
from dataclasses import dataclass
from types import TracebackType
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload

from diwire._internal.injection import (
    INJECT_CONTEXT_KWARG,
    INJECT_RESOLVER_KWARG,
    INJECT_WRAPPER_MARKER,
    InjectedCallableInspector,
)
from diwire._internal.policies import DependencyRegistrationPolicy
from diwire._internal.resolvers.protocol import ResolverProtocol
from diwire._internal.scope import BaseScope
from diwire.exceptions import DIWireInvalidRegistrationError, DIWireResolverNotSetError

if TYPE_CHECKING:
    from typing_extensions import Self

    from diwire._internal.container import Container


T = TypeVar("T")
InjectableF = TypeVar("InjectableF", bound=Callable[..., Any])


@dataclass(frozen=True, slots=True)
class _InjectInvocationState:
    source: Literal["explicit", "context", "fallback"]
    context_resolver: ResolverProtocol | None


@dataclass(frozen=True, slots=True)
class _InjectWrapperConfig:
    callable_obj: Callable[..., Any]
    scope: BaseScope | None
    dependency_registration_policy: DependencyRegistrationPolicy | None
    auto_open_scope: bool


class _ResolverBoundResolver:
    """Resolver wrapper that synchronizes resolver context with ResolverContext."""

    def __init__(
        self,
        *,
        resolver: ResolverProtocol,
        resolver_context: ResolverContext,
        push_resolver: Callable[[ResolverProtocol], None],
        pop_resolver: Callable[[], None],
    ) -> None:
        self._resolver = resolver
        self._resolver_context = resolver_context
        self._push_resolver = push_resolver
        self._pop_resolver = pop_resolver

    def __getattr__(self, name: str) -> Any:
        return getattr(self._resolver, name)

    @overload
    def resolve(self, dependency: type[T]) -> T: ...

    @overload
    def resolve(self, dependency: Any) -> Any: ...

    def resolve(self, dependency: Any) -> Any:
        return self._resolver.resolve(dependency)

    @overload
    async def aresolve(self, dependency: type[T]) -> T: ...

    @overload
    async def aresolve(self, dependency: Any) -> Any: ...

    async def aresolve(self, dependency: Any) -> Any:
        return await self._resolver.aresolve(dependency)

    def enter_scope(
        self,
        scope: BaseScope | None = None,
        *,
        context: Mapping[Any, Any] | None = None,
    ) -> _ResolverBoundResolver:
        scoped_resolver = self._resolver.enter_scope(scope, context=context)
        if scoped_resolver is self._resolver:
            return self
        return _ResolverBoundResolver(
            resolver=scoped_resolver,
            resolver_context=self._resolver_context,
            push_resolver=self._push_resolver,
            pop_resolver=self._pop_resolver,
        )

    def __enter__(self) -> Self:
        self._resolver.__enter__()
        self._push_resolver(cast("ResolverProtocol", self))
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> None:
        try:
            self._resolver.__exit__(exc_type, exc_value, traceback)
        finally:
            self._pop_resolver()

    async def __aenter__(self) -> Self:
        await cast("Any", self._resolver).__aenter__()
        self._push_resolver(cast("ResolverProtocol", self))
        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> None:
        try:
            await self._resolver.__aexit__(exc_type, exc_value, traceback)
        finally:
            self._pop_resolver()

    def close(
        self,
        exc_type: type[BaseException] | None = None,
        exc_value: BaseException | None = None,
        traceback: TracebackType | None = None,
    ) -> None:
        self._resolver.close(exc_type, exc_value, traceback)

    async def aclose(
        self,
        exc_type: type[BaseException] | None = None,
        exc_value: BaseException | None = None,
        traceback: TracebackType | None = None,
    ) -> None:
        await self._resolver.aclose(exc_type, exc_value, traceback)


[docs] class ResolverContext: """Task/thread-safe context for resolver-bound injection and resolution.""" __slots__ = ( "_current_resolver_var", "_fallback_container", "_injected_callable_inspector", "_token_stack_var", ) def __init__(self) -> None: self._current_resolver_var: ContextVar[ResolverProtocol | None] = ContextVar( "diwire_resolver_context_resolver", default=None, ) self._token_stack_var: ContextVar[tuple[Token[ResolverProtocol | None], ...]] = ContextVar( "diwire_resolver_context_tokens", default=(), ) self._fallback_container: Container | None = None self._injected_callable_inspector = InjectedCallableInspector() def set_fallback_container(self, container: Container) -> None: """Set the fallback container used when no resolver is bound. Resolver-bound contexts always take precedence over this fallback. Args: container: Container used to compile fallback resolvers and injected callables when no context resolver is active. """ self._set_fallback_container(container) def _push(self, resolver: ResolverProtocol) -> None: token = self._current_resolver_var.set(resolver) tokens = self._token_stack_var.get() self._token_stack_var.set((*tokens, token)) def _pop(self) -> None: tokens = self._token_stack_var.get() if not tokens: return token = tokens[-1] self._token_stack_var.set(tokens[:-1]) self._current_resolver_var.reset(token) def _set_fallback_container(self, container: Container) -> None: self._fallback_container = container def _get_bound_resolver_or_none(self) -> ResolverProtocol | None: return self._current_resolver_var.get() def _require_context_or_fallback_resolver(self) -> ResolverProtocol: resolver = self._get_bound_resolver_or_none() if resolver is not None: return resolver fallback_resolver = self._get_fallback_resolver_or_none() if fallback_resolver is not None: return fallback_resolver msg = ( "Resolver is not set for resolver_context. Enter a compiled resolver context " "before using resolver_context." ) raise DIWireResolverNotSetError(msg) def _get_fallback_resolver_or_none(self) -> ResolverProtocol | None: fallback_container = self._fallback_container if fallback_container is None: return None fallback_resolver = fallback_container.compile() return self._wrap_resolver(fallback_resolver) def _wrap_resolver(self, resolver: ResolverProtocol) -> ResolverProtocol: resolver_any = cast("Any", resolver) if isinstance(resolver_any, _ResolverBoundResolver): return cast("ResolverProtocol", resolver_any) return cast( "ResolverProtocol", _ResolverBoundResolver( resolver=resolver, resolver_context=self, push_resolver=self._push, pop_resolver=self._pop, ), ) @overload def resolve(self, dependency: type[T]) -> T: ... @overload def resolve(self, dependency: Any) -> Any: ...
[docs] def resolve(self, dependency: Any) -> Any: """Resolve a dependency from the active resolver or fallback container. Args: dependency: Dependency key to resolve. Raises: DIWireResolverNotSetError: If no resolver is bound and no fallback container is configured. """ return self._require_context_or_fallback_resolver().resolve(dependency)
@overload async def aresolve(self, dependency: type[T]) -> T: ... @overload async def aresolve(self, dependency: Any) -> Any: ...
[docs] async def aresolve(self, dependency: Any) -> Any: """Asynchronously resolve a dependency from context or fallback. Args: dependency: Dependency key to resolve. Raises: DIWireResolverNotSetError: If no resolver is bound and no fallback container is configured. """ return await self._require_context_or_fallback_resolver().aresolve(dependency)
[docs] def enter_scope( self, scope: BaseScope | None = None, *, context: Mapping[Any, Any] | None = None, ) -> ResolverProtocol: """Enter a child scope on the active resolver or fallback resolver. Args: scope: Target scope to enter. ``None`` keeps the current scope. context: Optional context payload merged into the entered scope. Raises: DIWireResolverNotSetError: If no resolver is bound and no fallback container is configured. """ return self._require_context_or_fallback_resolver().enter_scope(scope, context=context)
@overload def inject(self, func: InjectableF) -> InjectableF: ... @overload def inject( self, func: Literal["from_decorator"] = "from_decorator", *, scope: BaseScope | Literal["infer"] = "infer", dependency_registration_policy: ( DependencyRegistrationPolicy | Literal["from_container"] ) = "from_container", auto_open_scope: bool = True, ) -> Callable[[InjectableF], InjectableF]: ...
[docs] def inject( self, func: InjectableF | Literal["from_decorator"] = "from_decorator", *, scope: BaseScope | Literal["infer"] = "infer", dependency_registration_policy: ( DependencyRegistrationPolicy | Literal["from_container"] ) = "from_container", auto_open_scope: bool = True, ) -> InjectableF | Callable[[InjectableF], InjectableF]: """Wrap callables so ``Injected[...]`` parameters resolve at invocation. Resolution precedence at call time is explicit ``diwire_resolver``, then an active context-bound resolver, then the configured fallback container when it is configured with ``use_resolver_context=True``. Args: func: Callable to wrap directly, or ``"from_decorator"`` when used as ``@resolver_context.inject(...)``. scope: Explicit scope for wrapper generation, or ``"infer"`` to infer from injected dependencies. dependency_registration_policy: Dependency autoregistration policy for wrapper generation, or ``"from_container"`` to inherit the fallback container setting. auto_open_scope: Whether invocation should auto-enter scopes when needed. When ``True`` (default), scope entry is attempted only when moving into a deeper target scope is valid and required. If the target scope is already open, no additional scope is entered. If the current resolver is already deeper than the target scope, no additional scope is entered and resolution proceeds from the current resolver (including its existing scope-context chain). Raises: DIWireInvalidRegistrationError: If inject configuration values are invalid, ``func`` is not callable, or the callable uses reserved parameter names. DIWireResolverNotSetError: If invocation has no explicit resolver, no active context resolver, and no fallback container eligible for inject fallback. """ resolved_scope = self._resolve_inject_scope(scope) resolved_dependency_registration_policy = ( self._resolve_inject_dependency_registration_policy( dependency_registration_policy=dependency_registration_policy, ) ) def decorator(callable_obj: InjectableF) -> InjectableF: self._validate_injected_callable_signature(callable_obj) inspected_callable = self._injected_callable_inspector.inspect_callable(callable_obj) cache: dict[Container, Callable[..., Any]] = {} wrapper_config = _InjectWrapperConfig( callable_obj=callable_obj, scope=resolved_scope, dependency_registration_policy=resolved_dependency_registration_policy, auto_open_scope=auto_open_scope, ) fallback_container = self._fallback_container if fallback_container is not None: self._get_cached_injected_callable( cache=cache, container=fallback_container, wrapper_config=wrapper_config, ) invocation = self._build_injection_invoker( cache=cache, wrapper_config=wrapper_config, ) wrapped_callable = self._wrap_injected_callable( callable_obj=callable_obj, invocation=invocation, ) wrapped_callable.__signature__ = inspected_callable.public_signature # type: ignore[attr-defined] wrapped_callable.__dict__[INJECT_WRAPPER_MARKER] = True return cast("InjectableF", wrapped_callable) func_value = cast("Any", func) if func_value == "from_decorator": return decorator if not callable(func_value): msg = "inject() parameter 'func' must be callable or 'from_decorator'." raise DIWireInvalidRegistrationError(msg) return decorator(func_value)
def _resolve_inject_scope( self, scope: BaseScope | Literal["infer"], ) -> BaseScope | None: scope_value = cast("Any", scope) if scope_value == "infer": return None if isinstance(scope_value, BaseScope): return scope_value msg = "inject() parameter 'scope' must be BaseScope or 'infer'." raise DIWireInvalidRegistrationError(msg) def _resolve_inject_dependency_registration_policy( self, *, dependency_registration_policy: (DependencyRegistrationPolicy | Literal["from_container"]), ) -> DependencyRegistrationPolicy | None: dependency_registration_policy_value = cast("Any", dependency_registration_policy) if dependency_registration_policy_value == "from_container": return None if isinstance(dependency_registration_policy_value, DependencyRegistrationPolicy): return dependency_registration_policy_value msg = ( "inject() parameter 'dependency_registration_policy' must be " "DependencyRegistrationPolicy or 'from_container'." ) raise DIWireInvalidRegistrationError(msg) def _validate_injected_callable_signature(self, callable_obj: Callable[..., Any]) -> None: signature = inspect.signature(callable_obj) if INJECT_RESOLVER_KWARG in signature.parameters: msg = ( f"Callable '{self._callable_name(callable_obj)}' cannot declare reserved " f"parameter '{INJECT_RESOLVER_KWARG}'." ) raise DIWireInvalidRegistrationError(msg) if INJECT_CONTEXT_KWARG in signature.parameters: msg = ( f"Callable '{self._callable_name(callable_obj)}' cannot declare reserved " f"parameter '{INJECT_CONTEXT_KWARG}'." ) raise DIWireInvalidRegistrationError(msg) def _build_injection_invoker( self, *, cache: dict[Container, Callable[..., Any]], wrapper_config: _InjectWrapperConfig, ) -> Callable[..., Any]: def _invoke(*args: Any, **kwargs: Any) -> Any: state = self._resolve_injection_state(kwargs) fallback_container = self._require_inject_fallback_container() injected_callable = self._get_cached_injected_callable( cache=cache, container=fallback_container, wrapper_config=wrapper_config, ) runtime_kwargs = kwargs if state.source == "context": runtime_kwargs = dict(kwargs) runtime_kwargs[INJECT_RESOLVER_KWARG] = cast( "ResolverProtocol", state.context_resolver, ) return injected_callable(*args, **runtime_kwargs) return _invoke def _resolve_injection_state(self, kwargs: dict[str, Any]) -> _InjectInvocationState: if INJECT_RESOLVER_KWARG in kwargs: return _InjectInvocationState(source="explicit", context_resolver=None) context_resolver = self._get_bound_resolver_or_none() if context_resolver is not None: return _InjectInvocationState(source="context", context_resolver=context_resolver) fallback_container = self._fallback_container if fallback_container is not None: fallback_uses_resolver_context = bool( getattr(fallback_container, "_use_resolver_context", True) ) if fallback_uses_resolver_context: fallback_container.compile() return _InjectInvocationState(source="fallback", context_resolver=None) msg = ( "Resolver is not set for resolver_context.inject. Pass " f"'{INJECT_RESOLVER_KWARG}' explicitly, enter a resolver context, " "or initialize a fallback container with use_resolver_context=True." ) raise DIWireResolverNotSetError(msg) def _require_inject_fallback_container(self) -> Container: fallback_container = self._fallback_container if fallback_container is None: msg = ( "ResolverContext.inject requires a fallback container. Initialize a container " "with this ResolverContext before decorating callables." ) raise DIWireResolverNotSetError(msg) return fallback_container def _get_cached_injected_callable( self, *, cache: dict[Container, Callable[..., Any]], container: Container, wrapper_config: _InjectWrapperConfig, ) -> Callable[..., Any]: injected_callable = cache.get(container) if injected_callable is not None: return injected_callable injected_callable = container._inject_callable( # noqa: SLF001 callable_obj=wrapper_config.callable_obj, scope=wrapper_config.scope, dependency_registration_policy=wrapper_config.dependency_registration_policy, auto_open_scope=wrapper_config.auto_open_scope, ) cache[container] = injected_callable return injected_callable def _wrap_injected_callable( self, *, callable_obj: InjectableF, invocation: Callable[..., Any], ) -> Callable[..., Any]: if inspect.iscoroutinefunction(callable_obj): @functools.wraps(callable_obj) async def _async_injected(*args: Any, **kwargs: Any) -> Any: result = invocation(*args, **kwargs) return await cast("Awaitable[Any]", result) return _async_injected @functools.wraps(callable_obj) def _sync_injected(*args: Any, **kwargs: Any) -> Any: return invocation(*args, **kwargs) return _sync_injected def _callable_name(self, callable_obj: Callable[..., Any]) -> str: return getattr(callable_obj, "__qualname__", repr(callable_obj))
resolver_context = ResolverContext()