from __future__ import annotations
import asyncio
import contextlib
import inspect
import itertools
import threading
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator, Iterator, MutableMapping
from contextlib import AsyncExitStack, ExitStack
from types import FunctionType, MethodType
from typing import (
Any,
ClassVar,
TypeVar,
cast,
get_origin,
overload,
)
from diwire.compiled_providers import (
ArgsTypeProvider,
CompiledProvider,
FactoryProvider,
InstanceProvider,
PositionalArgsTypeProvider,
ScopedSingletonArgsProvider,
ScopedSingletonPositionalArgsProvider,
ScopedSingletonProvider,
SingletonArgsTypeProvider,
SingletonFactoryProvider,
SingletonPositionalArgsTypeProvider,
SingletonTypeProvider,
TypeProvider,
)
from diwire.container_helpers import (
_get_generic_origin_and_args,
_get_return_annotation,
_is_any_type,
_is_async_factory,
_is_method_descriptor,
_is_typevar,
_is_union_type,
_OpenGenericRegistration,
_ResolvedDependencies,
_type_arg_matches_constraint,
_unwrap_method_descriptor,
)
from diwire.container_injection import (
_AsyncInjectedFunction,
_AsyncScopedInjectedFunction,
_InjectedFunction,
_ScopedInjectedFunction,
)
from diwire.container_locks import LockManager
from diwire.container_resolution_stack import _get_resolution_stack
from diwire.container_scopes import ScopedContainer, _current_scope, _ScopeId
from diwire.defaults import (
DEFAULT_AUTOREGISTER_IGNORES,
DEFAULT_AUTOREGISTER_LIFETIME,
DEFAULT_AUTOREGISTER_REGISTRATION_FACTORIES,
)
from diwire.dependencies import DependenciesExtractor, ParameterInfo
from diwire.exceptions import (
DIWireAsyncCleanupWithoutEventLoopError,
DIWireAsyncDependencyInSyncContextError,
DIWireAsyncGeneratorFactoryDidNotYieldError,
DIWireAsyncGeneratorFactoryWithoutScopeError,
DIWireCircularDependencyError,
DIWireComponentSpecifiedError,
DIWireConcreteClassRequiresClassError,
DIWireContainerClosedError,
DIWireDecoratorFactoryMissingReturnAnnotationError,
DIWireError,
DIWireGeneratorFactoryDidNotYieldError,
DIWireGeneratorFactoryUnsupportedLifetimeError,
DIWireGeneratorFactoryWithoutScopeError,
DIWireIgnoredServiceError,
DIWireInvalidGenericTypeArgumentError,
DIWireMissingDependenciesError,
DIWireNotAClassError,
DIWireOpenGenericRegistrationError,
DIWireOpenGenericResolutionError,
DIWireScopedWithoutScopeError,
DIWireScopeMismatchError,
DIWireServiceNotRegisteredError,
DIWireUnionTypeError,
)
from diwire.registry import Registration
from diwire.service_key import Component, ServiceKey
from diwire.types import Factory, Lifetime
T = TypeVar("T", bound=Any)
_C = TypeVar("_C", bound=type) # For class decorator
class _ScopedCacheView(MutableMapping[ServiceKey, Any]):
"""View for scoped caches backed by a per-scope cache."""
__slots__ = ("_cache", "_lock", "_type_cache")
def __init__(
self,
cache: dict[ServiceKey, Any],
type_cache: dict[type, Any],
lock: threading.RLock | None,
) -> None:
self._cache = cache
self._type_cache = type_cache
self._lock = lock
def get(self, key: ServiceKey, default: Any | None = None) -> Any | None:
if key.is_type_key:
type_cache = self._type_cache
cached = type_cache.get(key.value)
if cached is not None:
return cached
cached = self._cache.get(key)
if cached is not None:
type_cache[key.value] = cached
return cached
return default
return self._cache.get(key, default)
def __getitem__(self, key: ServiceKey) -> Any:
if key.is_type_key:
type_cache = self._type_cache
cached = type_cache.get(key.value)
if cached is not None:
return cached
value = self._cache[key]
type_cache[key.value] = value
return value
return self._cache[key]
def __setitem__(self, key: ServiceKey, value: Any) -> None:
if key.is_type_key:
self._type_cache[key.value] = value
self._cache[key] = value
def __delitem__(self, key: ServiceKey) -> None:
del self._cache[key]
if key.is_type_key:
self._type_cache.pop(key.value, None)
def __iter__(self) -> Iterator[ServiceKey]:
return iter(self._cache)
def __len__(self) -> int:
return len(self._cache)
def get_or_create(self, key: ServiceKey, factory: Callable[[], Any]) -> Any:
cache = self._cache
type_cache: dict[type, Any] | None = None
if key.is_type_key:
type_cache = self._type_cache
cached = type_cache.get(key.value)
if cached is not None:
return cached
cached = cache.get(key)
if cached is not None:
type_cache[key.value] = cached
return cached
else:
cached = cache.get(key)
if cached is not None:
return cached
if self._lock is None:
instance = factory()
cache[key] = instance
if type_cache is not None:
type_cache[key.value] = instance
return instance
with self._lock:
if type_cache is not None:
cached = type_cache.get(key.value)
if cached is None:
cached = cache.get(key)
if cached is not None:
type_cache[key.value] = cached
return cached
else:
return cached
else:
cached = cache.get(key)
if cached is not None:
return cached
instance = factory()
cache[key] = instance
if type_cache is not None:
type_cache[key.value] = instance
return instance
def get_or_create_positional(
self,
key: ServiceKey,
constructor: type,
providers: tuple[CompiledProvider, ...],
singletons: dict[ServiceKey, Any],
) -> Any:
cache = self._cache
type_cache: dict[type, Any] | None = None
if key.is_type_key:
type_cache = self._type_cache
cached = type_cache.get(key.value)
if cached is not None:
return cached
cached = cache.get(key)
if cached is not None:
type_cache[key.value] = cached
return cached
else:
cached = cache.get(key)
if cached is not None:
return cached
if self._lock is None:
instance = constructor(*[provider(singletons, self) for provider in providers])
cache[key] = instance
if type_cache is not None:
type_cache[key.value] = instance
return instance
with self._lock:
if type_cache is not None:
cached = type_cache.get(key.value)
if cached is None:
cached = cache.get(key)
if cached is not None:
type_cache[key.value] = cached
return cached
else:
return cached
else:
cached = cache.get(key)
if cached is not None:
return cached
instance = constructor(*[provider(singletons, self) for provider in providers])
cache[key] = instance
if type_cache is not None:
type_cache[key.value] = instance
return instance
def get_or_create_kwargs(
self,
key: ServiceKey,
constructor: type,
items: tuple[tuple[str, CompiledProvider], ...],
singletons: dict[ServiceKey, Any],
) -> Any:
cache = self._cache
type_cache: dict[type, Any] | None = None
if key.is_type_key:
type_cache = self._type_cache
cached = type_cache.get(key.value)
if cached is not None:
return cached
cached = cache.get(key)
if cached is not None:
type_cache[key.value] = cached
return cached
else:
cached = cache.get(key)
if cached is not None:
return cached
if self._lock is None:
args = {name: provider(singletons, self) for name, provider in items}
instance = constructor(**args)
cache[key] = instance
if type_cache is not None:
type_cache[key.value] = instance
return instance
with self._lock:
if type_cache is not None:
cached = type_cache.get(key.value)
if cached is None:
cached = cache.get(key)
if cached is not None:
type_cache[key.value] = cached
return cached
else:
return cached
else:
cached = cache.get(key)
if cached is not None:
return cached
args = {name: provider(singletons, self) for name, provider in items}
instance = constructor(**args)
cache[key] = instance
if type_cache is not None:
type_cache[key.value] = instance
return instance
[docs]
class Container:
"""Dependency injection container for registering and resolving services.
Supports automatic registration, lifetime singleton/transient, and factory patterns.
"""
# Class-level counter for generating unique scope IDs (faster than UUID)
_scope_counter: ClassVar[itertools.count[int]] = itertools.count()
__slots__ = (
"_active_scopes",
"_active_scopes_lock",
"_async_deps_cache",
"_async_scope_exit_stacks",
"_auto_compile",
"_autoregister",
"_autoregister_default_lifetime",
"_autoregister_ignores",
"_autoregister_registration_factories",
"_cleanup_tasks",
"_closed",
"_compiled_providers",
"_dependencies_extractor",
"_has_scoped_registrations",
"_is_compiled",
"_locks",
"_multithreaded",
"_open_generic_registry",
"_registry",
"_scope_cache_locks",
"_scope_caches",
"_scope_exit_stacks",
"_scope_type_caches",
"_scoped_cache_views",
"_scoped_cache_views_nolock",
"_scoped_compiled_providers",
"_scoped_compiled_providers_by_scope",
"_scoped_open_generic_registry",
"_scoped_registry",
"_scoped_type_providers",
"_scoped_type_providers_by_scope",
"_singletons",
"_thread_id",
"_type_providers",
"_type_singletons",
)
[docs]
def __init__(
self,
*,
autoregister: bool = True,
autoregister_ignores: set[type[Any]] | None = None,
autoregister_registration_factories: dict[type[Any], Callable[[Any], Registration]]
| None = None,
autoregister_default_lifetime: Lifetime = DEFAULT_AUTOREGISTER_LIFETIME,
auto_compile: bool = True,
) -> None:
self._autoregister = autoregister
self._autoregister_ignores = autoregister_ignores or DEFAULT_AUTOREGISTER_IGNORES
self._autoregister_registration_factories = (
autoregister_registration_factories or DEFAULT_AUTOREGISTER_REGISTRATION_FACTORIES
)
self._autoregister_default_lifetime = autoregister_default_lifetime
self._auto_compile = auto_compile
self._singletons: dict[ServiceKey, Any] = {}
self._scoped_cache_views: dict[tuple[tuple[str | None, int], ...], _ScopedCacheView] = {}
self._scoped_cache_views_nolock: dict[
tuple[tuple[str | None, int], ...],
_ScopedCacheView,
] = {}
self._scope_caches: dict[tuple[tuple[str | None, int], ...], dict[ServiceKey, Any]] = {}
self._scope_type_caches: dict[tuple[tuple[str | None, int], ...], dict[type, Any]] = {}
self._scope_cache_locks: dict[tuple[tuple[str | None, int], ...], threading.RLock] = {}
self._registry: dict[ServiceKey, Registration] = {}
self._scoped_registry: dict[tuple[ServiceKey, str], Registration] = {}
self._scoped_compiled_providers_by_scope: dict[str, dict[ServiceKey, CompiledProvider]] = {}
self._scoped_type_providers_by_scope: dict[str, dict[type, CompiledProvider]] = {}
self._open_generic_registry: dict[
tuple[type, Component | None],
_OpenGenericRegistration,
] = {}
self._scoped_open_generic_registry: dict[
tuple[type, Component | None, str],
_OpenGenericRegistration,
] = {}
# Scope exit stacks keyed by tuple for consistency
self._scope_exit_stacks: dict[tuple[tuple[str | None, int], ...], ExitStack] = {}
self._async_scope_exit_stacks: dict[tuple[tuple[str | None, int], ...], AsyncExitStack] = {}
# Background cleanup tasks (to prevent garbage collection)
self._cleanup_tasks: set[asyncio.Task[None]] = set()
self._dependencies_extractor = DependenciesExtractor()
# Compiled providers for optimized resolution
self._compiled_providers: dict[ServiceKey, CompiledProvider] = {}
# Compiled scoped providers: (service_key, scope_name) -> provider
self._scoped_compiled_providers: dict[tuple[ServiceKey, str], CompiledProvider] = {}
self._scoped_type_providers: dict[tuple[type, str], CompiledProvider] = {}
self._is_compiled: bool = False
# Fast type-based lookup caches (bypasses ServiceKey creation for simple types)
self._type_singletons: dict[type, Any] = {}
self._type_providers: dict[type, CompiledProvider] = {}
# Track if any scoped registrations exist to skip ContextVar lookups
self._has_scoped_registrations: bool = False
# Cache for async dependency info (Phase 4 optimization)
self._async_deps_cache: dict[ServiceKey, frozenset[ServiceKey]] = {}
# Lock manager for singleton and scoped singleton resolution
self._locks = LockManager()
# Track thread usage for locking decisions
self._thread_id = threading.get_ident()
self._multithreaded = False
# Track active scopes for imperative close()
self._active_scopes: list[ScopedContainer] = []
self._active_scopes_lock = threading.Lock()
self._closed = False
self.register(type(self), instance=self, lifetime=Lifetime.SINGLETON)
# Overload 1: Bare class decorator - @container.register
# Must come first to match the direct decorator pattern
@overload
def register(self, key: _C, /) -> _C: ...
# Overload 2: Bare factory function decorator - @container.register
@overload
def register(self, key: Callable[..., T], /) -> Callable[..., T]: ...
# Overload 3: Parameterized decorator without key - @container.register(lifetime=...)
# Returns a decorator that accepts classes or functions
@overload
def register(
self,
key: None = None,
/,
factory: None = None,
instance: None = None,
lifetime: Lifetime = ...,
scope: str | None = ...,
is_async: bool | None = ...,
concrete_class: type | None = ...,
) -> Callable[[T], T]: ...
# Overload 4: Interface decorator - @container.register(Interface, lifetime=...)
# When a type is passed with optional keyword args (no factory/instance/concrete_class),
# returns a decorator
@overload
def register(
self,
key: type,
/,
*,
lifetime: Lifetime = ...,
scope: str | None = ...,
is_async: bool | None = ...,
) -> Callable[[T], T]: ...
# Overload 5: String key decorator - @container.register("key", lifetime=...)
# When a string is passed as key with optional keyword args (no factory/instance/concrete_class),
# returns a decorator
@overload
def register(
self,
key: str,
/,
*,
lifetime: Lifetime = ...,
scope: str | None = ...,
is_async: bool | None = ...,
) -> Callable[[T], T]: ...
# Overload 6: Direct call with explicit key - container.register(Interface, concrete_class=...)
@overload
def register(
self,
key: Any,
/,
factory: Factory | None = ...,
instance: Any | None = ...,
lifetime: Lifetime = ...,
scope: str | None = ...,
is_async: bool | None = ...,
concrete_class: type | None = ...,
) -> None: ...
[docs]
def register(
self,
key: Any | None = None,
/,
factory: Factory | None = None,
instance: Any | None = None,
lifetime: Lifetime = Lifetime.TRANSIENT,
scope: str | None = None,
is_async: bool | None = None,
concrete_class: type | None = None,
) -> Any:
"""Register a service with the container.
Can be used as:
- Bare class decorator: @container.register
- Parameterized decorator: @container.register(lifetime=Lifetime.SINGLETON)
- Interface decorator: @container.register(IService) on a class
- Factory function decorator: @container.register (with return annotation)
- Direct call: container.register(IService, concrete_class=MyService)
Args:
key: The service key (interface/type) to register under. When None, returns
a decorator. When used with @container.register(Interface) on a class,
the decorated class becomes the implementation.
factory: Optional factory to create instances. Generator factories
(Generator[T, None, None] or AsyncGenerator[T, None]) are supported for
resource cleanup - the container calls close()/aclose() when the scope exits.
instance: Optional pre-created instance.
lifetime: The lifetime of the service. This default applies only to explicit
registrations via `register`; auto-registration uses
`autoregister_default_lifetime` from container configuration.
scope: Optional scope name for SCOPED services.
is_async: Whether the factory is async. If None, auto-detected from factory.
concrete_class: Optional concrete implementation class. When specified, `key`
is used as the interface and `concrete_class` is the implementation.
Returns:
- When used as a bare decorator on a class: returns the class unchanged
- When used as a parameterized decorator: returns a decorator function
- When used as a direct call: returns None
Raises:
DIWireScopedWithoutScopeError: If lifetime is SCOPED but
no scope is provided.
DIWireConcreteClassRequiresClassError: If `concrete_class` is not a class type.
DIWireDecoratorFactoryMissingReturnAnnotationError: If used as a factory decorator
but the function has no return annotation and no explicit key.
Note:
When using generator factories for cleanup, wrap cleanup code in try/finally:
.. code-block:: python
def my_factory() -> Generator[Resource, None, None]:
resource = acquire_resource()
try:
yield resource
finally:
resource.close() # MUST be in finally block
Without try/finally, cleanup code after yield will not execute when the
scope exits, as close()/aclose() raises GeneratorExit at the yield point.
"""
# Check if all optional params are at defaults (for bare decorator detection)
all_params_at_defaults = (
factory is None
and instance is None
and lifetime == Lifetime.TRANSIENT
and scope is None
and is_async is None
and concrete_class is None
)
# Case 1: Parameterized decorator without key - @container.register(lifetime=...)
if key is None:
return self._make_decorator(
lifetime=lifetime,
scope=scope,
is_async=is_async,
interface_key=None,
)
# Case 2: Open generic decorator - @container.register(MyGeneric[T])
if factory is None and instance is None and concrete_class is None:
origin, args = _get_generic_origin_and_args(key)
if origin is not None and any(_is_typevar(arg) for arg in args):
return self._make_decorator(
lifetime=lifetime,
scope=scope,
is_async=is_async,
interface_key=key,
)
# Case 2b: Concrete generic alias - @container.register(MyGeneric[int])
# This is a generic alias with all concrete type arguments (no TypeVars),
# used as an interface key for registering a concrete class or factory.
if origin is not None and args:
return self._make_decorator(
lifetime=lifetime,
scope=scope,
is_async=is_async,
interface_key=key,
)
# Case 3+4 merged: Type as key (could be bare decorator, interface decorator, or factory)
# Ambiguous case - we can't tell if this is:
# - @container.register on class (bare decorator) -> should register and return class
# - @container.register(Type) on function (factory) -> should register function as factory
# - @container.register(Type) on class (interface) -> should register class under Type
#
# Solution: Create a proxy class that acts as both the original class and a decorator,
# then register BOTH the original and proxy so resolution works with either key.
if (
isinstance(key, type)
and factory is None
and instance is None
and concrete_class is None
):
# Create a proxy class that inherits from original and can act as decorator
proxy_class = self._make_class_proxy_decorator(
original_class=key,
lifetime=lifetime,
scope=scope,
is_async=is_async,
)
# Register the proxy class (for decorator usage where proxy becomes the class)
self._do_register(
key=proxy_class,
factory=None,
instance=None,
lifetime=lifetime,
scope=scope,
is_async=is_async,
concrete_class=None,
)
# Also register the original class (for direct call usage)
# This allows container.register(Type, ...) and container.resolve(Type) to work
self._do_register(
key=key,
factory=None,
instance=None,
lifetime=lifetime,
scope=scope,
is_async=is_async,
concrete_class=proxy_class, # Use proxy as implementation
)
return proxy_class
# Case 3: Bare decorator on a function - @container.register
# Check that key is a proper function/method, not a generic alias like Annotated[T, ...]
is_factory_function = (
callable(key)
and not isinstance(key, type)
and get_origin(key) is None # Not a generic alias
and (
inspect.isfunction(key) or inspect.ismethod(key) or inspect.iscoroutinefunction(key)
)
)
if is_factory_function and all_params_at_defaults:
service_type = _get_return_annotation(key)
if service_type is None:
raise DIWireDecoratorFactoryMissingReturnAnnotationError(key)
self._do_register(
key=service_type,
factory=key,
instance=None,
lifetime=lifetime,
scope=scope,
is_async=is_async,
concrete_class=None,
)
return key
# Case 4: Bare decorator on a staticmethod - @staticmethod @container.register
if _is_method_descriptor(key) and all_params_at_defaults:
# _is_method_descriptor guarantees key is staticmethod,
# so unwrapped_func is always non-None
unwrapped_func, _ = _unwrap_method_descriptor(key)
service_type = _get_return_annotation(unwrapped_func) # type: ignore[arg-type]
if service_type is None:
raise DIWireDecoratorFactoryMissingReturnAnnotationError(unwrapped_func)
# unwrapped_func is guaranteed non-None since _is_method_descriptor passed
factory_func = cast("Callable[..., Any]", unwrapped_func)
self._do_register(
key=service_type,
factory=factory_func,
instance=None,
lifetime=lifetime,
scope=scope,
is_async=is_async,
concrete_class=None,
)
return key
# Case 5: Non-type key as decorator - @container.register("string_key")
# This handles string keys or other hashable values used as service identifiers.
# When key is not a type, function, method descriptor, or generic alias,
# and no factory/instance/concrete_class is provided, return a decorator.
if factory is None and instance is None and concrete_class is None:
return self._make_decorator(
lifetime=lifetime,
scope=scope,
is_async=is_async,
interface_key=key,
)
# Case 6: Direct call - container.register(Interface, concrete_class=Impl)
self._do_register(
key=key,
factory=factory,
instance=instance,
lifetime=lifetime,
scope=scope,
is_async=is_async,
concrete_class=concrete_class,
)
return None
def _make_class_proxy_decorator(
self,
original_class: type,
lifetime: Lifetime,
scope: str | None,
is_async: bool | None,
) -> type:
"""Create a class proxy that works as both the original class and a decorator.
This enables the ambiguous pattern where `@container.register(Type)` can be:
- A bare decorator on the Type itself (returns the class)
- A decorator factory for interface/factory registration (acts as decorator)
The proxy inherits from the original class so isinstance/issubclass checks work,
but its __new__ is overridden to handle decorator invocation.
"""
container = self
class _ClassProxyDecorator(original_class): # type: ignore[valid-type, misc]
"""Proxy class that inherits from original and can act as a decorator.
When instantiated with a type or callable (decorator pattern), it performs
registration. Otherwise, it creates instances of the proxy class (not original)
so isinstance checks work correctly.
"""
# Store reference to avoid closure issues
_original_class = original_class
_lifetime = lifetime
_scope = scope
_is_async = is_async
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
# Check if this is decorator invocation (single positional arg that's a type/callable)
if len(args) == 1 and not kwargs:
target = args[0]
if isinstance(target, type):
if target is cls._original_class:
# Same class as interface key - just return it (already registered)
return target
# Interface registration: @proxy(ImplClass)
container._do_register( # noqa: SLF001
key=cls._original_class,
factory=None,
instance=None,
lifetime=cls._lifetime,
scope=cls._scope,
is_async=cls._is_async,
concrete_class=target,
)
return target
if callable(target):
# Factory registration: @proxy(factory_func)
container._do_register( # noqa: SLF001
key=cls._original_class,
factory=target,
instance=None,
lifetime=cls._lifetime,
scope=cls._scope,
is_async=cls._is_async,
concrete_class=None,
)
return target
# Normal instantiation - create instance of THIS proxy class (not original)
# so that isinstance(instance, proxy_class) returns True
return object.__new__(cls)
@classmethod
def __class_getitem__(cls, item: Any) -> Any:
"""Forward generic subscripting to original class."""
return original_class[item] # type: ignore[index]
# Preserve class metadata
_ClassProxyDecorator.__name__ = original_class.__name__
_ClassProxyDecorator.__qualname__ = original_class.__qualname__
_ClassProxyDecorator.__module__ = original_class.__module__
_ClassProxyDecorator.__doc__ = original_class.__doc__
return _ClassProxyDecorator
def _make_decorator(
self,
lifetime: Lifetime,
scope: str | None,
is_async: bool | None,
interface_key: Any | None,
) -> Callable[[T], T]:
"""Create a decorator function for parameterized @container.register(...) usage.
Args:
lifetime: The lifetime of the service.
scope: Optional scope name for SCOPED services.
is_async: Whether the factory is async.
interface_key: If provided, the decorated class/factory will be registered
under this key (interface registration pattern).
"""
def decorator(target: T) -> T:
if isinstance(target, type):
# Class decoration
if interface_key is not None: # pragma: no cover
# Interface registration: @container.register(Interface, ...) on a different class
# Note: This path is only reachable for open generics, but open generics
# with concrete_class raise DIWireOpenGenericRegistrationError, making
# this effectively unreachable. Kept for API completeness.
self._do_register(
key=interface_key,
factory=None,
instance=None,
lifetime=lifetime,
scope=scope,
is_async=is_async,
concrete_class=target,
)
else:
# Regular class registration (interface_key is None)
self._do_register(
key=target,
factory=None,
instance=None,
lifetime=lifetime,
scope=scope,
is_async=is_async,
concrete_class=None,
)
elif _is_method_descriptor(target):
# staticmethod decoration
# _is_method_descriptor guarantees target is staticmethod,
# so unwrapped_func is always non-None
unwrapped_func, _ = _unwrap_method_descriptor(target)
service_type = interface_key
if service_type is None:
service_type = _get_return_annotation(unwrapped_func) # type: ignore[arg-type]
if service_type is None:
raise DIWireDecoratorFactoryMissingReturnAnnotationError(
unwrapped_func,
)
# unwrapped_func is guaranteed non-None since _is_method_descriptor passed
method_factory = cast("Callable[..., Any]", unwrapped_func)
self._do_register(
key=service_type,
factory=method_factory,
instance=None,
lifetime=lifetime,
scope=scope,
is_async=is_async,
concrete_class=None,
)
else:
# Factory function decoration - infer type from return annotation
service_type = interface_key
if service_type is None:
service_type = _get_return_annotation(target) # type: ignore[arg-type]
if service_type is None:
raise DIWireDecoratorFactoryMissingReturnAnnotationError(target)
self._do_register(
key=service_type,
factory=target,
instance=None,
lifetime=lifetime,
scope=scope,
is_async=is_async,
concrete_class=None,
)
return target
return decorator
def _get_open_generic_info_for_registration(
self,
service_key: ServiceKey,
) -> tuple[type, tuple[Any, ...]] | None:
origin, args = _get_generic_origin_and_args(service_key.value)
if origin is None or not args:
return None
if not any(_is_typevar(arg) for arg in args):
return None
if not all(_is_typevar(arg) for arg in args):
raise DIWireOpenGenericRegistrationError(
service_key,
"Open generic registrations must use only TypeVar parameters.",
)
return origin, tuple(args)
def _register_open_generic(
self,
*,
origin: type,
service_key: ServiceKey,
registration: Registration,
typevars: tuple[Any, ...],
) -> None:
entry = _OpenGenericRegistration(
service_key=service_key,
registration=registration,
typevars=typevars,
)
if registration.scope is not None:
self._scoped_open_generic_registry[
(origin, service_key.component, registration.scope)
] = entry
self._has_scoped_registrations = True
else:
self._open_generic_registry[(origin, service_key.component)] = entry
def _do_register(
self,
key: Any,
factory: Factory | None,
instance: Any | None,
lifetime: Lifetime,
scope: str | None,
is_async: bool | None,
concrete_class: type | None,
) -> None:
"""Perform the actual registration logic."""
# Determine service_key and concrete_type based on concrete_class parameter
if concrete_class is not None:
# Interface registration: key is the interface, concrete_class is the implementation
if not isinstance(concrete_class, type):
raise DIWireConcreteClassRequiresClassError(concrete_class)
service_key = ServiceKey.from_value(key)
concrete_type: type | None = concrete_class
else:
service_key = ServiceKey.from_value(key)
concrete_type = key if isinstance(key, type) else None
if lifetime == Lifetime.SCOPED and scope is None:
raise DIWireScopedWithoutScopeError(service_key)
# Auto-detect if factory is async when not explicitly specified
detected_is_async = False
if is_async is not None:
detected_is_async = is_async
elif factory is not None:
detected_is_async = _is_async_factory(factory)
open_generic_info = self._get_open_generic_info_for_registration(service_key)
if open_generic_info is not None:
if concrete_class is not None:
raise DIWireOpenGenericRegistrationError(
service_key,
"Open generic registrations with 'concrete_class' are not supported.",
)
if instance is not None:
raise DIWireOpenGenericRegistrationError(
service_key,
"Open generic registrations do not support instances.",
)
origin, typevars = open_generic_info
registration = Registration(
service_key=service_key,
factory=factory,
instance=instance,
lifetime=lifetime,
scope=scope,
is_async=detected_is_async,
concrete_type=concrete_type,
typevar_map=None,
)
self._register_open_generic(
origin=origin,
service_key=service_key,
registration=registration,
typevars=typevars,
)
# Track scoped registrations
if lifetime == Lifetime.SCOPED:
self._has_scoped_registrations = True
self._is_compiled = False
return
registration = Registration(
service_key=service_key,
factory=factory,
instance=instance,
lifetime=lifetime,
scope=scope,
is_async=detected_is_async,
concrete_type=concrete_type,
)
# If registering with an instance (non-scoped), update the singleton cache immediately
# This ensures re-registration overwrites any previously cached value
if instance is not None and scope is None:
self._singletons[service_key] = instance
# Also clear type cache for re-registration
if service_key.is_type_key:
self._type_singletons[service_key.value] = instance
if scope is not None:
# Store in scoped registry for scope-specific lookup
self._scoped_registry[(service_key, scope)] = registration
# Track that we have scoped registrations
self._has_scoped_registrations = True
else:
# Store in global registry
self._registry[service_key] = registration
# Track scoped singleton registrations
if lifetime == Lifetime.SCOPED:
self._has_scoped_registrations = True
# Invalidate compiled state when registrations change
self._is_compiled = False
[docs]
def enter_scope(self, scope_name: str | None = None) -> ScopedContainer:
"""Start a new scope for resolving SCOPED dependencies.
The scope is activated immediately upon creation, allowing imperative usage:
scope = container.enter_scope("request")
# ... use the scope ...
scope.close() # or container.close() to close all scopes
Context manager usage is also supported:
with container.enter_scope("request") as scope:
# ... use the scope ...
Args:
scope_name: Optional name for the scope. If not provided, an integer ID is generated.
Returns:
A ScopedContainer that is already activated.
Note:
Nested scopes inherit from parent scopes. A scope started within
another scope will have access to dependencies registered for the
parent scope.
"""
self._check_not_closed()
# Generate unique instance ID for each scope (integer is faster than UUID)
instance_id = next(self._scope_counter)
# Create new segment as tuple
new_segment = (scope_name, instance_id)
# Build scope by appending to current scope's segments
current = _current_scope.get()
segments = (*current.segments, new_segment) if current is not None else (new_segment,)
scope_id = _ScopeId(segments=segments)
return ScopedContainer(_container=self, _scope_id=scope_id)
def _clear_scope(self, scope_id: _ScopeId) -> None:
"""Clear cached instances for a scope.
Args:
scope_id: The scope ID to clear.
"""
scope_key = scope_id.segments
scope_exit_stack = self._scope_exit_stacks.pop(scope_key, None)
if scope_exit_stack is not None:
scope_exit_stack.close()
# Close async exit stack (if any async generators were resolved in this scope)
# Peek first without removing - only remove after successfully scheduling cleanup
async_exit_stack = self._async_scope_exit_stacks.get(scope_key)
if async_exit_stack is not None:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# No running event loop - leave stack in place for later _aclear_scope() call
scope_name = scope_key[-1][0] if scope_key else None
raise DIWireAsyncCleanupWithoutEventLoopError(scope_name) from None
# Event loop is running - schedule cleanup as a task
task = loop.create_task(async_exit_stack.aclose())
self._cleanup_tasks.add(task)
task.add_done_callback(self._cleanup_tasks.discard)
# Only remove after successfully scheduling cleanup
del self._async_scope_exit_stacks[scope_key]
self._scope_caches.pop(scope_key, None)
self._scope_type_caches.pop(scope_key, None)
self._scoped_cache_views.pop(scope_key, None)
self._scoped_cache_views_nolock.pop(scope_key, None)
self._scope_cache_locks.pop(scope_key, None)
self._locks.clear_scope_locks(scope_key)
def _get_scope_exit_stack(
self,
scope_key: tuple[tuple[str | None, int], ...],
) -> ExitStack:
scope_exit_stack = self._scope_exit_stacks.get(scope_key)
if scope_exit_stack is None:
scope_exit_stack = ExitStack()
self._scope_exit_stacks[scope_key] = scope_exit_stack
return scope_exit_stack
def _get_scoped_cache_view(
self,
scope_key: tuple[tuple[str | None, int], ...],
*,
use_lock: bool = True,
) -> _ScopedCacheView:
if use_lock:
cache = self._get_scope_cache(scope_key)
type_cache = self._get_scope_type_cache(scope_key)
lock = self._get_scope_cache_lock(scope_key)
return self._scoped_cache_views.setdefault(
scope_key,
_ScopedCacheView(cache, type_cache, lock),
)
cache = self._get_scope_cache(scope_key)
type_cache = self._get_scope_type_cache(scope_key)
return self._scoped_cache_views_nolock.setdefault(
scope_key,
_ScopedCacheView(cache, type_cache, None),
)
def _get_scope_cache(
self,
scope_key: tuple[tuple[str | None, int], ...],
) -> dict[ServiceKey, Any]:
return self._scope_caches.setdefault(scope_key, {})
def _get_scope_type_cache(
self,
scope_key: tuple[tuple[str | None, int], ...],
) -> dict[type, Any]:
return self._scope_type_caches.setdefault(scope_key, {})
def _get_scope_cache_lock(
self,
scope_key: tuple[tuple[str | None, int], ...],
) -> threading.RLock:
return self._scope_cache_locks.setdefault(scope_key, threading.RLock())
[docs]
def compile(self) -> None:
"""Compile all registered services into optimized providers.
This pre-compiles the dependency graph into specialized provider objects
that eliminate runtime reflection and minimize dict lookups. Call this
after all services have been registered for maximum performance.
"""
self._compiled_providers.clear()
self._scoped_compiled_providers.clear()
self._scoped_compiled_providers_by_scope.clear()
self._scoped_type_providers.clear()
self._scoped_type_providers_by_scope.clear()
self._type_providers.clear()
self._type_singletons.clear()
self._async_deps_cache.clear()
# Iterate over a copy since _compile_or_get_provider may add to registry
for service_key, registration in list(self._registry.items()):
provider = self._compile_registration(service_key, registration)
if provider is not None:
self._compiled_providers[service_key] = provider
scoped_scopes_by_key: dict[ServiceKey, set[str]] = {}
for service_key, scope_name in self._scoped_registry:
scoped_scopes_by_key.setdefault(service_key, set()).add(scope_name)
# Compile scoped registrations
for (service_key, scope_name), registration in list(self._scoped_registry.items()):
if (service_key, scope_name) in self._scoped_compiled_providers:
continue
provider = self._compile_scoped_registration(
service_key,
registration,
scope_name,
scoped_scopes_by_key,
)
if provider is not None:
self._scoped_compiled_providers[(service_key, scope_name)] = provider
self._scoped_compiled_providers_by_scope.setdefault(scope_name, {})[service_key] = (
provider
)
if service_key.is_type_key:
self._scoped_type_providers[(service_key.value, scope_name)] = provider
self._scoped_type_providers_by_scope.setdefault(scope_name, {})[
service_key.value
] = provider
# Build async dependency cache for faster async resolution
self._build_async_deps_cache()
# Pre-warm fast type caches for direct type lookups
for service_key, provider in list(self._compiled_providers.items()):
if service_key.is_type_key:
self._type_providers[service_key.value] = provider
# Also cache any already-resolved singletons
if service_key in self._singletons:
self._type_singletons[service_key.value] = self._singletons[service_key]
self._is_compiled = True
def _build_async_deps_cache(self) -> None:
"""Build a cache of which service keys have async dependencies.
This eliminates registry lookups in the async resolution path.
"""
for service_key in list(self._registry):
if not isinstance(service_key.value, type):
continue
async_deps: set[ServiceKey] = set()
try:
deps = self._dependencies_extractor.get_dependencies_with_defaults(service_key)
for param_info in deps.values():
if param_info.typevar is not None:
continue
dep_reg = self._registry.get(param_info.service_key)
if dep_reg is not None and dep_reg.is_async:
async_deps.add(param_info.service_key)
except DIWireError:
continue
if async_deps:
self._async_deps_cache[service_key] = frozenset(async_deps)
def _compile_registration(
self,
service_key: ServiceKey,
registration: Registration,
) -> CompiledProvider | None:
"""Compile a single registration into an optimized provider."""
# Skip scoped registrations (handled separately)
if registration.scope is not None:
return None
if registration.is_async:
return None
if registration.typevar_map is not None:
return None
# Handle pre-created instances
if registration.instance is not None:
return InstanceProvider(registration.instance)
# Handle factory registrations
if registration.factory is not None:
if isinstance(registration.factory, type):
# Factory is a class - compile it as a provider
factory_key = ServiceKey.from_value(registration.factory)
factory_provider = self._compile_or_get_provider(factory_key)
if factory_provider is None:
return None
elif isinstance(registration.factory, FunctionType | MethodType):
# Functions/methods need resolution - skip compilation for now
# They may have Injected parameters that need injection
return None
else:
# Factory is a built-in callable (e.g., ContextVar.get) - wrap directly
factory_provider = InstanceProvider(registration.factory)
result_handler = self._make_compiled_factory_result_handler(
service_key,
registration.lifetime,
registration.scope,
)
if registration.lifetime == Lifetime.SINGLETON:
return SingletonFactoryProvider(service_key, factory_provider, result_handler)
return FactoryProvider(factory_provider, result_handler)
# Use concrete_type if registered with provides parameter
instantiation_type = registration.concrete_type or service_key.value
# Handle type registrations - compile dependencies
if not isinstance(instantiation_type, type):
return None
# Use concrete type's service key for dependency extraction
instantiation_key = (
ServiceKey.from_value(instantiation_type)
if registration.concrete_type is not None
else service_key
)
try:
deps = self._dependencies_extractor.get_dependencies_with_defaults(instantiation_key)
except DIWireError:
return None
# Filter out ignored types with defaults
filtered_deps: dict[str, ServiceKey] = {}
for name, param_info in deps.items():
if param_info.typevar is not None:
return None
if param_info.service_key.value in self._autoregister_ignores:
if param_info.has_default:
continue
# Can't compile - missing required dependency
return None
filtered_deps[name] = param_info.service_key
if not filtered_deps:
# No dependencies - use simple provider
if registration.lifetime == Lifetime.SINGLETON:
return SingletonTypeProvider(instantiation_type, service_key)
return TypeProvider(instantiation_type)
positional_order = self._get_positional_dependency_order(instantiation_type, filtered_deps)
if positional_order is None:
use_positional = False
param_names = list(filtered_deps.keys())
else:
use_positional = True
param_names = list(positional_order)
# Compile dependency providers
dep_providers: list[CompiledProvider] = []
for name in param_names:
dep_key = filtered_deps[name]
dep_provider = self._compile_or_get_provider(dep_key)
if dep_provider is None:
return None
dep_providers.append(dep_provider)
if registration.lifetime == Lifetime.SINGLETON:
if use_positional:
return SingletonPositionalArgsTypeProvider(
instantiation_type,
service_key,
tuple(dep_providers),
)
return SingletonArgsTypeProvider(
instantiation_type,
service_key,
tuple(param_names),
tuple(dep_providers),
)
if use_positional:
return PositionalArgsTypeProvider(
instantiation_type,
tuple(dep_providers),
)
return ArgsTypeProvider(instantiation_type, tuple(param_names), tuple(dep_providers))
def _make_compiled_factory_result_handler(
self,
service_key: ServiceKey,
lifetime: Lifetime,
scope: str | None,
) -> Callable[[Any], Any]:
def handler(result: Any) -> Any:
return self._handle_compiled_factory_result(
result,
service_key,
lifetime,
scope,
)
return handler
def _get_positional_dependency_order(
self,
instantiation_type: type,
dependencies: dict[str, ServiceKey],
) -> tuple[str, ...] | None:
if not dependencies:
return ()
signature_type = getattr(instantiation_type, "_original_class", None)
if not isinstance(signature_type, type):
signature_type = instantiation_type
try:
sig = inspect.signature(signature_type)
except (TypeError, ValueError):
return None
params = [param for param in sig.parameters.values() if param.name != "self"]
for param in params:
if param.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
return None
if (
param.kind == inspect.Parameter.KEYWORD_ONLY
and param.default is inspect.Parameter.empty
):
return None
if param.kind == inspect.Parameter.KEYWORD_ONLY and param.name in dependencies:
return None
positional_names = [
param.name
for param in params
if param.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
if any(name not in positional_names for name in dependencies):
return None
included_indices = [
index for index, name in enumerate(positional_names) if name in dependencies
]
if not included_indices:
return ()
last_index = max(included_indices)
for index in range(last_index + 1):
if positional_names[index] not in dependencies:
return None
return tuple(name for name in positional_names if name in dependencies)
def _handle_compiled_factory_result(
self,
result: Any,
service_key: ServiceKey,
lifetime: Lifetime,
scope: str | None,
) -> Any:
if inspect.iscoroutine(result):
result.close()
raise DIWireAsyncDependencyInSyncContextError(service_key, service_key)
if isinstance(result, AsyncGenerator):
raise DIWireAsyncDependencyInSyncContextError(service_key, service_key)
if isinstance(result, Generator):
current_scope = _current_scope.get() if self._has_scoped_registrations else None
cache_scope = self._get_cache_scope(current_scope, scope)
if cache_scope is None:
raise DIWireGeneratorFactoryWithoutScopeError(service_key)
if lifetime == Lifetime.SINGLETON:
raise DIWireGeneratorFactoryUnsupportedLifetimeError(service_key)
try:
instance = next(result)
except StopIteration as exc:
raise DIWireGeneratorFactoryDidNotYieldError(service_key) from exc
self._get_scope_exit_stack(cache_scope).callback(result.close)
return instance
return result
def _compile_or_get_provider(self, service_key: ServiceKey) -> CompiledProvider | None:
"""Get an existing compiled provider or compile a new one."""
# Check if already compiled
if service_key in self._compiled_providers:
return self._compiled_providers[service_key]
# Check registry
registration = self._registry.get(service_key)
if registration is not None:
provider = self._compile_registration(service_key, registration)
if provider is not None:
self._compiled_providers[service_key] = provider
return provider
# Auto-register if enabled
if self._autoregister:
try:
registration = self._get_auto_registration(service_key)
self._registry[service_key] = registration
provider = self._compile_registration(service_key, registration)
if provider is not None:
self._compiled_providers[service_key] = provider
return provider
except DIWireError:
return None
return None
def _compile_or_get_scoped_provider(
self,
service_key: ServiceKey,
scope_name: str,
scoped_scopes_by_key: dict[ServiceKey, set[str]],
) -> CompiledProvider | None:
"""Get an existing compiled scoped provider or compile a new one."""
scoped_key = (service_key, scope_name)
existing = self._scoped_compiled_providers.get(scoped_key)
if existing is not None:
return existing
registration = self._scoped_registry.get(scoped_key)
if registration is None:
return None
provider = self._compile_scoped_registration(
service_key,
registration,
scope_name,
scoped_scopes_by_key,
)
if provider is not None:
self._scoped_compiled_providers[scoped_key] = provider
return provider
def _compile_scoped_registration(
self,
service_key: ServiceKey,
registration: Registration,
scope_name: str,
scoped_scopes_by_key: dict[ServiceKey, set[str]],
) -> CompiledProvider | None:
"""Compile a scoped registration into an optimized provider.
Uses ScopedSingletonProvider for scoped singletons.
"""
# Skip non-type registrations (instances, factories)
# These need special handling for scope lifecycle
if registration.instance is not None or registration.factory is not None:
return None
if registration.typevar_map is not None:
return None
# Use concrete_type if registered with provides parameter
instantiation_type = registration.concrete_type or service_key.value
if not isinstance(instantiation_type, type):
return None
# Use concrete type's service key for dependency extraction
instantiation_key = (
ServiceKey.from_value(instantiation_type)
if registration.concrete_type is not None
else service_key
)
try:
deps = self._dependencies_extractor.get_dependencies_with_defaults(instantiation_key)
except DIWireError:
return None
# Filter out ignored types with defaults
filtered_deps: dict[str, ServiceKey] = {}
for name, param_info in deps.items():
if param_info.typevar is not None:
return None
if param_info.service_key.value in self._autoregister_ignores:
if param_info.has_default:
continue
return None
filtered_deps[name] = param_info.service_key
if not filtered_deps:
if registration.lifetime == Lifetime.TRANSIENT:
return TypeProvider(instantiation_type)
# No dependencies - use simple scoped provider
return ScopedSingletonProvider(instantiation_type, service_key)
positional_order = self._get_positional_dependency_order(instantiation_type, filtered_deps)
if positional_order is None:
use_positional = False
param_names = list(filtered_deps.keys())
else:
use_positional = True
param_names = list(positional_order)
# Compile dependency providers
dep_providers: list[CompiledProvider] = []
for name in param_names:
dep_key = filtered_deps[name]
dep_scopes = scoped_scopes_by_key.get(dep_key)
if dep_scopes:
if dep_scopes != {scope_name}:
return None
dep_provider = self._compile_or_get_scoped_provider(
dep_key,
scope_name,
scoped_scopes_by_key,
)
else:
dep_provider = self._compile_or_get_provider(dep_key)
if dep_provider is None:
return None
dep_providers.append(dep_provider)
if registration.lifetime == Lifetime.TRANSIENT:
if use_positional:
return PositionalArgsTypeProvider(
instantiation_type,
tuple(dep_providers),
)
return ArgsTypeProvider(
instantiation_type,
tuple(param_names),
tuple(dep_providers),
)
if use_positional:
return ScopedSingletonPositionalArgsProvider(
instantiation_type,
service_key,
tuple(dep_providers),
)
return ScopedSingletonArgsProvider(
instantiation_type,
service_key,
tuple(param_names),
tuple(dep_providers),
)
# Decorator overloads (key=None) - returns a decorator that wraps functions
@overload
def resolve(
self,
key: None = None,
*,
scope: str,
) -> Callable[[Callable[..., Any]], Any]: ...
@overload
def resolve(
self,
key: None = None,
*,
scope: None = None,
) -> Callable[[Callable[..., Any]], Any]: ...
@overload
def resolve(self, key: type[T], *, scope: None = None) -> T: ...
@overload
def resolve(self, key: type[T], *, scope: str) -> T: ...
@overload
def resolve(
self,
key: Callable[..., Coroutine[Any, Any, T]],
*,
scope: None = None,
) -> _AsyncInjectedFunction[T]: ...
@overload
def resolve(
self,
key: Callable[..., Coroutine[Any, Any, T]],
*,
scope: str,
) -> _AsyncScopedInjectedFunction[T]: ...
@overload
def resolve(self, key: Callable[..., T], *, scope: None = None) -> _InjectedFunction[T]: ...
@overload
def resolve(self, key: Callable[..., T], *, scope: str) -> _ScopedInjectedFunction[T]: ...
@overload
def resolve(self, key: ServiceKey, *, scope: str | None = None) -> Any: ...
@overload
def resolve(self, key: Any, *, scope: str | None = None) -> Any: ...
[docs]
def resolve(self, key: Any | None = None, *, scope: str | None = None) -> Any: # noqa: PLR0915
"""Resolve and return a service instance by its key.
When called with key=None, returns a decorator that can be applied to
functions to enable dependency injection.
Args:
key: The service key to resolve. If None, returns a decorator.
scope: Optional scope name. If provided and key is a function,
returns a ScopedInjected that creates a new scope per call.
Examples:
.. code-block:: python
# Direct usage:
injected = container.resolve(my_func, scope="request")
# Decorator usage:
@container.resolve(scope="request")
async def handler(service: Annotated[Service, Injected()]) -> dict: ...
"""
self._check_not_closed()
# DECORATOR PATTERN: resolve(scope="...") or resolve() returns decorator
if key is None:
def decorator(func: Callable[..., Any]) -> Any:
return self.resolve(func, scope=scope)
return decorator
# FAST PATH for simple types (most common case)
# Bypasses ServiceKey creation entirely for cached singletons
# Only use fast path when not in a scope (scoped registrations may override)
if (
isinstance(key, type)
and scope is None
and (not self._has_scoped_registrations or _current_scope.get() is None)
):
# Direct singleton lookup - fastest path
cached = self._type_singletons.get(key)
if cached is not None:
return cached
# Direct provider lookup (only when compiled)
if self._is_compiled:
provider = self._type_providers.get(key)
if provider is not None:
result = provider(self._singletons, None)
# Cache singleton results for next time (singleton providers have _instance)
if hasattr(provider, "_instance"):
self._type_singletons[key] = result
return result
service_key = ServiceKey.from_value(key)
if isinstance(service_key.value, FunctionType | MethodType):
# Determine scope: explicit parameter takes precedence
# If scope is None, try to detect from dependencies (may fail with NameError
# if using `from __future__ import annotations` with forward references)
effective_scope = scope
if effective_scope is None:
try:
injected_deps = self._dependencies_extractor.get_injected_dependencies(
service_key=service_key,
)
effective_scope = self._find_scope_in_dependencies(injected_deps)
except NameError:
# Forward reference not resolvable yet (e.g., with PEP 563)
# Default to no scope - user should provide explicit scope parameter
effective_scope = None
# Check if the function is async
is_async_func = inspect.iscoroutinefunction(service_key.value)
if effective_scope is not None:
if is_async_func:
return _AsyncScopedInjectedFunction(
func=service_key.value,
container=self,
dependencies_extractor=self._dependencies_extractor,
service_key=service_key,
scope_name=effective_scope,
)
return _ScopedInjectedFunction(
func=service_key.value,
container=self,
dependencies_extractor=self._dependencies_extractor,
service_key=service_key,
scope_name=effective_scope,
)
if is_async_func:
return _AsyncInjectedFunction(
func=service_key.value,
container=self,
dependencies_extractor=self._dependencies_extractor,
service_key=service_key,
)
return _InjectedFunction(
func=service_key.value,
container=self,
dependencies_extractor=self._dependencies_extractor,
service_key=service_key,
)
# Skip ContextVar lookup when no scoped registrations exist
current_scope = _current_scope.get() if self._has_scoped_registrations else None
# Auto-compile on first resolve if enabled and not in a scope
if self._auto_compile and not self._is_compiled and current_scope is None:
self.compile()
# Fast path: use compiled providers when available and not in a scope
if self._is_compiled and current_scope is None:
provider = self._compiled_providers.get(service_key)
if provider is not None:
return provider(self._singletons, None)
# Check for scoped registration FIRST when inside a scope
scoped_registration = None
scoped_scope_name: str | None = None
if current_scope is not None:
scoped_registration = self._get_scoped_registration(service_key, current_scope)
if scoped_registration is not None:
scoped_scope_name = scoped_registration.scope
# Return cached global singleton ONLY if no scoped registration matches
if scoped_registration is None and service_key in self._singletons:
return self._singletons[service_key]
# Fast path for compiled scoped providers
if (
self._is_compiled
and scoped_registration is not None
and scoped_scope_name is not None
and current_scope is not None
):
scoped_provider = self._scoped_compiled_providers.get((service_key, scoped_scope_name))
if scoped_provider is not None:
cache_scope = current_scope.get_cache_key_for_scope(scoped_scope_name)
if cache_scope is not None:
use_lock = self._is_multithreaded()
compiled_scoped_cache = self._get_scoped_cache_view(
cache_scope,
use_lock=use_lock,
)
return scoped_provider(self._singletons, compiled_scoped_cache)
# Inline circular dependency tracking (avoids context manager overhead)
stack = _get_resolution_stack()
if service_key in stack:
raise DIWireCircularDependencyError(service_key, list(stack))
stack.append(service_key)
try:
# Use scoped registration if found, otherwise get from registry
registration = (
scoped_registration
if scoped_registration is not None
else self._get_registration(service_key, current_scope)
)
# Validate scope if service is registered with a specific scope
if registration.scope is not None and (
current_scope is None or not self._scope_matches(current_scope, registration.scope)
):
raise DIWireScopeMismatchError(
service_key,
registration.scope,
current_scope.path if current_scope else None,
)
# Check for async dependencies - raise early with helpful error
if registration.is_async:
raise DIWireAsyncDependencyInSyncContextError(service_key, service_key)
# Determine the scope key to use for caching
cache_scope = self._get_cache_scope(current_scope, registration.scope)
scoped_cache: MutableMapping[ServiceKey, Any] | None = None
cache_key: tuple[tuple[tuple[str | None, int], ...], ServiceKey] | None = None
type_cache: dict[type, Any] | None = None
is_type_key = service_key.is_type_key
# Check scoped instance cache using flat dict (single lookup)
if cache_scope is not None:
if is_type_key:
type_cache = self._get_scope_type_cache(cache_scope)
cached = type_cache.get(service_key.value)
if cached is not None:
return cached
scoped_cache = self._get_scope_cache(cache_scope)
cache_key = (cache_scope, service_key)
cached = scoped_cache.get(service_key)
if cached is not None:
if type_cache is not None:
type_cache[service_key.value] = cached
return cached
scoped_lock: threading.RLock | None = None # type: ignore[no-redef]
scoped_lock_acquired = False
# Skip lock contention in single-threaded scenarios.
if (
registration.lifetime == Lifetime.SCOPED
and scoped_cache is not None
and cache_key is not None
and registration.instance is None
and self._is_multithreaded()
):
scoped_lock = self._get_scope_cache_lock(cache_key[0])
scoped_lock.acquire()
scoped_lock_acquired = True
# Double-check cache after acquiring lock
cached = None
if type_cache is not None:
cached = type_cache.get(service_key.value)
if cached is None:
cached = scoped_cache.get(service_key)
if cached is not None: # pragma: no cover - race timing dependent
scoped_lock.release() # pragma: no cover
scoped_lock_acquired = False # pragma: no cover
if type_cache is not None:
type_cache[service_key.value] = cached # pragma: no cover
return cached # pragma: no cover
if registration.instance is not None:
if (
registration.scope is not None
and scoped_cache is not None
and cache_key is not None
):
scoped_cache[service_key] = registration.instance
if type_cache is not None:
type_cache[service_key.value] = registration.instance
else:
self._singletons[service_key] = registration.instance
if scoped_lock is not None and scoped_lock_acquired: # pragma: no cover - defensive
scoped_lock.release() # pragma: no cover
scoped_lock_acquired = False # pragma: no cover
return registration.instance
# For singletons, use lock to prevent race conditions in threaded resolution
is_global_singleton = (
registration.lifetime == Lifetime.SINGLETON and scoped_registration is None
)
singleton_lock: threading.Lock | None = None
if is_global_singleton:
singleton_lock = self._locks.get_sync_singleton_lock(service_key)
singleton_lock.acquire()
# Double-check: re-check cache after acquiring lock
if service_key in self._singletons:
singleton_lock.release()
return self._singletons[service_key]
try:
if registration.factory is not None:
if isinstance(registration.factory, type):
# Factory is a class - resolve via container to instantiate
factory: Any = self.resolve(registration.factory)
instance = factory()
elif isinstance(registration.factory, FunctionType | MethodType):
# Function/method factory - resolve ALL deps and call directly
# This allows factory functions to have all params auto-injected
factory_key = ServiceKey.from_value(registration.factory)
resolved = self._get_resolved_dependencies(
factory_key,
typevar_map=registration.typevar_map,
)
if resolved.missing:
raise DIWireMissingDependenciesError(factory_key, resolved.missing)
instance = registration.factory(**resolved.dependencies)
else:
# Factory is a built-in callable (e.g., ContextVar.get) - use directly
instance = registration.factory()
if isinstance(instance, Generator):
if cache_scope is None:
raise DIWireGeneratorFactoryWithoutScopeError(service_key)
if registration.lifetime == Lifetime.SINGLETON:
raise DIWireGeneratorFactoryUnsupportedLifetimeError(service_key)
try:
generated_instance = next(instance)
except StopIteration as exc:
raise DIWireGeneratorFactoryDidNotYieldError(service_key) from exc
self._get_scope_exit_stack(cache_scope).callback(instance.close)
instance = generated_instance # type: ignore[possibly-undefined]
if registration.lifetime == Lifetime.SINGLETON:
self._singletons[service_key] = instance
elif (
registration.lifetime == Lifetime.SCOPED
and scoped_cache is not None
and cache_key is not None
):
scoped_cache[service_key] = instance
if type_cache is not None:
type_cache[service_key.value] = instance
return instance
# Use concrete_type if registered with provides parameter
instantiation_type = registration.concrete_type or service_key.value
instantiation_key = (
ServiceKey.from_value(instantiation_type)
if registration.concrete_type is not None
else service_key
)
resolved_dependencies = self._get_resolved_dependencies(
service_key=instantiation_key,
typevar_map=registration.typevar_map,
)
if resolved_dependencies.missing:
raise DIWireMissingDependenciesError(service_key, resolved_dependencies.missing)
instance = instantiation_type(**resolved_dependencies.dependencies)
if registration.lifetime == Lifetime.SINGLETON:
self._singletons[service_key] = instance
elif (
registration.lifetime == Lifetime.SCOPED
and scoped_cache is not None
and cache_key is not None
):
scoped_cache[service_key] = instance
if type_cache is not None:
type_cache[service_key.value] = instance
return instance
finally:
if singleton_lock is not None and singleton_lock.locked():
singleton_lock.release()
if scoped_lock is not None and scoped_lock_acquired:
scoped_lock.release()
scoped_lock_acquired = False
finally:
stack.pop()
def _resolve_scoped_compiled(
self,
key: Any,
scope_id: _ScopeId,
) -> tuple[bool, Any]:
"""Fast-path for compiled resolution within an explicit scope."""
if not self._is_compiled or not isinstance(key, type):
return False, None
named_scopes_desc = scope_id.named_scopes_desc
scoped_type_providers_by_scope = self._scoped_type_providers_by_scope
if scoped_type_providers_by_scope:
for scope_name in named_scopes_desc:
type_providers = scoped_type_providers_by_scope.get(scope_name)
if type_providers is None:
continue
provider = type_providers.get(key)
if provider is not None:
cache_scope = scope_id.get_cache_key_for_scope(scope_name)
if cache_scope is None: # pragma: no cover - scope_id invariant
return False, None
use_lock = self._is_multithreaded()
scoped_cache = self._get_scoped_cache_view(
cache_scope,
use_lock=use_lock,
)
return True, provider(self._singletons, scoped_cache)
service_key = ServiceKey.from_value(key)
scoped_compiled_by_scope = self._scoped_compiled_providers_by_scope
if scoped_compiled_by_scope:
for scope_name in named_scopes_desc:
scoped_providers = scoped_compiled_by_scope.get(scope_name)
if scoped_providers is None:
continue
provider = scoped_providers.get(service_key)
if provider is not None:
cache_scope = scope_id.get_cache_key_for_scope(scope_name)
if cache_scope is None: # pragma: no cover - scope_id invariant
return False, None
use_lock = self._is_multithreaded()
scoped_cache = self._get_scoped_cache_view(
cache_scope,
use_lock=use_lock,
)
return True, provider(self._singletons, scoped_cache)
if (
not self._has_scoped_registrations
or self._get_scoped_registration(service_key, scope_id) is None
):
provider = self._compiled_providers.get(service_key)
if provider is not None:
return True, provider(self._singletons, None)
return False, None
def _is_multithreaded(self) -> bool:
if self._multithreaded:
return True
if threading.get_ident() != self._thread_id:
self._multithreaded = True
return True
return False
@overload
async def aresolve(self, key: type[T], *, scope: None = None) -> T: ...
@overload
async def aresolve(self, key: type[T], *, scope: str) -> T: ...
@overload
async def aresolve(
self,
key: Callable[..., Coroutine[Any, Any, T]],
*,
scope: None = None,
) -> _AsyncInjectedFunction[T]: ...
@overload
async def aresolve(
self,
key: Callable[..., Coroutine[Any, Any, T]],
*,
scope: str,
) -> _AsyncScopedInjectedFunction[T]: ...
@overload
async def aresolve(self, key: ServiceKey, *, scope: str | None = None) -> Any: ...
@overload
async def aresolve(self, key: Any, *, scope: str | None = None) -> Any: ...
[docs]
async def aresolve(self, key: Any, *, scope: str | None = None) -> Any: # noqa: PLR0915
"""Asynchronously resolve and return a service instance by its key.
This method supports async factories and async generator factories.
Use this method when resolving services that have async dependencies.
Note:
For decorator usage, use the synchronous `.resolve()` method which
handles both sync and async functions correctly.
Args:
key: The service key to resolve.
scope: Optional scope name. If provided and key is a function,
returns an AsyncScopedInjected that creates a new scope per call.
Raises:
DIWireAsyncGeneratorFactoryWithoutScopeError: If an async generator factory
is used without an active scope.
Examples:
# Direct usage:
injected = await container.aresolve(my_func, scope="request")
"""
self._check_not_closed()
# FAST PATH for cached singletons (same as sync resolve)
# Only use fast path when not in a scope (scoped registrations may override)
if (
isinstance(key, type)
and scope is None
and (not self._has_scoped_registrations or _current_scope.get() is None)
):
cached = self._type_singletons.get(key)
if cached is not None:
return cached
service_key = ServiceKey.from_value(key)
if isinstance(service_key.value, FunctionType | MethodType):
# Determine scope: explicit parameter takes precedence
# If scope is None, try to detect from dependencies (may fail with NameError
# if using `from __future__ import annotations` with forward references)
effective_scope = scope
if effective_scope is None:
try:
injected_deps = self._dependencies_extractor.get_injected_dependencies(
service_key=service_key,
)
effective_scope = self._find_scope_in_dependencies(injected_deps)
except NameError:
# Forward reference not resolvable yet (e.g., with PEP 563)
# Default to no scope - user should provide explicit scope parameter
effective_scope = None
# Check if the function is async
is_async_func = inspect.iscoroutinefunction(service_key.value)
if effective_scope is not None:
if is_async_func:
return _AsyncScopedInjectedFunction(
func=service_key.value,
container=self,
dependencies_extractor=self._dependencies_extractor,
service_key=service_key,
scope_name=effective_scope,
)
return _ScopedInjectedFunction(
func=service_key.value,
container=self,
dependencies_extractor=self._dependencies_extractor,
service_key=service_key,
scope_name=effective_scope,
)
if is_async_func:
return _AsyncInjectedFunction(
func=service_key.value,
container=self,
dependencies_extractor=self._dependencies_extractor,
service_key=service_key,
)
return _InjectedFunction(
func=service_key.value,
container=self,
dependencies_extractor=self._dependencies_extractor,
service_key=service_key,
)
# Skip ContextVar lookup when no scoped registrations exist
current_scope = _current_scope.get() if self._has_scoped_registrations else None
# Auto-compile on first resolve if enabled and not in a scope
if self._auto_compile and not self._is_compiled and current_scope is None:
self.compile()
# Return cached global singleton if available and no scoped registration
scoped_registration = None
if current_scope is not None:
scoped_registration = self._get_scoped_registration(service_key, current_scope)
if scoped_registration is None and service_key in self._singletons:
return self._singletons[service_key]
# Inline circular dependency tracking
stack = _get_resolution_stack()
if service_key in stack:
raise DIWireCircularDependencyError(service_key, list(stack))
stack.append(service_key)
try:
# Use scoped registration if found, otherwise get from registry
registration = (
scoped_registration
if scoped_registration is not None
else self._get_registration(service_key, current_scope)
)
# Validate scope if service is registered with a specific scope
if registration.scope is not None and (
current_scope is None or not self._scope_matches(current_scope, registration.scope)
):
raise DIWireScopeMismatchError(
service_key,
registration.scope,
current_scope.path if current_scope else None,
)
# Determine the scope key to use for caching
cache_scope = self._get_cache_scope(current_scope, registration.scope)
cache_key = (cache_scope, service_key) if cache_scope is not None else None
scoped_cache: MutableMapping[ServiceKey, Any] | None = None
type_cache: dict[type, Any] | None = None
is_type_key = service_key.is_type_key
# Check scoped instance cache using flat dict (single lookup)
if cache_scope is not None:
if is_type_key:
type_cache = self._get_scope_type_cache(cache_scope)
cached = type_cache.get(service_key.value)
if cached is not None:
return cached
scoped_cache = self._get_scope_cache(cache_scope)
cached = scoped_cache.get(service_key)
if cached is not None:
if type_cache is not None:
type_cache[service_key.value] = cached
return cached
scoped_lock: asyncio.Lock | None = None
if (
registration.lifetime == Lifetime.SCOPED
and cache_key is not None
and registration.instance is None
):
scoped_lock = await self._locks.get_scoped_singleton_lock(cache_key)
await scoped_lock.acquire()
# Double-check cache after acquiring lock
cached = None
if type_cache is not None:
cached = type_cache.get(service_key.value)
if cached is None and scoped_cache is not None:
cached = scoped_cache.get(service_key)
if cached is not None:
if type_cache is not None:
type_cache[service_key.value] = cached
scoped_lock.release()
return cached
if registration.instance is not None:
if (
registration.scope is not None
and scoped_cache is not None
and cache_key is not None
):
scoped_cache[service_key] = registration.instance
if type_cache is not None:
type_cache[service_key.value] = registration.instance
else:
self._singletons[service_key] = registration.instance
if (
scoped_lock is not None and scoped_lock.locked()
): # pragma: no cover - defensive, lock only acquired when instance is None
scoped_lock.release() # pragma: no cover
return registration.instance
# For singletons, use lock to prevent race conditions in async resolution
# The lock is acquired here (after getting registration) and released in finally
is_global_singleton = (
registration.lifetime == Lifetime.SINGLETON and scoped_registration is None
)
singleton_lock: asyncio.Lock | None = None
if is_global_singleton:
singleton_lock = await self._locks.get_singleton_lock(service_key)
await singleton_lock.acquire()
# Double-check: re-check cache after acquiring lock
# This path is hit when another coroutine resolved while we were waiting for the lock
if service_key in self._singletons: # pragma: no cover - race timing dependent
singleton_lock.release()
return self._singletons[service_key]
try:
if registration.factory is not None:
# Call the factory based on its type
if isinstance(registration.factory, type):
# Factory is a class - resolve via container to instantiate
factory: Any = await self.aresolve(registration.factory)
result = factory()
elif isinstance(registration.factory, FunctionType | MethodType):
# Function/method factory - resolve ALL deps and call directly
# This allows factory functions to have all params auto-injected
factory_key = ServiceKey.from_value(registration.factory)
resolved = await self._aget_resolved_dependencies(
factory_key,
typevar_map=registration.typevar_map,
)
if resolved.missing:
raise DIWireMissingDependenciesError(factory_key, resolved.missing)
result = registration.factory(**resolved.dependencies)
else:
# Factory is a built-in callable (e.g., ContextVar.get) - use directly
result = registration.factory()
# Handle async factories
if inspect.iscoroutine(result):
instance = await result
elif isinstance(result, AsyncGenerator):
# Async generator factory
if cache_scope is None:
raise DIWireAsyncGeneratorFactoryWithoutScopeError(service_key)
if registration.lifetime == Lifetime.SINGLETON:
raise DIWireGeneratorFactoryUnsupportedLifetimeError(service_key)
try:
instance = await result.__anext__()
except StopAsyncIteration as exc:
raise DIWireAsyncGeneratorFactoryDidNotYieldError(service_key) from exc
# Register cleanup
async_exit_stack = self._get_async_scope_exit_stack(cache_scope)
async_exit_stack.push_async_callback(result.aclose)
elif isinstance(result, Generator):
# Sync generator factory
if cache_scope is None:
raise DIWireGeneratorFactoryWithoutScopeError(service_key)
if registration.lifetime == Lifetime.SINGLETON:
raise DIWireGeneratorFactoryUnsupportedLifetimeError(service_key)
try:
instance = next(result)
except StopIteration as exc:
raise DIWireGeneratorFactoryDidNotYieldError(service_key) from exc
self._get_scope_exit_stack(cache_scope).callback(result.close)
else:
instance = result
if registration.lifetime == Lifetime.SINGLETON:
self._singletons[service_key] = instance # type: ignore[possibly-undefined]
elif (
registration.lifetime == Lifetime.SCOPED
and scoped_cache is not None
and cache_key is not None
):
scoped_cache[service_key] = instance # type: ignore[possibly-undefined]
if type_cache is not None:
type_cache[service_key.value] = instance # type: ignore[possibly-undefined]
return instance # type: ignore[possibly-undefined]
# Use concrete_type if registered with provides parameter
instantiation_type = registration.concrete_type or service_key.value
instantiation_key = (
ServiceKey.from_value(instantiation_type)
if registration.concrete_type is not None
else service_key
)
# Resolve dependencies
resolved_dependencies = await self._aget_resolved_dependencies(
service_key=instantiation_key,
typevar_map=registration.typevar_map,
)
if resolved_dependencies.missing:
raise DIWireMissingDependenciesError(service_key, resolved_dependencies.missing)
instance = instantiation_type(**resolved_dependencies.dependencies)
if registration.lifetime == Lifetime.SINGLETON:
self._singletons[service_key] = instance
elif (
registration.lifetime == Lifetime.SCOPED
and scoped_cache is not None
and cache_key is not None
):
scoped_cache[service_key] = instance
if type_cache is not None:
type_cache[service_key.value] = instance
return instance
finally:
if singleton_lock is not None and singleton_lock.locked():
singleton_lock.release()
if scoped_lock is not None and scoped_lock.locked():
scoped_lock.release()
finally:
stack.pop()
async def _aget_resolved_dependencies(
self,
service_key: ServiceKey,
*,
typevar_map: dict[Any, Any] | None = None,
) -> _ResolvedDependencies:
"""Asynchronously resolve dependencies for a service."""
resolved_dependencies = _ResolvedDependencies()
dependencies = self._dependencies_extractor.get_dependencies_with_defaults(
service_key=service_key,
)
# Use pre-computed async deps cache when available (avoids registry lookups)
async_deps = self._async_deps_cache.get(service_key)
# Collect sync and async resolution tasks
sync_deps: dict[str, Any] = {}
async_tasks: list[tuple[str, Coroutine[Any, Any, Any]]] = []
for name, param_info in dependencies.items():
dep_key = param_info.service_key
if self._handle_typevar_dependency(
service_key=service_key,
name=name,
param_info=param_info,
typevar_map=typevar_map,
resolved_dependencies=resolved_dependencies,
):
continue
# Skip ignored types that aren't explicitly registered
if dep_key.value in self._autoregister_ignores:
# Check both global and scoped registries before marking as missing
is_registered = dep_key in self._registry
if not is_registered and self._has_scoped_registrations:
current_scope = _current_scope.get()
if current_scope is not None:
is_registered = (
self._get_scoped_registration(dep_key, current_scope) is not None
)
if not is_registered:
if param_info.has_default:
continue
resolved_dependencies.missing.append(dep_key)
continue
try:
# Fast path: use cached async deps info when compiled
if async_deps is not None and dep_key in async_deps:
async_tasks.append((name, self.aresolve(dep_key)))
else:
# Try sync resolution first
# For uncompiled containers, fall back to registry check
if not self._is_compiled:
registration = self._registry.get(dep_key)
if registration is not None and registration.is_async:
async_tasks.append((name, self.aresolve(dep_key)))
continue
# Sync resolution (will raise DIWireAsyncDependencyInSyncContextError if truly async)
try:
sync_deps[name] = self.resolve(dep_key)
except DIWireAsyncDependencyInSyncContextError:
async_tasks.append((name, self.aresolve(dep_key)))
except (DIWireCircularDependencyError, DIWireScopeMismatchError):
raise
except DIWireError:
if not param_info.has_default:
resolved_dependencies.missing.append(dep_key)
# Resolve async dependencies
if async_tasks:
if len(async_tasks) == 1:
# Single async dependency - await directly (skip gather overhead)
name, coro = async_tasks[0]
resolved_dependencies.dependencies[name] = await coro
else:
# Multiple async dependencies - resolve in parallel
# Wrap in create_task() so each coroutine gets its own context copy
names, coros = zip(*async_tasks, strict=True)
tasks = [asyncio.create_task(coro) for coro in coros]
results = await asyncio.gather(*tasks)
for name, result in zip(names, results, strict=True):
resolved_dependencies.dependencies[name] = result
# Add sync dependencies
resolved_dependencies.dependencies.update(sync_deps)
return resolved_dependencies
def _get_async_scope_exit_stack(
self,
scope_key: tuple[tuple[str | None, int], ...],
) -> AsyncExitStack:
"""Get or create an AsyncExitStack for the given scope."""
async_exit_stack = self._async_scope_exit_stacks.get(scope_key)
if async_exit_stack is None:
async_exit_stack = AsyncExitStack()
self._async_scope_exit_stacks[scope_key] = async_exit_stack
return async_exit_stack
async def _aclear_scope(self, scope_id: _ScopeId) -> None:
"""Asynchronously clear cached instances for a scope.
This properly cleans up async generators registered in the scope.
Args:
scope_id: The scope ID to clear.
"""
scope_key = scope_id.segments
# Close sync exit stack
scope_exit_stack = self._scope_exit_stacks.pop(scope_key, None)
if scope_exit_stack is not None:
scope_exit_stack.close()
# Close async exit stack
async_exit_stack = self._async_scope_exit_stacks.pop(scope_key, None)
if async_exit_stack is not None:
await async_exit_stack.aclose()
self._scope_caches.pop(scope_key, None)
self._scope_type_caches.pop(scope_key, None)
self._scoped_cache_views.pop(scope_key, None)
self._scoped_cache_views_nolock.pop(scope_key, None)
self._scope_cache_locks.pop(scope_key, None)
self._locks.clear_scope_locks(scope_key)
def _register_active_scope(self, scope: ScopedContainer) -> None:
"""Register a scope as active for imperative close()."""
if self._is_multithreaded():
with self._active_scopes_lock:
if self._closed:
raise DIWireContainerClosedError
self._active_scopes.append(scope)
return
if self._closed:
raise DIWireContainerClosedError
self._active_scopes.append(scope)
def _unregister_active_scope(self, scope: ScopedContainer) -> None:
"""Unregister a scope when it is closed."""
if self._is_multithreaded():
with self._active_scopes_lock, contextlib.suppress(ValueError):
self._active_scopes.remove(scope)
return
with contextlib.suppress(ValueError):
self._active_scopes.remove(scope)
def _check_not_closed(self) -> None:
"""Raise an error if the container is closed."""
if self._closed:
raise DIWireContainerClosedError
[docs]
def close(self) -> None:
"""Close all active scopes and mark the container as closed.
After calling this method, any attempt to resolve services or start
new scopes will raise DIWireContainerClosedError.
Scopes are closed in LIFO order (newest first).
This method is idempotent - calling it multiple times is safe.
If a scope's close() fails, that scope remains in _active_scopes
and the exception is re-raised.
"""
with self._active_scopes_lock:
if self._closed:
return
self._closed = True
while True:
with self._active_scopes_lock:
if not self._active_scopes:
break
scope = self._active_scopes[-1]
scope.close()
with self._active_scopes_lock:
if self._active_scopes and self._active_scopes[-1] is scope:
self._active_scopes.pop()
[docs]
async def aclose(self) -> None:
"""Asynchronously close all active scopes and mark the container as closed.
Use this method when scopes contain async generator factories that
need proper async cleanup.
After calling this method, any attempt to resolve services or start
new scopes will raise DIWireContainerClosedError.
Scopes are closed in LIFO order (newest first).
This method is idempotent - calling it multiple times is safe.
This method will drain remaining scopes even if the container is
already marked as closed. If a scope's aclose() fails, that scope
remains in _active_scopes and the exception is re-raised.
"""
with self._active_scopes_lock:
self._closed = True
while True:
with self._active_scopes_lock:
if not self._active_scopes:
break
scope = self._active_scopes[-1]
await scope.aclose()
with self._active_scopes_lock:
if self._active_scopes and self._active_scopes[-1] is scope:
self._active_scopes.pop()
[docs]
def close_scope(self, scope_name: str) -> None:
"""Close all active scopes that contain the given scope name.
This closes the named scope and all its child scopes in LIFO order
(children first, then parents).
Args:
scope_name: The name of the scope to close.
Example:
# Given hierarchy: app -> session -> request
container.close_scope("session") # Closes both "request" and "session"
"""
while True:
scope_to_close: ScopedContainer | None = None
with self._active_scopes_lock:
# Find scopes containing the scope_name, process from end (LIFO)
for i in range(len(self._active_scopes) - 1, -1, -1):
scope = self._active_scopes[i]
if scope._scope_id.contains_scope(scope_name): # noqa: SLF001
scope_to_close = scope
break
if scope_to_close is None:
return
scope_to_close.close()
[docs]
async def aclose_scope(self, scope_name: str) -> None:
"""Asynchronously close all active scopes that contain the given scope name.
This closes the named scope and all its child scopes in LIFO order
(children first, then parents).
Args:
scope_name: The name of the scope to close.
"""
while True:
scope_to_close: ScopedContainer | None = None
with self._active_scopes_lock:
for i in range(len(self._active_scopes) - 1, -1, -1):
scope = self._active_scopes[i]
if scope._scope_id.contains_scope(scope_name): # noqa: SLF001
scope_to_close = scope
break
if scope_to_close is None:
return
await scope_to_close.aclose()
def _get_scoped_registration(
self,
service_key: ServiceKey,
current_scope: _ScopeId,
) -> Registration | None:
"""Get a scoped registration for a service, if one exists.
Only checks the scoped registry, does not fall back to global registry.
Uses tuple iteration instead of string split/join for performance.
"""
# Check from most specific to least specific (named scopes only)
for name in current_scope.named_scopes_desc:
scoped_reg = self._scoped_registry.get((service_key, name))
if scoped_reg is not None:
return scoped_reg
return None
def _get_scoped_open_generic_registration(
self,
origin: type,
component: Component | None,
current_scope: _ScopeId,
) -> _OpenGenericRegistration | None:
"""Get a scoped open generic registration for a matching scope, if any."""
for name in current_scope.named_scopes_desc:
scoped_reg = self._scoped_open_generic_registry.get((origin, component, name))
if scoped_reg is not None:
return scoped_reg
return None
def _validate_typevar_map(
self,
service_key: ServiceKey,
typevar_map: dict[Any, Any],
) -> None:
"""Validate TypeVar bounds and constraints for a closed generic."""
for typevar, arg in typevar_map.items():
constraints = getattr(typevar, "__constraints__", ())
bound = getattr(typevar, "__bound__", None)
if constraints:
if not any(
_type_arg_matches_constraint(arg, constraint) for constraint in constraints
):
raise DIWireInvalidGenericTypeArgumentError(
service_key,
typevar,
arg,
f"Expected one of {constraints!r}.",
)
continue
if (
bound is not None
and not _is_any_type(bound)
and not _type_arg_matches_constraint(arg, bound)
):
raise DIWireInvalidGenericTypeArgumentError(
service_key,
typevar,
arg,
f"Expected bound {bound!r}.",
)
def _get_typevar_argument(
self,
typevar: Any,
typevar_map: dict[Any, Any],
) -> Any | None:
"""Lookup the concrete argument for a TypeVar from a map."""
if typevar in typevar_map:
return typevar_map[typevar]
return None
def _handle_typevar_dependency(
self,
*,
service_key: ServiceKey,
name: str,
param_info: ParameterInfo,
typevar_map: dict[Any, Any] | None,
resolved_dependencies: _ResolvedDependencies,
) -> bool:
"""Inject TypeVar-bound arguments if present; return True when handled."""
if param_info.typevar is None:
return False
if typevar_map is not None:
type_arg = self._get_typevar_argument(param_info.typevar, typevar_map)
if type_arg is not None:
resolved_dependencies.dependencies[name] = type_arg
return True
if param_info.has_default:
return True
raise DIWireOpenGenericResolutionError(
service_key,
f"Type argument for {getattr(param_info.typevar, '__name__', param_info.typevar)!r} "
"is missing.",
)
def _resolve_open_generic_registration(
self,
service_key: ServiceKey,
current_scope: _ScopeId | None,
) -> Registration | None:
origin, args = _get_generic_origin_and_args(service_key.value)
if origin is None or not args:
return None
if any(_is_typevar(arg) for arg in args):
raise DIWireOpenGenericResolutionError(
service_key,
"Type arguments must be concrete.",
)
open_registration: _OpenGenericRegistration | None = None
if current_scope is not None:
open_registration = self._get_scoped_open_generic_registration(
origin,
service_key.component,
current_scope,
)
if open_registration is None:
open_registration = self._open_generic_registry.get((origin, service_key.component))
if open_registration is None:
return None
if len(args) != len(open_registration.typevars):
raise DIWireOpenGenericResolutionError(
service_key,
f"Expected {len(open_registration.typevars)} type argument(s), got {len(args)}.",
)
typevar_map = dict(zip(open_registration.typevars, args, strict=True))
self._validate_typevar_map(service_key, typevar_map)
base = open_registration.registration
registration = Registration(
service_key=service_key,
factory=base.factory,
instance=base.instance,
lifetime=base.lifetime,
scope=base.scope,
is_async=base.is_async,
concrete_type=base.concrete_type,
typevar_map=typevar_map,
)
if base.scope is not None:
self._scoped_registry[(service_key, base.scope)] = registration
self._has_scoped_registrations = True
else:
self._registry[service_key] = registration
return registration
def _get_registration(
self,
service_key: ServiceKey,
current_scope: _ScopeId | None,
) -> Registration:
"""Get the registration for a service, checking scoped registry first.
Looks for a matching scoped registration based on the current scope hierarchy,
then falls back to the global registry, then auto-registration.
"""
# Check scoped registry - find the most specific matching scope
if current_scope is not None:
scoped_reg = self._get_scoped_registration(service_key, current_scope)
if scoped_reg is not None:
return scoped_reg
# Fall back to global registry
registration = self._registry.get(service_key)
if registration is not None:
return registration
registration = self._resolve_open_generic_registration(service_key, current_scope)
if registration is not None:
return registration
# Auto-register if enabled
if not self._autoregister:
raise DIWireServiceNotRegisteredError(service_key)
# Check if there's any scoped registration for this key before auto-registering
if self._has_scoped_registrations:
scoped_reg = self._find_any_scoped_registration(service_key)
if scoped_reg is not None:
raise DIWireScopeMismatchError(
service_key,
scoped_reg.scope, # type: ignore[arg-type]
current_scope.path if current_scope else None,
)
registration = self._get_auto_registration(service_key=service_key)
self._registry[service_key] = registration
return registration
def _find_any_scoped_registration(self, service_key: ServiceKey) -> Registration | None:
"""Find any scoped registration for the given service key, regardless of scope."""
for (sk, _scope_name), reg in self._scoped_registry.items():
if sk == service_key:
return reg
return None
def _get_cache_scope(
self,
current_scope: _ScopeId | None,
registered_scope: str | None,
) -> tuple[tuple[str | None, int], ...] | None:
"""Get the scope key to use for caching scoped instances.
Returns the tuple key up to and including the registered scope segment.
E.g., current=_ScopeId((("request", 1), ("child", 2))), registered="request"
-> (("request", 1),)
"""
if current_scope is None:
return None
if registered_scope is None:
return current_scope.segments
# Find segments up to and including the registered scope name
return current_scope.get_cache_key_for_scope(registered_scope)
def _scope_matches(self, current_scope: _ScopeId, registered_scope: str) -> bool:
"""Check if the current scope matches or contains the registered scope.
Uses tuple iteration instead of string operations for performance.
"""
return current_scope.contains_scope(registered_scope)
def _find_scope_in_dependencies(
self,
deps: dict[str, ServiceKey],
visited: set[ServiceKey] | None = None,
) -> str | None:
"""Find a scope from registered dependencies (recursively)."""
if visited is None:
visited = set()
for dep_key in deps.values():
if dep_key in visited:
continue
visited.add(dep_key)
# Collect all scopes from both registries
found_scopes: set[str] = set()
# Check global registry
registration = self._registry.get(dep_key)
if registration is not None and registration.scope is not None:
found_scopes.add(registration.scope)
# Check scoped registry for all entries matching this dep_key
for (service_key, _scope_name), scoped_reg in self._scoped_registry.items():
if service_key == dep_key and scoped_reg.scope is not None:
found_scopes.add(scoped_reg.scope)
# If we found exactly one unique scope, return it
if len(found_scopes) == 1:
return next(iter(found_scopes))
# If multiple different scopes (ambiguous), skip and check nested deps
# If no scopes found, also check nested deps
# Check nested dependencies (skip if extraction fails for non-class types)
try:
nested_deps = self._dependencies_extractor.get_dependencies(dep_key)
nested_scope = self._find_scope_in_dependencies(nested_deps, visited)
if nested_scope is not None:
return nested_scope
except DIWireError:
continue
return None
def _get_auto_registration(self, service_key: ServiceKey) -> Registration:
if service_key.component is not None:
raise DIWireComponentSpecifiedError(service_key)
if service_key.value in self._autoregister_ignores:
raise DIWireIgnoredServiceError(service_key)
if _is_union_type(service_key.value):
raise DIWireUnionTypeError(service_key)
if not isinstance(service_key.value, type):
raise DIWireNotAClassError(service_key)
for base_cls, registration_factory in self._autoregister_registration_factories.items():
if issubclass(service_key.value, base_cls):
return registration_factory(service_key.value)
return Registration(
service_key=service_key,
lifetime=self._autoregister_default_lifetime,
)
def _get_resolved_dependencies(
self,
service_key: ServiceKey,
*,
typevar_map: dict[Any, Any] | None = None,
) -> _ResolvedDependencies:
resolved_dependencies = _ResolvedDependencies()
dependencies = self._dependencies_extractor.get_dependencies_with_defaults(
service_key=service_key,
)
for name, param_info in dependencies.items():
if self._handle_typevar_dependency(
service_key=service_key,
name=name,
param_info=param_info,
typevar_map=typevar_map,
resolved_dependencies=resolved_dependencies,
):
continue
# Skip ignored types that aren't explicitly registered
if param_info.service_key.value in self._autoregister_ignores:
# Check both global and scoped registries before marking as missing
is_registered = param_info.service_key in self._registry
if not is_registered and self._has_scoped_registrations:
current_scope = _current_scope.get()
if current_scope is not None:
is_registered = (
self._get_scoped_registration(param_info.service_key, current_scope)
is not None
)
if not is_registered:
if param_info.has_default:
continue
resolved_dependencies.missing.append(param_info.service_key)
continue
try:
resolved_dependencies.dependencies[name] = self.resolve(param_info.service_key)
except (
DIWireCircularDependencyError,
DIWireScopeMismatchError,
DIWireAsyncDependencyInSyncContextError,
):
raise
except DIWireError:
if not param_info.has_default:
resolved_dependencies.missing.append(param_info.service_key)
return resolved_dependencies