from __future__ import annotations
import ast
import functools
import inspect
import logging
import textwrap
import threading
from collections.abc import AsyncGenerator, Awaitable, Callable, Generator, Mapping
from contextlib import contextmanager, suppress
from dataclasses import dataclass
from types import TracebackType
from typing import (
Annotated,
Any,
Literal,
TypeVar,
cast,
get_args,
get_origin,
overload,
)
from diwire._internal.autoregistration import ConcreteTypeAutoregistrationPolicy
from diwire._internal.injection import (
INJECT_RESOLVER_KWARG,
INJECT_WRAPPER_MARKER,
InjectedCallableInspector,
InjectedParameter,
)
from diwire._internal.integrations.pydantic_settings import is_pydantic_settings_subclass
from diwire._internal.lock_mode import LockMode
from diwire._internal.markers import (
Component,
ProviderMarker,
build_annotated_key,
component_base_key,
is_all_annotation,
is_maybe_annotation,
is_provider_annotation,
strip_all_annotation,
strip_maybe_annotation,
strip_non_component_annotation,
strip_provider_annotation,
)
from diwire._internal.open_generics import (
OpenGenericRegistry,
OpenGenericResolver,
canonicalize_open_key,
contains_typevar,
substitute_typevars,
)
from diwire._internal.policies import DependencyRegistrationPolicy, MissingPolicy
from diwire._internal.providers import (
ContextManagerProvider,
FactoryProvider,
GeneratorProvider,
Lifetime,
ProviderDependenciesExtractor,
ProviderDependency,
ProviderReturnTypeExtractor,
ProviderSpec,
ProvidersRegistrations,
)
from diwire._internal.resolver_context import (
ResolverContext,
resolver_context as default_resolver_context,
)
from diwire._internal.resolvers.manager import ResolversManager
from diwire._internal.resolvers.protocol import ResolverProtocol
from diwire._internal.scope import BaseScope, Scope
from diwire._internal.validators import DependecyRegistrationValidator
from diwire.exceptions import (
DIWireDependencyNotRegisteredError,
DIWireError,
DIWireInvalidRegistrationError,
DIWireScopeMismatchError,
)
T = TypeVar("T")
InjectableF = TypeVar("InjectableF", bound=Callable[..., Any])
logger = logging.getLogger(__name__)
_MISSING_CLOSED_GENERIC_INJECTION = object()
[docs]
class Container:
"""Manage dependency registration, resolution, scoping, and cleanup.
Dependency keys are usually concrete types, protocols, or
``typing.Annotated`` tokens (for example ``Annotated[Db, Component("ro")]``).
Closed generic keys are also supported when matching open-generic
registrations.
Use registrations for explicit control, or keep autoregistration enabled to
auto-wire eligible concrete classes and their dependencies. Disable
autoregistration for strict mode where every dependency must be registered
explicitly.
Resolution happens through a compiled resolver graph. ``resolve`` runs sync
graphs, ``aresolve`` runs async graphs, and ``enter_scope`` creates nested
resolvers that own scoped caches and cleanup callbacks. Registration
mutations invalidate compilation automatically.
"""
# Hot-path methods rebound to the compiled root resolver to keep steady-state
# resolution/scope calls on generated fast paths.
_ENTRYPOINT_METHOD_NAMES: tuple[str, ...] = (
"resolve",
"aresolve",
"enter_scope",
)
_OPEN_GENERIC_MATERIALIZATION_MAX_ITERATIONS: int = 10_000
_OPEN_GENERIC_MATERIALIZATION_STATE_TAIL_SIZE: int = 3
[docs]
def __init__(
self,
root_scope: BaseScope = Scope.APP,
default_lifetime: Lifetime = Lifetime.SCOPED,
*,
lock_mode: LockMode | Literal["auto"] = "auto",
missing_policy: MissingPolicy = MissingPolicy.REGISTER_RECURSIVE,
dependency_registration_policy: DependencyRegistrationPolicy = (
DependencyRegistrationPolicy.REGISTER_RECURSIVE
),
resolver_context: ResolverContext = default_resolver_context,
use_resolver_context: bool = True,
) -> None:
"""Initialize a container and configure default registration behavior.
By default, the container recursively auto-registers eligible concrete
dependencies both at resolve time and while registering providers.
Use strict mode when you want full registration control by opting in to
``missing_policy=MissingPolicy.ERROR`` and
``dependency_registration_policy=DependencyRegistrationPolicy.IGNORE``.
``lock_mode="auto"`` selects thread locks for sync-only cached paths and
async locks when async resolution paths are present.
Args:
root_scope: Root scope for resolver ownership and root-scoped caches.
default_lifetime: Default lifetime used by registrations that omit
``lifetime``.
lock_mode: Container default lock strategy for non-instance
registrations. Accepts ``LockMode`` or ``"auto"``.
missing_policy: Default policy for resolve-time missing dependencies.
dependency_registration_policy: Default policy for dependency
autoregistration during registration/injection.
resolver_context: Resolver context used by ``resolver_context.inject``
fallback behavior for this container.
use_resolver_context: Wrap compiled resolvers so context-manager
entry binds into ``resolver_context``.
Notes:
Common presets are: auto-wiring mode (default, both recursive),
strict mode (opt-in), and root-only resolve mode
(``missing_policy=MissingPolicy.REGISTER_ROOT``).
Examples:
.. code-block:: python
container = Container()
strict_container = Container(
missing_policy=MissingPolicy.ERROR,
dependency_registration_policy=DependencyRegistrationPolicy.IGNORE,
)
threaded_container = Container(lock_mode=LockMode.THREAD)
"""
self._root_scope = root_scope
self._default_lifetime = default_lifetime
self._lock_mode = lock_mode
self._missing_policy = self._resolve_container_missing_policy(missing_policy)
self._dependency_registration_policy = (
self._resolve_container_dependency_registration_policy(
dependency_registration_policy,
)
)
self._resolver_context = resolver_context
self._use_resolver_context = use_resolver_context
self._concrete_autoregistration_policy = ConcreteTypeAutoregistrationPolicy()
self._provider_dependencies_extractor = ProviderDependenciesExtractor()
self._provider_return_type_extractor = ProviderReturnTypeExtractor()
self._dependency_registration_validator = DependecyRegistrationValidator()
self._providers_registrations = ProvidersRegistrations()
self._open_generic_registry = OpenGenericRegistry()
self._resolvers_manager = ResolversManager()
self._injected_callable_inspector = InjectedCallableInspector()
self._root_resolver: ResolverProtocol | None = None
self._entered_root_resolver: ResolverProtocol | None = None
self._graph_revision: int = 0
self._registration_mutation_depth: int = 0
self._registration_mutation_snapshot: _ContainerGraphSnapshot | None = None
self._registration_mutation_failed: bool = False
self._decoration_rules_by_provides: dict[Any, list[_DecorationRule]] = {}
self._decoration_chain_by_provides: dict[Any, _DecorationChain] = {}
self._decoration_counter: int = 0
self._injected_scope_contracts: list[_InjectedScopeContract] = []
self._runtime_materialized_closed_keys: set[Any] = set()
self._graph_state_lock: Any = threading.RLock()
self._container_entrypoints: dict[str, Callable[..., Any]] = {
method_name: getattr(self, method_name) for method_name in self._ENTRYPOINT_METHOD_NAMES
}
self._resolver_context.set_fallback_container(self)
# region Registration Methods
[docs]
def add_instance(
self,
instance: T,
*,
provides: Any | Literal["infer"] = "infer",
component: Component | Any | None = None,
) -> None:
"""Register a pre-built instance as a provider.
This is the simplest way to bind configuration objects or singleton
clients. Re-registering the same dependency key overrides the previous
spec.
Args:
instance: Instance value to return on resolution.
provides: Dependency key to bind. Use ``"infer"`` to bind by
``type(instance)``.
component: Optional component marker value used to register under
``Annotated[provides, Component(...)]``.
Raises:
DIWireInvalidRegistrationError: If ``provides`` is ``None``.
Notes:
Instance specs always use ``LockMode.NONE`` because value creation is
not deferred.
Examples:
.. code-block:: python
settings = Settings(api_url="https://api.example.com")
container.add_instance(settings)
resolved = container.resolve(Settings)
"""
provides_value = cast("Any", provides)
if provides_value == "infer":
resolved_provides: Any = type(instance)
elif provides_value is not None:
resolved_provides = provides_value
else:
msg = "add_instance() parameter 'provides' must not be None; use 'infer'."
raise DIWireInvalidRegistrationError(msg)
resolved_provides_with_component = self._resolve_registration_component_provides(
provides=resolved_provides,
component=component,
method_name="add_instance",
)
registration_provides, has_decoration_chain = self._resolve_registration_target_provides(
resolved_provides_with_component,
)
with self._registration_mutation():
self._providers_registrations.add(
ProviderSpec(
provides=registration_provides,
instance=instance,
lifetime=self._default_lifetime,
scope=self._root_scope,
is_async=False,
is_any_dependency_async=False,
needs_cleanup=False,
lock_mode=LockMode.NONE,
),
)
self._finalize_registration_after_binding(
original_provides=resolved_provides_with_component,
has_decoration_chain=has_decoration_chain,
)
[docs]
def add(
self,
concrete_type: type[Any],
*,
provides: Any | Literal["infer"] = "infer",
component: Component | Any | None = None,
scope: BaseScope | Literal["from_container"] = "from_container",
lifetime: Lifetime | Literal["from_container"] = "from_container",
dependencies: Mapping[Any, inspect.Parameter] | Literal["infer"] = "infer",
lock_mode: LockMode | Literal["from_container"] = "from_container",
dependency_registration_policy: DependencyRegistrationPolicy
| Literal["from_container"] = "from_container",
) -> None:
"""Register a concrete type provider.
``provides`` may be a protocol, concrete type, annotated token, or open
generic key. Dependencies are inferred from constructor annotations
unless explicit dependencies are passed.
Args:
concrete_type: Concrete class to instantiate.
provides: Dependency key produced by this provider. ``"infer"`` uses
``concrete_type`` directly.
component: Optional component marker value used to register under
``Annotated[provides, Component(...)]``.
scope: Provider scope, or ``"from_container"`` to inherit root scope.
lifetime: Provider lifetime, or ``"from_container"`` to inherit
container default.
dependencies: Explicit dependency mapping from dependency key to
provider parameter, or ``"infer"`` for annotation inference.
lock_mode: Lock strategy, or ``"from_container"`` to inherit the
container lock mode.
dependency_registration_policy: Override dependency autoregistration for
this registration.
Raises:
DIWireInvalidRegistrationError: If parameters are invalid or scope
contracts cannot be satisfied.
DIWireInvalidProviderSpecError: If explicit dependencies do not match
provider signature.
DIWireProviderDependencyInferenceError: If dependencies cannot be
inferred from annotations.
Notes:
``lock_mode="from_container"`` inherits the container-level mode.
Open generic registration is enabled when ``provides`` contains
TypeVars.
Examples:
.. code-block:: python
container.add(SqlRepo, provides=Repo)
container.add(CachedRepo, provides=Repo)
"""
resolved_provides, resolved_concrete_type = self._resolve_concrete_registration_types(
provides=provides,
concrete_type=concrete_type,
)
resolved_provides_with_component = self._resolve_registration_component_provides(
provides=resolved_provides,
component=component,
method_name="add",
)
resolved_scope = self._resolve_registration_scope(
scope=scope,
method_name="add",
)
resolved_lifetime = self._resolve_registration_lifetime(
lifetime=lifetime,
method_name="add",
)
explicit_dependencies = self._resolve_registration_dependencies(
dependencies=dependencies,
method_name="add",
)
resolved_dependency_registration_policy = (
self._resolve_registration_dependency_registration_policy(
dependency_registration_policy=dependency_registration_policy,
method_name="add",
)
)
self._dependency_registration_validator.validate_concrete_type(
concrete_type=resolved_concrete_type,
)
dependencies_for_provider = self._resolve_concrete_registration_dependencies(
concrete_type=resolved_concrete_type,
explicit_dependencies=explicit_dependencies,
)
is_any_dependency_async = self._provider_return_type_extractor.is_any_dependency_async(
dependencies_for_provider,
)
resolved_lock_mode = self._resolve_provider_lock_mode(lock_mode)
registration_provides, has_decoration_chain = self._resolve_registration_target_provides(
resolved_provides_with_component,
)
with self._registration_mutation():
if (
self._open_generic_registry.register(
provides=registration_provides,
provider_kind="concrete_type",
provider=resolved_concrete_type,
lifetime=resolved_lifetime,
scope=resolved_scope,
lock_mode=resolved_lock_mode,
is_async=False,
is_any_dependency_async=is_any_dependency_async,
needs_cleanup=False,
dependencies=dependencies_for_provider,
)
is not None
):
self._autoregister_provider_dependencies(
dependencies=dependencies_for_provider,
scope=resolved_scope,
lifetime=resolved_lifetime,
dependency_registration_policy=self._resolve_dependency_registration_policy(
resolved_dependency_registration_policy,
),
)
self._finalize_registration_after_binding(
original_provides=resolved_provides_with_component,
has_decoration_chain=has_decoration_chain,
)
return
(
closed_generic_injections,
dependencies_for_provider,
) = self._resolve_closed_concrete_generic_injections(
provides=resolved_provides_with_component,
dependencies=dependencies_for_provider,
)
is_any_dependency_async = self._provider_return_type_extractor.is_any_dependency_async(
dependencies_for_provider,
)
if closed_generic_injections:
concrete_factory = self._build_closed_concrete_factory(
concrete_type=resolved_concrete_type,
injected_arguments=closed_generic_injections,
)
self._providers_registrations.add(
ProviderSpec(
provides=registration_provides,
factory=concrete_factory,
lifetime=resolved_lifetime,
scope=resolved_scope,
dependencies=dependencies_for_provider,
is_async=False,
is_any_dependency_async=is_any_dependency_async,
needs_cleanup=False,
lock_mode=resolved_lock_mode,
),
)
self._autoregister_provider_dependencies(
dependencies=dependencies_for_provider,
scope=resolved_scope,
lifetime=resolved_lifetime,
dependency_registration_policy=self._resolve_dependency_registration_policy(
resolved_dependency_registration_policy,
),
)
self._finalize_registration_after_binding(
original_provides=resolved_provides_with_component,
has_decoration_chain=has_decoration_chain,
)
return
self._providers_registrations.add(
ProviderSpec(
provides=registration_provides,
concrete_type=resolved_concrete_type,
lifetime=resolved_lifetime,
scope=resolved_scope,
dependencies=dependencies_for_provider,
is_async=False,
is_any_dependency_async=is_any_dependency_async,
needs_cleanup=False,
lock_mode=resolved_lock_mode,
),
)
self._autoregister_provider_dependencies(
dependencies=dependencies_for_provider,
scope=resolved_scope,
lifetime=resolved_lifetime,
dependency_registration_policy=self._resolve_dependency_registration_policy(
resolved_dependency_registration_policy,
),
)
self._finalize_registration_after_binding(
original_provides=resolved_provides_with_component,
has_decoration_chain=has_decoration_chain,
)
[docs]
def add_factory(
self,
factory: Callable[..., Any] | Callable[..., Awaitable[Any]],
*,
provides: Any | Literal["infer"] = "infer",
component: Component | Any | None = None,
scope: BaseScope | Literal["from_container"] = "from_container",
lifetime: Lifetime | Literal["from_container"] = "from_container",
dependencies: Mapping[Any, inspect.Parameter] | Literal["infer"] = "infer",
lock_mode: LockMode | Literal["from_container"] = "from_container",
dependency_registration_policy: DependencyRegistrationPolicy
| Literal["from_container"] = "from_container",
) -> None:
"""Register a factory provider.
``provides`` may be a protocol, concrete type, annotated token, or open
generic key. Dependencies are inferred from factory parameters unless
explicit dependencies are passed.
Args:
factory: Provider function/callable.
provides: Dependency key produced by the factory. ``"infer"`` uses
the return annotation.
component: Optional component marker value used to register under
``Annotated[provides, Component(...)]``.
scope: Provider scope, or ``"from_container"``.
lifetime: Provider lifetime, or ``"from_container"``.
dependencies: Explicit dependency mapping, or ``"infer"``.
lock_mode: Lock strategy, or ``"from_container"``.
dependency_registration_policy: Override dependency autoregistration for
this registration.
Raises:
DIWireInvalidRegistrationError: If configuration or annotations are
invalid.
DIWireInvalidProviderSpecError: If explicit dependencies do not match
factory parameters.
DIWireProviderDependencyInferenceError: If required dependencies
cannot be inferred.
Notes:
``lock_mode="from_container"`` inherits the container-level mode.
Open-generic factories can inject type arguments by accepting
``type[T]`` or ``T`` parameters in dependencies.
Examples:
.. code-block:: python
container.add_factory(lambda settings: Client(settings), provides=Client)
def build_box(value_type: type[T]) -> Box[T]:
return Box(value_type)
container.add_factory(build_box, provides=Box[T])
"""
factory_value = cast("Any", factory)
if not callable(factory_value):
msg = "add_factory() parameter 'factory' must be callable."
raise DIWireInvalidRegistrationError(msg)
factory_provider = cast("FactoryProvider[Any]", factory_value)
resolved_provides = self._resolve_registration_provides(
provides=provides,
method_name="add_factory",
infer_from=lambda: self._provider_return_type_extractor.extract_from_factory(
factory=factory_provider,
),
)
resolved_provides_with_component = self._resolve_registration_component_provides(
provides=resolved_provides,
component=component,
method_name="add_factory",
)
resolved_scope = self._resolve_registration_scope(
scope=scope,
method_name="add_factory",
)
resolved_lifetime = self._resolve_registration_lifetime(
lifetime=lifetime,
method_name="add_factory",
)
explicit_dependencies = self._resolve_registration_dependencies(
dependencies=dependencies,
method_name="add_factory",
)
resolved_dependency_registration_policy = (
self._resolve_registration_dependency_registration_policy(
dependency_registration_policy=dependency_registration_policy,
method_name="add_factory",
)
)
dependencies_for_provider = self._resolve_factory_registration_dependencies(
factory=factory_provider,
explicit_dependencies=explicit_dependencies,
)
is_async = self._provider_return_type_extractor.is_factory_async(factory_provider)
is_any_dependency_async = self._provider_return_type_extractor.is_any_dependency_async(
dependencies_for_provider,
)
resolved_lock_mode = self._resolve_provider_lock_mode(lock_mode)
self._register_non_concrete_provider(
provides=resolved_provides_with_component,
provider_kind="factory",
provider=factory_provider,
provider_field="factory",
lifetime=resolved_lifetime,
scope=resolved_scope,
lock_mode=resolved_lock_mode,
is_async=is_async,
is_any_dependency_async=is_any_dependency_async,
needs_cleanup=False,
dependencies=dependencies_for_provider,
resolved_dependency_registration_policy=resolved_dependency_registration_policy,
)
[docs]
def add_generator(
self,
generator: (
Callable[..., Generator[Any, None, None]] | Callable[..., AsyncGenerator[Any, None]]
),
*,
provides: Any | Literal["infer"] = "infer",
component: Component | Any | None = None,
scope: BaseScope | Literal["from_container"] = "from_container",
lifetime: Lifetime | Literal["from_container"] = "from_container",
dependencies: Mapping[Any, inspect.Parameter] | Literal["infer"] = "infer",
lock_mode: LockMode | Literal["from_container"] = "from_container",
dependency_registration_policy: DependencyRegistrationPolicy
| Literal["from_container"] = "from_container",
require_generator_finally: bool = True,
) -> None:
"""Register a generator or async-generator provider with cleanup.
The yielded value is resolved as the dependency, and teardown runs when
the owning resolver scope exits (or container closes for root scope).
Args:
generator: Generator provider.
provides: Dependency key produced by the yield value.
component: Optional component marker value used to register under
``Annotated[provides, Component(...)]``.
scope: Provider scope, or ``"from_container"``.
lifetime: Provider lifetime, or ``"from_container"``.
dependencies: Explicit dependency mapping, or ``"infer"``.
lock_mode: Lock strategy, or ``"from_container"``.
dependency_registration_policy: Override dependency autoregistration.
require_generator_finally: Validate that every ``yield`` / ``yield from``
appears inside the body of a ``try`` with a non-empty ``finally``.
Pass ``False`` to skip this validation for intentionally unusual
generator providers.
Raises:
DIWireInvalidRegistrationError: If registration arguments are invalid.
DIWireInvalidProviderSpecError: If explicit dependencies are invalid.
DIWireProviderDependencyInferenceError: If dependency inference fails.
Notes:
Cleanup is deterministic only when the owning resolver is closed
(`with`/`async with` or explicit close/aclose).
By default, registration validates generator source to enforce
``yield`` placement inside ``try/finally`` blocks.
Examples:
.. code-block:: python
def open_session(engine: Engine) -> Generator[Session, None, None]:
with Session(engine) as session:
yield session
container.add_generator(
open_session,
scope=Scope.REQUEST,
provides=Session,
)
"""
generator_value = cast("Any", generator)
if not callable(generator_value):
msg = "add_generator() parameter 'generator' must be callable."
raise DIWireInvalidRegistrationError(msg)
require_generator_finally_value = cast("Any", require_generator_finally)
if not isinstance(require_generator_finally_value, bool):
msg = "add_generator() parameter 'require_generator_finally' must be bool."
raise DIWireInvalidRegistrationError(msg)
generator_provider = cast("GeneratorProvider[Any]", generator_value)
resolved_provides = self._resolve_registration_provides(
provides=provides,
method_name="add_generator",
infer_from=lambda: self._provider_return_type_extractor.extract_from_generator(
generator=generator_provider,
),
)
resolved_provides_with_component = self._resolve_registration_component_provides(
provides=resolved_provides,
component=component,
method_name="add_generator",
)
resolved_scope = self._resolve_registration_scope(
scope=scope,
method_name="add_generator",
)
resolved_lifetime = self._resolve_registration_lifetime(
lifetime=lifetime,
method_name="add_generator",
)
explicit_dependencies = self._resolve_registration_dependencies(
dependencies=dependencies,
method_name="add_generator",
)
resolved_dependency_registration_policy = (
self._resolve_registration_dependency_registration_policy(
dependency_registration_policy=dependency_registration_policy,
method_name="add_generator",
)
)
dependencies_for_provider = self._resolve_generator_registration_dependencies(
generator=generator_provider,
explicit_dependencies=explicit_dependencies,
)
if require_generator_finally_value:
self._validate_generator_provider_uses_try_finally(generator_provider)
is_async = self._provider_return_type_extractor.is_generator_async(generator_provider)
is_any_dependency_async = self._provider_return_type_extractor.is_any_dependency_async(
dependencies_for_provider,
)
resolved_lock_mode = self._resolve_provider_lock_mode(lock_mode)
self._register_non_concrete_provider(
provides=resolved_provides_with_component,
provider_kind="generator",
provider=generator_provider,
provider_field="generator",
lifetime=resolved_lifetime,
scope=resolved_scope,
lock_mode=resolved_lock_mode,
is_async=is_async,
is_any_dependency_async=is_any_dependency_async,
needs_cleanup=True,
dependencies=dependencies_for_provider,
resolved_dependency_registration_policy=resolved_dependency_registration_policy,
)
def _validate_generator_provider_uses_try_finally(
self,
provider: GeneratorProvider[Any],
) -> None:
unwrapped_provider = inspect.unwrap(provider)
if not (
inspect.isgeneratorfunction(unwrapped_provider)
or inspect.isasyncgenfunction(unwrapped_provider)
):
return
provider_name = self._callable_name(cast("Callable[..., Any]", unwrapped_provider))
function_node: ast.FunctionDef | ast.AsyncFunctionDef
try:
source = inspect.getsource(unwrapped_provider)
parsed_source = ast.parse(textwrap.dedent(source))
function_node = self._resolve_generator_function_ast_node(
parsed_source=parsed_source,
provider_name=provider_name,
function_name=getattr(unwrapped_provider, "__name__", ""),
)
except (OSError, SyntaxError, TypeError, ValueError) as error:
msg = (
"add_generator() could not inspect generator provider "
f"'{provider_name}' for try/finally validation; pass "
"require_generator_finally=False if you intentionally want to skip "
"this validation."
)
raise DIWireInvalidRegistrationError(msg) from error
if self._generator_yields_are_protected_by_try_finally(function_node):
return
msg = (
"add_generator() provider "
f"'{provider_name}' must place every yield/yield from inside the body "
"of a try block with a non-empty finally block; pass "
"require_generator_finally=False if you intentionally want to skip this "
"validation."
)
raise DIWireInvalidRegistrationError(msg)
def _resolve_generator_function_ast_node(
self,
*,
parsed_source: ast.Module,
provider_name: str,
function_name: str,
) -> ast.FunctionDef | ast.AsyncFunctionDef:
function_nodes = [
node
for node in parsed_source.body
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
]
matching_function_nodes = [node for node in function_nodes if node.name == function_name]
if len(matching_function_nodes) == 1:
return matching_function_nodes[0]
if len(function_nodes) == 1:
return function_nodes[0]
msg = (
"add_generator() could not inspect generator provider "
f"'{provider_name}' for try/finally validation; pass "
"require_generator_finally=False if you intentionally want to skip "
"this validation."
)
raise DIWireInvalidRegistrationError(msg)
def _generator_yields_are_protected_by_try_finally(
self,
function_node: ast.FunctionDef | ast.AsyncFunctionDef,
) -> bool:
try_star_type = cast("type[ast.AST] | None", getattr(ast, "TryStar", None))
try_node_types = (ast.Try,) if try_star_type is None else (ast.Try, try_star_type)
def _walk_node(node: ast.AST, *, inside_valid_try_body: bool) -> bool:
if isinstance(node, (ast.Yield, ast.YieldFrom)):
return inside_valid_try_body
if isinstance(node, try_node_types):
try_node = cast("Any", node)
body_is_valid = inside_valid_try_body or bool(try_node.finalbody)
for statement in try_node.body:
if not _walk_node(statement, inside_valid_try_body=body_is_valid):
return False
for handler in try_node.handlers:
if not _walk_node(handler, inside_valid_try_body=inside_valid_try_body):
return False
for statement in try_node.orelse:
if not _walk_node(statement, inside_valid_try_body=inside_valid_try_body):
return False
for statement in try_node.finalbody:
if not _walk_node(statement, inside_valid_try_body=inside_valid_try_body):
return False
return True
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)):
return True
for child in ast.iter_child_nodes(node):
if not _walk_node(child, inside_valid_try_body=inside_valid_try_body):
return False
return True
return all(
_walk_node(statement, inside_valid_try_body=False) for statement in function_node.body
)
[docs]
def add_context_manager(
self,
context_manager: ContextManagerProvider[Any],
*,
provides: Any | Literal["infer"] = "infer",
component: Component | Any | None = None,
scope: BaseScope | Literal["from_container"] = "from_container",
lifetime: Lifetime | Literal["from_container"] = "from_container",
dependencies: Mapping[Any, inspect.Parameter] | Literal["infer"] = "infer",
lock_mode: LockMode | Literal["from_container"] = "from_container",
dependency_registration_policy: DependencyRegistrationPolicy
| Literal["from_container"] = "from_container",
) -> None:
"""Register a context-manager or async-context-manager provider.
The entered value is resolved as the dependency, and ``__exit__`` /
``__aexit__`` runs when the owning resolver scope exits.
Args:
context_manager: Context-manager provider.
provides: Dependency key produced by the entered value.
component: Optional component marker value used to register under
``Annotated[provides, Component(...)]``.
scope: Provider scope, or ``"from_container"``.
lifetime: Provider lifetime, or ``"from_container"``.
dependencies: Explicit dependency mapping, or ``"infer"``.
lock_mode: Lock strategy, or ``"from_container"``.
dependency_registration_policy: Override dependency autoregistration.
Raises:
DIWireInvalidRegistrationError: If registration arguments are invalid.
DIWireInvalidProviderSpecError: If explicit dependencies are invalid.
DIWireProviderDependencyInferenceError: If dependency inference fails.
Notes:
Cleanup runs at scope/container exit. For request resources, register
under ``Scope.REQUEST`` and resolve inside a request scope.
Examples:
.. code-block:: python
def session(engine: Engine) -> ContextManager[Session]:
return Session(engine)
container.add_context_manager(
session,
scope=Scope.REQUEST,
provides=Session,
)
"""
context_manager_value = cast("Any", context_manager)
if not callable(context_manager_value):
msg = "add_context_manager() parameter 'context_manager' must be callable."
raise DIWireInvalidRegistrationError(msg)
context_manager_provider = cast("ContextManagerProvider[Any]", context_manager_value)
resolved_provides = self._resolve_registration_provides(
provides=provides,
method_name="add_context_manager",
infer_from=lambda: self._provider_return_type_extractor.extract_from_context_manager(
context_manager=context_manager_provider,
),
)
resolved_provides_with_component = self._resolve_registration_component_provides(
provides=resolved_provides,
component=component,
method_name="add_context_manager",
)
resolved_scope = self._resolve_registration_scope(
scope=scope,
method_name="add_context_manager",
)
resolved_lifetime = self._resolve_registration_lifetime(
lifetime=lifetime,
method_name="add_context_manager",
)
explicit_dependencies = self._resolve_registration_dependencies(
dependencies=dependencies,
method_name="add_context_manager",
)
resolved_dependency_registration_policy = (
self._resolve_registration_dependency_registration_policy(
dependency_registration_policy=dependency_registration_policy,
method_name="add_context_manager",
)
)
dependencies_for_provider = self._resolve_context_manager_registration_dependencies(
context_manager=context_manager_provider,
explicit_dependencies=explicit_dependencies,
)
is_async = self._provider_return_type_extractor.is_context_manager_async(
context_manager_provider,
)
is_any_dependency_async = self._provider_return_type_extractor.is_any_dependency_async(
dependencies_for_provider,
)
resolved_lock_mode = self._resolve_provider_lock_mode(lock_mode)
self._register_non_concrete_provider(
provides=resolved_provides_with_component,
provider_kind="context_manager",
provider=context_manager_provider,
provider_field="context_manager",
lifetime=resolved_lifetime,
scope=resolved_scope,
lock_mode=resolved_lock_mode,
is_async=is_async,
is_any_dependency_async=is_any_dependency_async,
needs_cleanup=True,
dependencies=dependencies_for_provider,
resolved_dependency_registration_policy=resolved_dependency_registration_policy,
)
[docs]
def decorate(
self,
*,
provides: Any,
component: Component | Any | None = None,
decorator: Callable[..., Any],
inner_parameter: str | None = None,
) -> None:
"""Decorate an existing or future provider binding for a dependency key.
Decoration rules are persistent for the container lifetime. If a binding
exists now, decoration is applied immediately. Otherwise the rule is
stored and applied automatically when the key is registered later.
Args:
provides: Dependency key whose binding should be decorated.
component: Optional component marker used to qualify the dependency key.
decorator: Factory-style callable that receives the inner value and returns a decorated value.
inner_parameter: Optional decorator parameter name that should receive the inner value.
"""
if provides is None:
msg = "decorate() parameter 'provides' must not be None."
raise DIWireInvalidRegistrationError(msg)
resolved_provides = self._resolve_registration_component_provides(
provides=provides,
component=component,
method_name="decorate",
)
normalized_provides = self._normalize_decoration_provides_key(resolved_provides)
with self._registration_mutation():
self._register_decoration_rule(
provides=normalized_provides,
decorator=decorator,
inner_parameter=inner_parameter,
)
if self._decoration_chain_by_provides.get(normalized_provides) is not None:
self._ensure_chain_keys(provides=normalized_provides)
self._rebuild_decoration_chain(provides=normalized_provides)
self._invalidate_compilation()
return
if self._has_registered_binding(normalized_provides):
self._apply_pending_decorations(provides=normalized_provides)
self._invalidate_compilation()
def _register_decoration_rule(
self,
*,
provides: Any,
decorator: Callable[..., Any],
inner_parameter: str | None,
) -> None:
decorator_callable = self._validate_decorator_callable(decorator)
dependencies = self._extract_decoration_dependencies(
decorator=decorator_callable,
)
resolved_inner_parameter = self._resolve_decoration_inner_parameter(
provides=provides,
dependencies=dependencies,
inner_parameter=inner_parameter,
decorator=decorator_callable,
)
is_async = self._provider_return_type_extractor.is_factory_async(decorator_callable)
rules = self._decoration_rules_by_provides.setdefault(provides, [])
rules.append(
_DecorationRule(
decorator=decorator_callable,
inner_parameter=resolved_inner_parameter,
dependencies=tuple(dependencies),
is_async=is_async,
),
)
def _validate_decorator_callable(
self,
decorator: Any,
) -> Callable[..., Any]:
decorator_value = cast("Any", decorator)
if not callable(decorator_value):
msg = "decorate() parameter 'decorator' must be callable."
raise DIWireInvalidRegistrationError(msg)
unwrapped = inspect.unwrap(decorator_value)
if inspect.isgeneratorfunction(unwrapped) or inspect.isasyncgenfunction(unwrapped):
msg = (
"decorate() parameter 'decorator' must be a sync/async factory-style "
"callable, not a generator or async-generator function."
)
raise DIWireInvalidRegistrationError(msg)
return cast("Callable[..., Any]", decorator_value)
def _extract_decoration_dependencies(
self,
*,
decorator: Callable[..., Any],
) -> list[ProviderDependency]:
try:
return self._provider_dependencies_extractor.extract_from_factory(
factory=decorator,
)
except DIWireError as error:
msg = (
"decorate() could not infer dependencies for decorator "
f"'{self._callable_name(decorator)}': {error}"
)
raise DIWireInvalidRegistrationError(msg) from error
def _resolve_decoration_inner_parameter(
self,
*,
provides: Any,
dependencies: list[ProviderDependency],
inner_parameter: str | None,
decorator: Callable[..., Any],
) -> str:
if inner_parameter is not None:
if any(dependency.parameter.name == inner_parameter for dependency in dependencies):
return inner_parameter
msg = (
"decorate() parameter 'inner_parameter' must match one of the decorator's "
"injectable parameters."
)
raise DIWireInvalidRegistrationError(msg)
matched_parameter_names = [
dependency.parameter.name
for dependency in dependencies
if dependency.provides == provides
]
if len(matched_parameter_names) == 1:
return matched_parameter_names[0]
if not matched_parameter_names:
msg = (
"decorate() could not infer the inner parameter for decorator "
f"'{self._callable_name(decorator)}' and provides {provides!r}. "
"Pass inner_parameter='...'."
)
raise DIWireInvalidRegistrationError(msg)
msg = (
"decorate() found multiple inner parameter candidates for decorator "
f"'{self._callable_name(decorator)}' and provides {provides!r}. "
"Pass inner_parameter='...'."
)
raise DIWireInvalidRegistrationError(msg)
def _finalize_registration_after_binding(
self,
*,
original_provides: Any,
has_decoration_chain: bool,
) -> None:
normalized_provides = self._normalize_decoration_provides_key(original_provides)
if has_decoration_chain:
self._rebuild_decoration_chain(provides=normalized_provides)
elif self._decoration_rules_by_provides.get(normalized_provides):
self._apply_pending_decorations(provides=normalized_provides)
self._invalidate_compilation()
def _resolve_registration_target_provides(self, provides: Any) -> tuple[Any, bool]:
normalized_provides = self._normalize_decoration_provides_key(provides)
chain = self._decoration_chain_by_provides.get(normalized_provides)
if chain is None:
return provides, False
return chain.base_key, True
def _apply_pending_decorations(self, *, provides: Any) -> None:
rules = self._decoration_rules_by_provides.get(provides)
if not rules:
return
if not self._has_registered_binding(provides):
return
rule_count = len(rules)
chain = self._build_decoration_chain(
provides=provides,
rule_count=rule_count,
)
self._move_current_binding_to_base_key(
provides=provides,
base_key=chain.base_key,
)
self._rebuild_decoration_chain(
provides=provides,
chain=chain,
)
self._decoration_chain_by_provides[provides] = chain
def _ensure_chain_keys(self, *, provides: Any) -> None:
chain = self._decoration_chain_by_provides.get(provides)
if chain is None:
return
rules = self._decoration_rules_by_provides.get(provides, [])
expected_layers = len(rules)
if expected_layers == len(chain.layer_keys):
return
if expected_layers < len(chain.layer_keys):
msg = f"Decoration chain for {provides!r} has more layers than rules."
raise DIWireInvalidRegistrationError(msg)
while len(chain.layer_keys) < expected_layers:
insertion_index = max(len(chain.layer_keys) - 1, 0)
chain.layer_keys.insert(
insertion_index,
self._create_decoration_alias_key(
provides=provides,
layer=insertion_index,
),
)
def _build_decoration_chain(
self,
*,
provides: Any,
rule_count: int,
) -> _DecorationChain:
base_key = self._create_decoration_alias_key(
provides=provides,
layer=-1,
)
layer_keys: list[Any] = [provides]
if rule_count > 1:
layer_keys = [
self._create_decoration_alias_key(
provides=provides,
layer=layer,
)
for layer in range(rule_count - 1)
]
layer_keys.append(provides)
return _DecorationChain(
base_key=base_key,
layer_keys=layer_keys,
)
def _move_current_binding_to_base_key(
self,
*,
provides: Any,
base_key: Any,
) -> None:
if self._is_open_generic_provides(provides):
open_spec = self._open_generic_registry.find_exact(provides)
if open_spec is None:
msg = f"Cannot decorate {provides!r}: base open-generic binding is not registered."
raise DIWireInvalidRegistrationError(msg)
dependencies = [binding.dependency for binding in open_spec.bindings]
self._open_generic_registry.register(
provides=base_key,
provider_kind=open_spec.provider_kind,
provider=open_spec.provider,
lifetime=open_spec.lifetime,
scope=open_spec.scope,
lock_mode=open_spec.lock_mode,
is_async=open_spec.is_async,
is_any_dependency_async=open_spec.is_any_dependency_async,
needs_cleanup=open_spec.needs_cleanup,
dependencies=dependencies,
)
return
provider_spec = self._providers_registrations.find_by_type(provides)
if provider_spec is None:
msg = f"Cannot decorate {provides!r}: base binding is not registered."
raise DIWireInvalidRegistrationError(msg)
self._providers_registrations.add(
self._copy_provider_spec_with_new_key(
provider_spec=provider_spec,
provides=base_key,
),
)
def _copy_provider_spec_with_new_key(
self,
*,
provider_spec: ProviderSpec,
provides: Any,
) -> ProviderSpec:
return ProviderSpec(
provides=provides,
instance=provider_spec.instance,
concrete_type=provider_spec.concrete_type,
factory=provider_spec.factory,
generator=provider_spec.generator,
context_manager=provider_spec.context_manager,
dependencies=list(provider_spec.dependencies),
is_async=provider_spec.is_async,
is_any_dependency_async=provider_spec.is_any_dependency_async,
needs_cleanup=provider_spec.needs_cleanup,
lock_mode=provider_spec.lock_mode,
lifetime=provider_spec.lifetime,
scope=provider_spec.scope,
)
def _rebuild_decoration_chain(
self,
*,
provides: Any,
chain: _DecorationChain | None = None,
) -> None:
active_chain = (
chain if chain is not None else self._decoration_chain_by_provides.get(provides)
)
if active_chain is None:
return
rules = self._decoration_rules_by_provides.get(provides)
if not rules:
return
if len(active_chain.layer_keys) != len(rules):
msg = f"Decoration chain for {provides!r} is out of sync with rules."
raise DIWireInvalidRegistrationError(msg)
base_metadata = self._resolve_decoration_base_metadata(
provides=provides,
base_key=active_chain.base_key,
)
inner_key = active_chain.base_key
for index, rule in enumerate(rules):
out_key = active_chain.layer_keys[index]
dependencies = self._build_decorator_dependencies(
rule=rule,
inner_key=inner_key,
)
is_any_dependency_async = self._provider_return_type_extractor.is_any_dependency_async(
dependencies,
)
if base_metadata.is_open_generic:
self._register_open_generic_decorator_layer(
provides=out_key,
rule=rule,
dependencies=dependencies,
metadata=base_metadata,
is_any_dependency_async=is_any_dependency_async,
)
else:
self._providers_registrations.add(
ProviderSpec(
provides=out_key,
factory=rule.decorator,
lifetime=base_metadata.lifetime,
scope=base_metadata.scope,
dependencies=dependencies,
is_async=rule.is_async,
is_any_dependency_async=is_any_dependency_async,
needs_cleanup=False,
lock_mode=base_metadata.lock_mode,
),
)
self._autoregister_provider_dependencies(
dependencies=dependencies,
scope=base_metadata.scope,
lifetime=base_metadata.lifetime,
dependency_registration_policy=self._resolve_dependency_registration_policy(None),
)
inner_key = out_key
def _build_decorator_dependencies(
self,
*,
rule: _DecorationRule,
inner_key: Any,
) -> list[ProviderDependency]:
resolved_dependencies: list[ProviderDependency] = []
inner_resolved = False
for dependency in rule.dependencies:
if dependency.parameter.name == rule.inner_parameter:
resolved_dependencies.append(
ProviderDependency(
provides=inner_key,
parameter=dependency.parameter,
),
)
inner_resolved = True
else:
resolved_dependencies.append(dependency)
if inner_resolved:
return resolved_dependencies
msg = (
"decorate() configured an unknown inner parameter "
f"'{rule.inner_parameter}' for decorator '{self._callable_name(rule.decorator)}'."
)
raise DIWireInvalidRegistrationError(msg)
def _register_open_generic_decorator_layer(
self,
*,
provides: Any,
rule: _DecorationRule,
dependencies: list[ProviderDependency],
metadata: _DecorationBaseMetadata,
is_any_dependency_async: bool,
) -> None:
registered_spec = self._open_generic_registry.register(
provides=provides,
provider_kind="factory",
provider=rule.decorator,
lifetime=metadata.lifetime,
scope=metadata.scope,
lock_mode=metadata.lock_mode,
is_async=rule.is_async,
is_any_dependency_async=is_any_dependency_async,
needs_cleanup=False,
dependencies=dependencies,
)
if registered_spec is None:
msg = f"Cannot register open-generic decorator layer for key {provides!r}."
raise DIWireInvalidRegistrationError(msg)
def _resolve_decoration_base_metadata(
self,
*,
provides: Any,
base_key: Any,
) -> _DecorationBaseMetadata:
if self._is_open_generic_provides(provides):
open_spec = self._open_generic_registry.find_exact(base_key)
if open_spec is None:
msg = f"Decoration base binding for {provides!r} is missing."
raise DIWireInvalidRegistrationError(msg)
return _DecorationBaseMetadata(
lifetime=open_spec.lifetime,
scope=open_spec.scope,
lock_mode=open_spec.lock_mode,
is_open_generic=True,
)
provider_spec = self._providers_registrations.find_by_type(base_key)
if provider_spec is None:
msg = f"Decoration base binding for {provides!r} is missing."
raise DIWireInvalidRegistrationError(msg)
if provider_spec.lifetime is None:
msg = f"Decoration base binding for {provides!r} has no lifetime."
raise DIWireInvalidRegistrationError(msg)
return _DecorationBaseMetadata(
lifetime=provider_spec.lifetime,
scope=provider_spec.scope,
lock_mode=provider_spec.lock_mode,
is_open_generic=False,
)
def _has_registered_binding(self, provides: Any) -> bool:
if self._is_open_generic_provides(provides):
return self._open_generic_registry.find_exact(provides) is not None
return self._providers_registrations.find_by_type(provides) is not None
def _normalize_decoration_provides_key(self, provides: Any) -> Any:
canonical_open_key = canonicalize_open_key(provides)
if canonical_open_key is None:
return provides
return canonical_open_key
def _create_decoration_alias_key(
self,
*,
provides: Any,
layer: int,
) -> Any:
self._decoration_counter += 1
alias_id = self._decoration_counter
if self._is_open_generic_provides(provides):
return Annotated[provides, _OpenDecorationAlias(id=alias_id, layer=layer)]
return type(f"_DIWireInner_{alias_id}", (), {})
def _is_open_generic_provides(self, provides: Any) -> bool:
return canonicalize_open_key(provides) is not None
def _resolve_concrete_registration_types(
self,
*,
provides: Any | Literal["infer"],
concrete_type: Any,
) -> tuple[Any, type[Any]]:
provides_value = cast("Any", provides)
concrete_type_value = concrete_type
if provides_value == "infer":
resolved_provides = concrete_type_value
elif provides_value is not None:
resolved_provides = provides_value
else:
msg = "add() parameter 'provides' must not be None; use 'infer'."
raise DIWireInvalidRegistrationError(msg)
if concrete_type_value is None:
msg = "add() parameter 'concrete_type' must not be None; use 'infer'."
raise DIWireInvalidRegistrationError(msg)
return resolved_provides, concrete_type_value
def _resolve_registration_provides(
self,
*,
provides: Any,
method_name: str,
infer_from: Callable[[], Any],
) -> Any:
if provides == "infer":
return infer_from()
if provides is not None:
return provides
msg = f"{method_name}() parameter 'provides' must not be None; use 'infer'."
raise DIWireInvalidRegistrationError(msg)
def _resolve_registration_component_provides(
self,
*,
provides: Any,
component: Component | Any | None,
method_name: str,
) -> Any:
normalized_provides = strip_non_component_annotation(provides)
if component is None:
return normalized_provides
if component_base_key(normalized_provides) is not None:
msg = (
f"{method_name}() received both a component-qualified 'provides' key "
f"({normalized_provides!r}) and 'component'. Omit component=... or pass the "
"non-component "
"base key in provides=... and keep component=...."
)
raise DIWireInvalidRegistrationError(msg)
component_marker = self._normalize_registration_component(
component=component,
method_name=method_name,
)
return build_annotated_key((normalized_provides, component_marker))
def _normalize_registration_component(
self,
*,
component: Component | Any,
method_name: str,
) -> Component:
if isinstance(component, Component):
component_marker = component
component_value = component.value
else:
component_marker = Component(component)
component_value = component
try:
hash(component_marker)
except TypeError as error:
msg = (
f"{method_name}() parameter 'component' must be hashable; "
f"got {type(component_value).__name__}."
)
raise DIWireInvalidRegistrationError(msg) from error
return component_marker
def _resolve_registration_scope(
self,
*,
scope: BaseScope | Literal["from_container"],
method_name: str,
) -> BaseScope:
scope_value = cast("Any", scope)
if scope_value == "from_container":
return self._root_scope
if isinstance(scope_value, BaseScope):
return scope_value
msg = f"{method_name}() parameter 'scope' must be BaseScope or 'from_container'."
raise DIWireInvalidRegistrationError(msg)
def _resolve_registration_lifetime(
self,
*,
lifetime: Lifetime | Literal["from_container"],
method_name: str,
) -> Lifetime:
lifetime_value = cast("Any", lifetime)
if lifetime_value == "from_container":
return self._default_lifetime
if isinstance(lifetime_value, Lifetime):
return lifetime_value
msg = f"{method_name}() parameter 'lifetime' must be Lifetime or 'from_container'."
raise DIWireInvalidRegistrationError(msg)
def _resolve_registration_dependencies(
self,
*,
dependencies: Mapping[Any, inspect.Parameter] | Literal["infer"],
method_name: str,
) -> list[ProviderDependency] | None:
dependencies_value = cast("Any", dependencies)
if dependencies_value == "infer":
return None
if not isinstance(dependencies_value, Mapping):
msg = (
f"{method_name}() parameter 'dependencies' must be a "
"mapping[Any, inspect.Parameter] or 'infer'."
)
raise DIWireInvalidRegistrationError(msg)
resolved_dependencies: list[ProviderDependency] = []
for provides_key, parameter in dependencies_value.items():
if not isinstance(parameter, inspect.Parameter):
msg = (
f"{method_name}() parameter 'dependencies' must be a "
"mapping[Any, inspect.Parameter] or 'infer'."
)
raise DIWireInvalidRegistrationError(msg)
resolved_dependencies.append(
ProviderDependency(
provides=provides_key,
parameter=parameter,
),
)
return resolved_dependencies
def _resolve_registration_dependency_registration_policy(
self,
*,
dependency_registration_policy: DependencyRegistrationPolicy | Literal["from_container"],
method_name: str,
) -> DependencyRegistrationPolicy | None:
dependency_registration_policy_value = cast("Any", dependency_registration_policy)
if dependency_registration_policy_value == "from_container":
return None
if isinstance(dependency_registration_policy_value, DependencyRegistrationPolicy):
return dependency_registration_policy_value
msg = (
f"{method_name}() parameter 'dependency_registration_policy' must be "
"DependencyRegistrationPolicy or 'from_container'."
)
raise DIWireInvalidRegistrationError(msg)
def _register_non_concrete_provider(
self,
*,
provides: Any,
provider_kind: Literal["factory", "generator", "context_manager"],
provider: Any,
provider_field: Literal["factory", "generator", "context_manager"],
lifetime: Lifetime,
scope: BaseScope,
lock_mode: LockMode | Literal["auto"],
is_async: bool,
is_any_dependency_async: bool,
needs_cleanup: bool,
dependencies: list[ProviderDependency],
resolved_dependency_registration_policy: DependencyRegistrationPolicy | None,
) -> None:
registration_provides, has_decoration_chain = self._resolve_registration_target_provides(
provides,
)
with self._registration_mutation():
if (
self._open_generic_registry.register(
provides=registration_provides,
provider_kind=provider_kind,
provider=provider,
lifetime=lifetime,
scope=scope,
lock_mode=lock_mode,
is_async=is_async,
is_any_dependency_async=is_any_dependency_async,
needs_cleanup=needs_cleanup,
dependencies=dependencies,
)
is not None
):
self._autoregister_provider_dependencies(
dependencies=dependencies,
scope=scope,
lifetime=lifetime,
dependency_registration_policy=self._resolve_dependency_registration_policy(
resolved_dependency_registration_policy,
),
)
self._finalize_registration_after_binding(
original_provides=provides,
has_decoration_chain=has_decoration_chain,
)
return
if provider_field == "factory":
provider_spec = ProviderSpec(
provides=registration_provides,
factory=cast("FactoryProvider[Any]", provider),
lifetime=lifetime,
scope=scope,
dependencies=dependencies,
is_async=is_async,
is_any_dependency_async=is_any_dependency_async,
needs_cleanup=needs_cleanup,
lock_mode=lock_mode,
)
elif provider_field == "generator":
provider_spec = ProviderSpec(
provides=registration_provides,
generator=cast("GeneratorProvider[Any]", provider),
lifetime=lifetime,
scope=scope,
dependencies=dependencies,
is_async=is_async,
is_any_dependency_async=is_any_dependency_async,
needs_cleanup=needs_cleanup,
lock_mode=lock_mode,
)
else:
provider_spec = ProviderSpec(
provides=registration_provides,
context_manager=cast("ContextManagerProvider[Any]", provider),
lifetime=lifetime,
scope=scope,
dependencies=dependencies,
is_async=is_async,
is_any_dependency_async=is_any_dependency_async,
needs_cleanup=needs_cleanup,
lock_mode=lock_mode,
)
self._providers_registrations.add(provider_spec)
self._autoregister_provider_dependencies(
dependencies=dependencies,
scope=scope,
lifetime=lifetime,
dependency_registration_policy=self._resolve_dependency_registration_policy(
resolved_dependency_registration_policy,
),
)
self._finalize_registration_after_binding(
original_provides=provides,
has_decoration_chain=has_decoration_chain,
)
return
def _resolve_concrete_registration_dependencies(
self,
*,
concrete_type: type[Any],
explicit_dependencies: list[ProviderDependency] | None,
) -> list[ProviderDependency]:
if explicit_dependencies is None:
return self._provider_dependencies_extractor.extract_from_concrete_type(
concrete_type=concrete_type,
)
return self._provider_dependencies_extractor.validate_explicit_for_concrete_type(
concrete_type=concrete_type,
dependencies=explicit_dependencies,
)
def _resolve_closed_concrete_generic_injections(
self,
*,
provides: Any,
dependencies: list[ProviderDependency],
) -> tuple[dict[str, Any], list[ProviderDependency]]:
typevar_map = self._closed_generic_typevar_map(provides=provides)
if not typevar_map:
return {}, dependencies
injected_arguments: dict[str, Any] = {}
remaining_dependencies: list[ProviderDependency] = []
for dependency in dependencies:
injection_value = self._resolve_closed_generic_injection_value(
dependency_annotation=dependency.provides,
typevar_map=typevar_map,
)
if injection_value is _MISSING_CLOSED_GENERIC_INJECTION:
remaining_dependencies.append(dependency)
continue
injected_arguments[dependency.parameter.name] = injection_value
return injected_arguments, remaining_dependencies
def _closed_generic_typevar_map(self, *, provides: Any) -> dict[TypeVar, Any]:
normalized_provides = strip_non_component_annotation(provides)
if get_origin(normalized_provides) is Annotated:
normalized_provides = get_args(normalized_provides)[0]
origin = get_origin(normalized_provides)
if origin is None:
return {}
arguments = get_args(normalized_provides)
if not arguments:
return {}
origin_typevars = tuple(
parameter
for parameter in getattr(origin, "__parameters__", ())
if isinstance(parameter, TypeVar)
)
if len(origin_typevars) != len(arguments):
return {}
return dict(zip(origin_typevars, arguments, strict=True))
def _resolve_closed_generic_injection_value(
self,
*,
dependency_annotation: Any,
typevar_map: dict[TypeVar, Any],
) -> Any:
if isinstance(dependency_annotation, TypeVar):
return typevar_map.get(dependency_annotation, _MISSING_CLOSED_GENERIC_INJECTION)
origin = get_origin(dependency_annotation)
arguments = get_args(dependency_annotation)
if origin is type and len(arguments) == 1 and isinstance(arguments[0], TypeVar):
return typevar_map.get(arguments[0], _MISSING_CLOSED_GENERIC_INJECTION)
return _MISSING_CLOSED_GENERIC_INJECTION
def _build_closed_concrete_factory(
self,
*,
concrete_type: type[Any],
injected_arguments: dict[str, Any],
) -> Callable[..., Any]:
constructor_signature = inspect.signature(concrete_type)
factory_injected_arguments = dict(injected_arguments)
def _factory(*args: Any, **kwargs: Any) -> Any:
bound_arguments = constructor_signature.bind_partial(*args, **kwargs)
for argument_name, argument_value in factory_injected_arguments.items():
if argument_name in bound_arguments.arguments:
continue
bound_arguments.arguments[argument_name] = argument_value
return concrete_type(*bound_arguments.args, **bound_arguments.kwargs)
return _factory
def _materialize_closed_open_generic_spec(
self,
dependency: Any,
match: Any,
) -> None:
with self._graph_state_lock:
if self._providers_registrations.find_by_type(dependency) is not None:
return
open_spec = match.spec
typevar_map = match.typevar_map
specialized_dependencies: list[ProviderDependency] = []
injected_dependencies: list[ProviderDependency] = []
injected_arguments: dict[str, Any] = {}
for binding in open_spec.bindings:
if binding.kind == "dependency":
resolved_dependency = substitute_typevars(binding.template, mapping=typevar_map)
if contains_typevar(resolved_dependency):
return
specialized_dependencies.append(
ProviderDependency(
provides=resolved_dependency,
parameter=binding.dependency.parameter,
),
)
continue
typevar = binding.typevar
if typevar is None:
return
argument_value = typevar_map.get(typevar, _MISSING_CLOSED_GENERIC_INJECTION)
if argument_value is _MISSING_CLOSED_GENERIC_INJECTION:
return
injected_dependencies.append(binding.dependency)
injected_arguments[binding.dependency.parameter.name] = argument_value
provider_kind = cast("str", open_spec.provider_kind)
provider_object = open_spec.provider
if injected_dependencies:
provider_object = self._build_materialized_provider_wrapper(
provider=provider_object,
injected_dependencies=tuple(injected_dependencies),
injected_arguments=injected_arguments,
has_runtime_dependencies=bool(specialized_dependencies),
provider_is_inject_wrapper=open_spec.provider_is_inject_wrapper,
)
if open_spec.provider_is_inject_wrapper:
provider_object.__dict__[INJECT_WRAPPER_MARKER] = True
if provider_kind == "concrete_type":
provider_kind = "factory"
if provider_kind == "concrete_type":
provider_spec = ProviderSpec(
provides=dependency,
concrete_type=cast("type[Any]", provider_object),
lifetime=open_spec.lifetime,
scope=open_spec.scope,
dependencies=specialized_dependencies,
is_async=open_spec.is_async,
is_any_dependency_async=open_spec.is_any_dependency_async,
needs_cleanup=open_spec.needs_cleanup,
lock_mode=open_spec.lock_mode,
)
elif provider_kind == "factory":
provider_spec = ProviderSpec(
provides=dependency,
factory=cast("FactoryProvider[Any]", provider_object),
lifetime=open_spec.lifetime,
scope=open_spec.scope,
dependencies=specialized_dependencies,
is_async=open_spec.is_async,
is_any_dependency_async=open_spec.is_any_dependency_async,
needs_cleanup=open_spec.needs_cleanup,
lock_mode=open_spec.lock_mode,
)
elif provider_kind == "generator":
provider_spec = ProviderSpec(
provides=dependency,
generator=cast("GeneratorProvider[Any]", provider_object),
lifetime=open_spec.lifetime,
scope=open_spec.scope,
dependencies=specialized_dependencies,
is_async=open_spec.is_async,
is_any_dependency_async=open_spec.is_any_dependency_async,
needs_cleanup=open_spec.needs_cleanup,
lock_mode=open_spec.lock_mode,
)
else:
provider_spec = ProviderSpec(
provides=dependency,
context_manager=cast("ContextManagerProvider[Any]", provider_object),
lifetime=open_spec.lifetime,
scope=open_spec.scope,
dependencies=specialized_dependencies,
is_async=open_spec.is_async,
is_any_dependency_async=open_spec.is_any_dependency_async,
needs_cleanup=open_spec.needs_cleanup,
lock_mode=open_spec.lock_mode,
)
self._providers_registrations.add(provider_spec)
self._runtime_materialized_closed_keys.add(dependency)
should_invalidate_compilation = (
provider_spec.lifetime is Lifetime.TRANSIENT
or provider_spec.scope.level > self._root_scope.level
)
if should_invalidate_compilation:
self._invalidate_compilation()
def _build_materialized_provider_wrapper(
self,
*,
provider: Any,
injected_dependencies: tuple[ProviderDependency, ...],
injected_arguments: dict[str, Any],
has_runtime_dependencies: bool,
provider_is_inject_wrapper: bool,
) -> Callable[..., Any]:
simple_parameter_kinds = {
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
}
if (
not has_runtime_dependencies
and not provider_is_inject_wrapper
and all(
dependency.parameter.kind in simple_parameter_kinds
for dependency in injected_dependencies
)
):
provider_signature = inspect.signature(provider)
prebound_arguments = provider_signature.bind_partial()
for dependency in injected_dependencies:
argument_name = dependency.parameter.name
prebound_arguments.arguments[argument_name] = injected_arguments[argument_name]
prebuilt_args = prebound_arguments.args
prebuilt_kwargs = dict(prebound_arguments.kwargs)
return self._build_prebound_materialized_wrapper(
provider=provider,
prebuilt_args=prebuilt_args,
prebuilt_kwargs=prebuilt_kwargs,
)
supports_fast_kwargs = all(
dependency.parameter.kind
in {
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
}
for dependency in injected_dependencies
)
if supports_fast_kwargs:
captured_injected_arguments = dict(injected_arguments)
def _wrapper(*args: Any, **kwargs: Any) -> Any:
merged_kwargs = dict(kwargs)
for argument_name, argument_value in captured_injected_arguments.items():
merged_kwargs.setdefault(argument_name, argument_value)
return provider(*args, **merged_kwargs)
return _wrapper
provider_signature = inspect.signature(provider)
captured_dependencies = tuple(injected_dependencies)
captured_injected_arguments = dict(injected_arguments)
def _bound_wrapper(*args: Any, **kwargs: Any) -> Any:
bound_arguments = provider_signature.bind_partial(*args, **kwargs)
for dependency in captured_dependencies:
argument_name = dependency.parameter.name
if argument_name in bound_arguments.arguments:
continue
bound_arguments.arguments[argument_name] = captured_injected_arguments[
argument_name
]
return provider(*bound_arguments.args, **bound_arguments.kwargs)
return _bound_wrapper
def _build_prebound_materialized_wrapper(
self,
*,
provider: Any,
prebuilt_args: tuple[Any, ...],
prebuilt_kwargs: dict[str, Any],
) -> Callable[[], Any]:
if prebuilt_kwargs:
if prebuilt_args:
def _zero_runtime_wrapper_args_kwargs() -> Any:
return provider(*prebuilt_args, **prebuilt_kwargs)
return _zero_runtime_wrapper_args_kwargs
def _zero_runtime_wrapper_kwargs() -> Any:
return provider(**prebuilt_kwargs)
return _zero_runtime_wrapper_kwargs
if prebuilt_args:
one_argument = 1
two_arguments = 2
three_arguments = 3
four_arguments = 4
if len(prebuilt_args) == one_argument:
arg0 = prebuilt_args[0]
def _zero_runtime_wrapper_one_arg() -> Any:
return provider(arg0)
return _zero_runtime_wrapper_one_arg
if len(prebuilt_args) == two_arguments:
arg0 = prebuilt_args[0]
arg1 = prebuilt_args[1]
def _zero_runtime_wrapper_two_args() -> Any:
return provider(arg0, arg1)
return _zero_runtime_wrapper_two_args
if len(prebuilt_args) == three_arguments:
arg0 = prebuilt_args[0]
arg1 = prebuilt_args[1]
arg2 = prebuilt_args[2]
def _zero_runtime_wrapper_three_args() -> Any:
return provider(arg0, arg1, arg2)
return _zero_runtime_wrapper_three_args
if len(prebuilt_args) == four_arguments:
arg0 = prebuilt_args[0]
arg1 = prebuilt_args[1]
arg2 = prebuilt_args[2]
arg3 = prebuilt_args[3]
def _zero_runtime_wrapper_four_args() -> Any:
return provider(arg0, arg1, arg2, arg3)
return _zero_runtime_wrapper_four_args
def _zero_runtime_wrapper_var_args() -> Any:
return provider(*prebuilt_args)
return _zero_runtime_wrapper_var_args
def _zero_runtime_wrapper_no_args() -> Any:
return provider()
return _zero_runtime_wrapper_no_args
def _resolve_factory_registration_dependencies(
self,
*,
factory: FactoryProvider[Any],
explicit_dependencies: list[ProviderDependency] | None,
) -> list[ProviderDependency]:
if explicit_dependencies is None:
return self._provider_dependencies_extractor.extract_from_factory(
factory=factory,
)
return self._provider_dependencies_extractor.validate_explicit_for_factory(
factory=factory,
dependencies=explicit_dependencies,
)
def _resolve_generator_registration_dependencies(
self,
*,
generator: GeneratorProvider[Any],
explicit_dependencies: list[ProviderDependency] | None,
) -> list[ProviderDependency]:
if explicit_dependencies is None:
return self._provider_dependencies_extractor.extract_from_generator(
generator=generator,
)
return self._provider_dependencies_extractor.validate_explicit_for_generator(
generator=generator,
dependencies=explicit_dependencies,
)
def _resolve_context_manager_registration_dependencies(
self,
*,
context_manager: ContextManagerProvider[Any],
explicit_dependencies: list[ProviderDependency] | None,
) -> list[ProviderDependency]:
if explicit_dependencies is None:
return self._provider_dependencies_extractor.extract_from_context_manager(
context_manager=context_manager,
)
return self._provider_dependencies_extractor.validate_explicit_for_context_manager(
context_manager=context_manager,
dependencies=explicit_dependencies,
)
def _resolve_provider_lock_mode(
self,
lock_mode: LockMode | Literal["from_container"],
) -> LockMode | Literal["auto"]:
if lock_mode == "from_container":
return self._lock_mode
return lock_mode
def _resolve_dependency_registration_policy(
self,
dependency_registration_policy: DependencyRegistrationPolicy | None,
) -> DependencyRegistrationPolicy:
if dependency_registration_policy is None:
return self._dependency_registration_policy
return dependency_registration_policy
def _resolve_container_missing_policy(
self,
missing_policy: Any,
) -> MissingPolicy:
if isinstance(missing_policy, MissingPolicy):
return missing_policy
msg = "Container() parameter 'missing_policy' must be MissingPolicy."
raise DIWireInvalidRegistrationError(msg)
def _resolve_container_dependency_registration_policy(
self,
dependency_registration_policy: Any,
) -> DependencyRegistrationPolicy:
if isinstance(dependency_registration_policy, DependencyRegistrationPolicy):
return dependency_registration_policy
msg = (
"Container() parameter 'dependency_registration_policy' must be "
"DependencyRegistrationPolicy."
)
raise DIWireInvalidRegistrationError(msg)
def _resolve_resolution_on_missing(
self,
*,
on_missing: MissingPolicy | Literal["from_container"],
method_name: str,
) -> MissingPolicy:
on_missing_value = cast("Any", on_missing)
if on_missing_value == "from_container":
return self._missing_policy
if isinstance(on_missing_value, MissingPolicy):
return on_missing_value
msg = f"{method_name}() parameter 'on_missing' must be MissingPolicy or 'from_container'."
raise DIWireInvalidRegistrationError(msg)
def _autoregister_provider_dependencies(
self,
*,
dependencies: list[ProviderDependency],
scope: BaseScope,
lifetime: Lifetime,
dependency_registration_policy: DependencyRegistrationPolicy,
) -> None:
if dependency_registration_policy is DependencyRegistrationPolicy.IGNORE:
return
for dependency in dependencies:
dependency_key = self._normalize_dependency_identity_key(
self._unwrap_provider_dependency_key(dependency.provides),
)
if self._providers_registrations.find_by_type(dependency_key):
continue
if self._open_generic_registry.has_match_for_dependency(dependency_key):
continue
with suppress(DIWireError):
self._autoregister_dependency(
dependency=dependency_key,
scope=scope,
lifetime=lifetime,
dependency_registration_policy=DependencyRegistrationPolicy.REGISTER_RECURSIVE,
)
def _autoregister_dependency(
self,
*,
dependency: Any,
scope: BaseScope,
lifetime: Lifetime,
dependency_registration_policy: DependencyRegistrationPolicy,
) -> None:
if is_pydantic_settings_subclass(dependency):
# Settings are environment-backed configuration objects and should be
# auto-registered as root-scoped (singleton) values via a no-arg factory.
self.add_factory(
lambda dependency_type=dependency: dependency_type(),
provides=dependency,
scope=self._root_scope,
lifetime=Lifetime.SCOPED,
dependency_registration_policy=dependency_registration_policy,
)
return
if not self._concrete_autoregistration_policy.is_eligible_concrete(dependency):
return
self.add(
dependency,
scope=scope,
lifetime=lifetime,
dependency_registration_policy=dependency_registration_policy,
)
def _ensure_autoregistration(
self,
dependency: Any,
*,
on_missing: MissingPolicy | None = None,
) -> None:
effective_on_missing = self._missing_policy if on_missing is None else on_missing
if effective_on_missing is MissingPolicy.ERROR:
return
dependency_key = self._normalize_dependency_identity_key(
self._unwrap_provider_dependency_key(dependency),
)
if self._providers_registrations.find_by_type(dependency_key):
return
if self._open_generic_registry.has_match_for_dependency(dependency_key):
return
effective_dependency_policy = (
DependencyRegistrationPolicy.REGISTER_RECURSIVE
if effective_on_missing is MissingPolicy.REGISTER_RECURSIVE
else DependencyRegistrationPolicy.IGNORE
)
self._autoregister_dependency(
dependency=dependency_key,
scope=self._root_scope,
lifetime=self._default_lifetime,
dependency_registration_policy=effective_dependency_policy,
)
def _inject_callable(
self,
*,
callable_obj: InjectableF,
scope: BaseScope | None,
dependency_registration_policy: DependencyRegistrationPolicy | None,
auto_open_scope: bool,
) -> InjectableF:
signature = inspect.signature(callable_obj)
if INJECT_RESOLVER_KWARG in signature.parameters:
msg = (
f"Callable '{self._callable_name(callable_obj)}' cannot declare reserved parameter "
f"'{INJECT_RESOLVER_KWARG}'."
)
raise DIWireInvalidRegistrationError(msg)
inspected_callable = self._injected_callable_inspector.inspect_callable(callable_obj)
injected_parameters = inspected_callable.injected_parameters
resolved_dependency_registration_policy = self._resolve_dependency_registration_policy(
dependency_registration_policy,
)
if (
resolved_dependency_registration_policy
is DependencyRegistrationPolicy.REGISTER_RECURSIVE
):
self._autoregister_injected_dependencies(
injected_parameters=injected_parameters,
scope=scope,
)
inferred_scope_level = self._infer_injected_scope_level(
injected_parameters=injected_parameters,
)
callable_name = self._callable_name(callable_obj)
if scope is not None and scope.level < inferred_scope_level:
msg = (
f"Callable '{callable_name}' scope level {scope.level} is "
f"shallower than required dependency scope level {inferred_scope_level}."
)
raise DIWireInvalidRegistrationError(msg)
if scope is not None:
self._injected_scope_contracts.append(
_InjectedScopeContract(
callable_name=callable_name,
injected_parameters=injected_parameters,
scope=scope,
),
)
get_target_scope = self._build_injected_target_scope_getter(
explicit_scope=scope,
inferred_scope_level=inferred_scope_level,
injected_parameters=injected_parameters,
callable_name=callable_name,
auto_open_scope=auto_open_scope,
)
if inspect.iscoroutinefunction(callable_obj):
@functools.wraps(callable_obj)
async def _async_injected(*args: Any, **kwargs: Any) -> Any:
base_resolver = self._resolve_inject_resolver(kwargs)
target_scope = get_target_scope()
maybe_scoped = self._enter_scope_if_needed(
base_resolver=base_resolver,
target_scope=target_scope,
)
if maybe_scoped is base_resolver:
bound_arguments = await self._resolve_async_injected_arguments(
resolver=maybe_scoped,
signature=signature,
args=args,
kwargs=kwargs,
injected_parameters=injected_parameters,
)
async_callable = cast("Callable[..., Awaitable[Any]]", callable_obj)
return await async_callable(*bound_arguments.args, **bound_arguments.kwargs)
async_scoped_resolver = cast("Any", maybe_scoped)
async with async_scoped_resolver:
bound_arguments = await self._resolve_async_injected_arguments(
resolver=maybe_scoped,
signature=signature,
args=args,
kwargs=kwargs,
injected_parameters=injected_parameters,
)
async_callable = cast("Callable[..., Awaitable[Any]]", callable_obj)
return await async_callable(*bound_arguments.args, **bound_arguments.kwargs)
wrapped_callable: Callable[..., Any] = _async_injected
else:
@functools.wraps(callable_obj)
def _sync_injected(*args: Any, **kwargs: Any) -> Any:
base_resolver = self._resolve_inject_resolver(kwargs)
target_scope = get_target_scope()
maybe_scoped = self._enter_scope_if_needed(
base_resolver=base_resolver,
target_scope=target_scope,
)
if maybe_scoped is base_resolver:
bound_arguments = self._resolve_sync_injected_arguments(
resolver=maybe_scoped,
signature=signature,
args=args,
kwargs=kwargs,
injected_parameters=injected_parameters,
)
return callable_obj(*bound_arguments.args, **bound_arguments.kwargs)
with maybe_scoped:
bound_arguments = self._resolve_sync_injected_arguments(
resolver=maybe_scoped,
signature=signature,
args=args,
kwargs=kwargs,
injected_parameters=injected_parameters,
)
return callable_obj(*bound_arguments.args, **bound_arguments.kwargs)
wrapped_callable = _sync_injected
wrapped_callable.__signature__ = inspected_callable.public_signature # type: ignore[attr-defined]
wrapped_callable.__dict__[INJECT_WRAPPER_MARKER] = True
return cast("InjectableF", wrapped_callable)
def _resolve_injected_dependency(self, *, annotation: Any) -> Any | None:
return self._injected_callable_inspector.resolve_injected_dependency(annotation=annotation)
def _autoregister_injected_dependencies(
self,
*,
injected_parameters: tuple[InjectedParameter, ...],
scope: BaseScope | None,
) -> None:
registration_scope = scope or self._root_scope
for injected_parameter in injected_parameters:
dependency_key = self._normalize_dependency_identity_key(
self._unwrap_provider_dependency_key(injected_parameter.dependency),
)
if self._providers_registrations.find_by_type(dependency_key):
continue
with suppress(DIWireError):
self._autoregister_dependency(
dependency=dependency_key,
scope=registration_scope,
lifetime=self._default_lifetime,
dependency_registration_policy=DependencyRegistrationPolicy.REGISTER_RECURSIVE,
)
def _infer_injected_scope_level(
self,
*,
injected_parameters: tuple[InjectedParameter, ...],
) -> int:
max_scope_level = self._root_scope.level
cache: dict[Any, int] = {}
for injected_parameter in injected_parameters:
max_scope_level = max(
max_scope_level,
self._infer_dependency_scope_level(
dependency=injected_parameter.dependency,
cache=cache,
in_progress=set(),
),
)
return max_scope_level
def _infer_dependency_scope_level(
self,
*,
dependency: Any,
cache: dict[Any, int],
in_progress: set[Any],
) -> int:
original_dependency = dependency
known_level = cache.get(original_dependency)
if known_level is not None:
return known_level
if is_maybe_annotation(dependency):
maybe_inner_dependency = strip_non_component_annotation(
strip_maybe_annotation(dependency)
)
inferred_level = self._infer_dependency_scope_level(
dependency=maybe_inner_dependency,
cache=cache,
in_progress=in_progress,
)
cache[original_dependency] = inferred_level
return inferred_level
if is_provider_annotation(dependency):
provider_inner_dependency = strip_non_component_annotation(
strip_provider_annotation(dependency),
)
inferred_level = self._infer_dependency_scope_level(
dependency=provider_inner_dependency,
cache=cache,
in_progress=in_progress,
)
cache[original_dependency] = inferred_level
return inferred_level
if is_all_annotation(dependency):
if original_dependency in in_progress:
return self._root_scope.level
inner = strip_non_component_annotation(strip_all_annotation(dependency))
collected_keys: list[Any] = []
if self._providers_registrations.find_by_type(inner) is not None:
collected_keys.append(inner)
collected_keys.extend(
spec.provides
for spec in self._providers_registrations.values()
if component_base_key(spec.provides) == inner
)
if not collected_keys:
cache[original_dependency] = self._root_scope.level
return self._root_scope.level
in_progress.add(original_dependency)
try:
inferred_level = max(
self._infer_dependency_scope_level(
dependency=collected_key,
cache=cache,
in_progress=in_progress,
)
for collected_key in collected_keys
)
finally:
in_progress.remove(original_dependency)
cache[original_dependency] = inferred_level
return inferred_level
dependency = strip_non_component_annotation(dependency)
known_level = cache.get(dependency)
if known_level is not None:
cache[original_dependency] = known_level
return known_level
spec = self._providers_registrations.find_by_type(dependency)
if spec is None:
open_match = self._open_generic_registry.find_best_match(dependency)
if open_match is not None:
inferred_level = open_match.spec.scope.level
cache[original_dependency] = inferred_level
return inferred_level
cache[original_dependency] = self._root_scope.level
return self._root_scope.level
if dependency in in_progress:
return spec.scope.level
in_progress.add(dependency)
max_scope_level = spec.scope.level
for nested_dependency in spec.dependencies:
max_scope_level = max(
max_scope_level,
self._infer_dependency_scope_level(
dependency=nested_dependency.provides,
cache=cache,
in_progress=in_progress,
),
)
in_progress.remove(dependency)
cache[dependency] = max_scope_level
cache[original_dependency] = max_scope_level
return max_scope_level
def _unwrap_provider_dependency_key(self, dependency: Any) -> Any:
provider_inner_dependency = self._extract_provider_inner_dependency_fast(dependency)
if provider_inner_dependency is None:
return dependency
return provider_inner_dependency
def _normalize_dependency_identity_key(self, dependency: Any) -> Any:
if (
is_all_annotation(dependency)
or is_maybe_annotation(dependency)
or is_provider_annotation(dependency)
):
return dependency
return strip_non_component_annotation(dependency)
def _extract_provider_inner_dependency_fast(self, dependency: Any) -> Any | None:
metadata = getattr(dependency, "__metadata__", None)
if metadata is None:
return None
for marker in metadata:
if isinstance(marker, ProviderMarker):
return marker.dependency_key
return None
def _resolve_sync_injected_arguments(
self,
*,
resolver: ResolverProtocol,
signature: inspect.Signature,
args: tuple[Any, ...],
kwargs: dict[str, Any],
injected_parameters: tuple[InjectedParameter, ...],
) -> inspect.BoundArguments:
bound_arguments = signature.bind_partial(*args, **kwargs)
for injected_parameter in injected_parameters:
if injected_parameter.name in bound_arguments.arguments:
continue
dependency = injected_parameter.dependency
if is_maybe_annotation(dependency):
inner_dependency = strip_maybe_annotation(dependency)
if is_provider_annotation(inner_dependency):
bound_arguments.arguments[injected_parameter.name] = resolver.resolve(
dependency,
)
continue
if not self._is_registered_in_resolver(
resolver=resolver,
dependency=inner_dependency,
):
parameter = signature.parameters[injected_parameter.name]
if parameter.default is inspect.Parameter.empty:
bound_arguments.arguments[injected_parameter.name] = None
continue
bound_arguments.arguments[injected_parameter.name] = resolver.resolve(
inner_dependency,
)
continue
bound_arguments.arguments[injected_parameter.name] = resolver.resolve(dependency)
return bound_arguments
async def _resolve_async_injected_arguments(
self,
*,
resolver: ResolverProtocol,
signature: inspect.Signature,
args: tuple[Any, ...],
kwargs: dict[str, Any],
injected_parameters: tuple[InjectedParameter, ...],
) -> inspect.BoundArguments:
bound_arguments = signature.bind_partial(*args, **kwargs)
for injected_parameter in injected_parameters:
if injected_parameter.name in bound_arguments.arguments:
continue
dependency = injected_parameter.dependency
if is_maybe_annotation(dependency):
inner_dependency = strip_maybe_annotation(dependency)
if is_provider_annotation(inner_dependency):
bound_arguments.arguments[injected_parameter.name] = await resolver.aresolve(
dependency,
)
continue
if not self._is_registered_in_resolver(
resolver=resolver,
dependency=inner_dependency,
):
parameter = signature.parameters[injected_parameter.name]
if parameter.default is inspect.Parameter.empty:
bound_arguments.arguments[injected_parameter.name] = None
continue
bound_arguments.arguments[injected_parameter.name] = await resolver.aresolve(
inner_dependency,
)
continue
bound_arguments.arguments[injected_parameter.name] = await resolver.aresolve(
dependency,
)
return bound_arguments
def _is_registered_in_resolver(
self,
*,
resolver: ResolverProtocol,
dependency: Any,
) -> bool:
is_registered_dependency = getattr(resolver, "_is_registered_dependency", None)
normalized_dependency = strip_non_component_annotation(dependency)
if callable(is_registered_dependency):
if bool(is_registered_dependency(dependency)):
return True
if normalized_dependency is dependency:
return False
return bool(is_registered_dependency(normalized_dependency))
if self._providers_registrations.find_by_type(dependency) is not None:
return True
if (
normalized_dependency is not dependency
and self._providers_registrations.find_by_type(normalized_dependency) is not None
):
return True
if self._open_generic_registry.find_best_match(dependency) is not None:
return True
if normalized_dependency is dependency:
return False
return self._open_generic_registry.find_best_match(normalized_dependency) is not None
def _resolve_inject_resolver(self, kwargs: dict[str, Any]) -> ResolverProtocol:
if INJECT_RESOLVER_KWARG in kwargs:
return cast("ResolverProtocol", kwargs.pop(INJECT_RESOLVER_KWARG))
return self.compile()
def _build_injected_target_scope_getter(
self,
*,
explicit_scope: BaseScope | None,
inferred_scope_level: int,
injected_parameters: tuple[InjectedParameter, ...],
callable_name: str,
auto_open_scope: bool,
) -> Callable[[], BaseScope | None]:
if not auto_open_scope:
def _no_scope() -> BaseScope | None:
return None
return _no_scope
if explicit_scope is not None:
def _explicit_scope() -> BaseScope:
return explicit_scope
return _explicit_scope
revision = self._graph_revision
inferred_scope = self._find_scope_by_level(scope_level=inferred_scope_level)
cached_target_scope = inferred_scope
def _resolve_target_scope() -> BaseScope:
nonlocal cached_target_scope, revision
# Registrations can be mutated after decoration time (including after the wrapper is created),
# so infer target scope lazily on call and re-check when the graph revision changes.
if cached_target_scope is not None and revision == self._graph_revision:
return cached_target_scope
current_inferred_level = self._infer_injected_scope_level(
injected_parameters=injected_parameters,
)
candidate = self._find_scope_by_level(scope_level=current_inferred_level)
if candidate is None:
msg = (
f"Callable '{callable_name}' inferred scope level {current_inferred_level} has no "
"matching scope in the root scope owner."
)
raise DIWireInvalidRegistrationError(msg)
revision = self._graph_revision
cached_target_scope = candidate
return candidate
return _resolve_target_scope
def _find_scope_by_level(self, *, scope_level: int) -> BaseScope | None:
return next(
(candidate for candidate in self._root_scope.owner() if candidate.level == scope_level),
None,
)
def _enter_scope_if_needed(
self,
*,
base_resolver: ResolverProtocol,
target_scope: BaseScope | None,
) -> ResolverProtocol:
if target_scope is None:
return base_resolver
if target_scope.level == self._root_scope.level:
return base_resolver
try:
return base_resolver.enter_scope(target_scope)
except DIWireScopeMismatchError as error:
if self._is_already_deep_enough_scope_error(error):
return base_resolver
raise
@staticmethod
def _is_already_deep_enough_scope_error(error: DIWireScopeMismatchError) -> bool:
message = str(error)
return message.startswith("Cannot enter scope level ") and " from level " in message
def _callable_name(self, callable_obj: Callable[..., Any]) -> str:
return getattr(callable_obj, "__qualname__", repr(callable_obj))
def _materialize_registered_open_generic_dependencies(self) -> None:
if not self._open_generic_registry.has_specs():
return
iteration_count = 0
seen_iteration_states: set[tuple[int, int, tuple[str, ...]]] = set()
materialized_registration_keys: list[Any] = []
materialized_registration_keys_seen: set[Any] = set()
materialized_closed_keys: list[Any] = []
materialized_closed_keys_seen: set[Any] = set()
while True:
iteration_count += 1
if iteration_count > self._OPEN_GENERIC_MATERIALIZATION_MAX_ITERATIONS:
self._rollback_open_generic_materialization(
materialized_registration_keys=materialized_registration_keys,
materialized_closed_keys=materialized_closed_keys,
)
self._raise_non_converging_open_generic_materialization_error(
iteration_count=iteration_count,
max_iterations=self._OPEN_GENERIC_MATERIALIZATION_MAX_ITERATIONS,
)
materialized_any = False
materialized_dependency_reprs: list[str] = []
registration_count_iteration_start = len(tuple(self._providers_registrations.values()))
for provider_spec in tuple(self._providers_registrations.values()):
for provider_dependency in provider_spec.dependencies:
dependency_key = self._open_generic_materialization_dependency_key(
provider_dependency.provides,
)
if dependency_key is None:
continue
if self._providers_registrations.find_by_type(dependency_key) is not None:
continue
open_match = self._open_generic_registry.find_best_match(dependency_key)
if open_match is None:
continue
registration_count_before = len(tuple(self._providers_registrations.values()))
self._materialize_closed_open_generic_spec(dependency_key, open_match)
registration_count_after = len(tuple(self._providers_registrations.values()))
if registration_count_after > registration_count_before:
materialized_any = True
materialized_dependency_reprs.append(repr(dependency_key))
if dependency_key not in materialized_registration_keys_seen:
materialized_registration_keys_seen.add(dependency_key)
materialized_registration_keys.append(dependency_key)
if (
dependency_key in self._runtime_materialized_closed_keys
and dependency_key not in materialized_closed_keys_seen
):
materialized_closed_keys_seen.add(dependency_key)
materialized_closed_keys.append(dependency_key)
if not materialized_any:
return
registration_count_iteration_end = len(tuple(self._providers_registrations.values()))
iteration_state = (
registration_count_iteration_end,
registration_count_iteration_end - registration_count_iteration_start,
tuple(
materialized_dependency_reprs[
-self._OPEN_GENERIC_MATERIALIZATION_STATE_TAIL_SIZE :
]
),
)
if iteration_state in seen_iteration_states:
self._rollback_open_generic_materialization(
materialized_registration_keys=materialized_registration_keys,
materialized_closed_keys=materialized_closed_keys,
)
self._raise_non_converging_open_generic_materialization_error(
iteration_count=iteration_count,
max_iterations=self._OPEN_GENERIC_MATERIALIZATION_MAX_ITERATIONS,
repeated_state=iteration_state,
)
seen_iteration_states.add(iteration_state)
def _rollback_open_generic_materialization(
self,
*,
materialized_registration_keys: list[Any],
materialized_closed_keys: list[Any],
) -> None:
if not materialized_registration_keys and not materialized_closed_keys:
return
if materialized_registration_keys:
materialized_registration_keys_set = set(materialized_registration_keys)
providers_registrations = ProvidersRegistrations()
for spec in self._providers_registrations.values():
if spec.provides in materialized_registration_keys_set:
continue
providers_registrations.add(spec)
self._providers_registrations = providers_registrations
for materialized_closed_key in materialized_closed_keys:
self._runtime_materialized_closed_keys.discard(materialized_closed_key)
self._invalidate_compilation()
def _raise_non_converging_open_generic_materialization_error(
self,
*,
iteration_count: int,
max_iterations: int,
repeated_state: tuple[int, int, tuple[str, ...]] | None = None,
) -> None:
msg = (
"Open generic materialization did not converge while compiling container registrations. "
f"iterations={iteration_count}, max_iterations={max_iterations}, "
f"repeated_iteration_state={repeated_state!r}. "
"Debug context: _open_generic_materialization_dependency_key, "
"_materialize_closed_open_generic_spec, _providers_registrations, "
"_open_generic_registry."
)
raise DIWireInvalidRegistrationError(msg)
def _open_generic_materialization_dependency_key(self, dependency: Any) -> Any | None:
dependency_key = dependency
if is_maybe_annotation(dependency_key):
dependency_key = strip_maybe_annotation(dependency_key)
if is_provider_annotation(dependency_key):
dependency_key = strip_provider_annotation(dependency_key)
if is_all_annotation(dependency_key):
return None
return strip_non_component_annotation(dependency_key)
# endregion Registration Methods
# region Compilation
[docs]
def compile(self) -> ResolverProtocol:
"""Compile and cache the root resolver for current registrations.
Compilation is lazy and invalidated by any registration mutation. In
strict mode (opt-in, autoregistration disabled) with
``use_resolver_context=False``, hot-path entrypoints are rebound to the
compiled resolver for lower call overhead.
Returns:
The compiled root resolver.
Notes:
Call this once after startup registrations when you want stable
strict-mode (opt-in) hot-path behavior. Any registration mutation invalidates
the compiled graph automatically.
Examples:
.. code-block:: python
container.add(Service)
container.compile()
service = container.resolve(Service)
"""
with self._graph_state_lock:
if self._root_resolver is None:
self._materialize_registered_open_generic_dependencies()
root_resolver = self._resolvers_manager.build_root_resolver(
root_scope=self._root_scope,
registrations=self._providers_registrations,
)
if self._open_generic_registry.has_specs():
has_async_specs = any(
spec.is_async for spec in self._providers_registrations.values()
) or any(spec.is_async for spec in self._open_generic_registry.values())
root_resolver = cast(
"ResolverProtocol",
OpenGenericResolver(
base_resolver=root_resolver,
registry=self._open_generic_registry,
root_scope=self._root_scope,
has_async_specs=has_async_specs,
scope_level=self._root_scope.level,
materialize_closed_callback=self._materialize_closed_open_generic_spec,
),
)
if self._use_resolver_context:
root_resolver = self._resolver_context._wrap_resolver(root_resolver) # noqa: SLF001
self._root_resolver = root_resolver
if self._missing_policy is MissingPolicy.ERROR and not self._use_resolver_context:
self._bind_container_entrypoints(target=self._root_resolver)
return self._root_resolver
def _invalidate_compilation(self) -> None:
"""Discard compiled resolver state and restore original container methods.
Any registration mutation can change the resolver graph, so cached compiled
entrypoints must be reverted back to container methods until recompilation.
"""
with self._graph_state_lock:
self._graph_revision += 1
self._root_resolver = None
self._restore_container_entrypoints()
def _revalidate_injected_scope_contracts(self) -> None:
for contract in self._injected_scope_contracts:
inferred_scope_level = self._infer_injected_scope_level(
injected_parameters=contract.injected_parameters,
)
if contract.scope.level < inferred_scope_level:
msg = (
f"Callable '{contract.callable_name}' scope level {contract.scope.level} is "
f"shallower than required dependency scope level {inferred_scope_level}."
)
raise DIWireInvalidRegistrationError(msg)
@contextmanager
def _registration_mutation(self) -> Generator[None, None, None]:
with self._graph_state_lock:
if self._registration_mutation_depth == 0:
self._purge_runtime_materialized_closed_specs()
self._registration_mutation_snapshot = _ContainerGraphSnapshot(
providers_registrations=self._providers_registrations.snapshot(),
open_generic_registry=self._open_generic_registry.snapshot(),
decoration_rules_by_provides={
provides: list(rules)
for provides, rules in self._decoration_rules_by_provides.items()
},
decoration_chain_by_provides={
provides: _DecorationChain(
base_key=chain.base_key,
layer_keys=list(chain.layer_keys),
)
for provides, chain in self._decoration_chain_by_provides.items()
},
decoration_counter=self._decoration_counter,
)
self._registration_mutation_failed = False
self._registration_mutation_depth += 1
try:
yield
if self._registration_mutation_depth == 1:
self._revalidate_injected_scope_contracts()
except DIWireInvalidRegistrationError:
self._registration_mutation_failed = True
raise
finally:
self._registration_mutation_depth -= 1
if self._registration_mutation_depth == 0:
if self._registration_mutation_failed:
snapshot = cast(
"_ContainerGraphSnapshot", self._registration_mutation_snapshot
)
self._providers_registrations.restore(snapshot.providers_registrations)
self._open_generic_registry.restore(snapshot.open_generic_registry)
self._decoration_rules_by_provides = {
provides: list(rules)
for provides, rules in snapshot.decoration_rules_by_provides.items()
}
self._decoration_chain_by_provides = {
provides: _DecorationChain(
base_key=chain.base_key,
layer_keys=list(chain.layer_keys),
)
for provides, chain in snapshot.decoration_chain_by_provides.items()
}
self._decoration_counter = snapshot.decoration_counter
self._invalidate_compilation()
self._registration_mutation_snapshot = None
self._registration_mutation_failed = False
def _purge_runtime_materialized_closed_specs(self) -> None:
if not self._runtime_materialized_closed_keys:
return
preserved_specs = [
spec
for spec in self._providers_registrations.values()
if spec.provides not in self._runtime_materialized_closed_keys
]
providers_registrations = ProvidersRegistrations()
for spec in preserved_specs:
providers_registrations.add(spec)
self._providers_registrations = providers_registrations
self._runtime_materialized_closed_keys.clear()
self._invalidate_compilation()
def _bind_container_entrypoints(
self,
*,
target: ResolverProtocol,
) -> None:
"""Bind selected container entrypoints directly to resolver-bound methods."""
for method_name in self._ENTRYPOINT_METHOD_NAMES:
setattr(self, method_name, getattr(target, method_name))
def _restore_container_entrypoints(self) -> None:
"""Restore original container-bound entrypoint methods captured at init."""
for method_name, method in self._container_entrypoints.items():
setattr(self, method_name, method)
# endregion Compilation
# region Resolution and Scope Management
def _get_context_bound_resolver_or_none(self) -> ResolverProtocol | None:
if not self._use_resolver_context:
return None
return self._resolver_context._get_bound_resolver_or_none() # noqa: SLF001
@overload
def resolve(
self,
dependency: type[T],
*,
on_missing: MissingPolicy | Literal["from_container"] = "from_container",
) -> T: ...
@overload
def resolve(
self,
dependency: Any,
*,
on_missing: MissingPolicy | Literal["from_container"] = "from_container",
) -> Any: ...
[docs]
def resolve(
self,
dependency: Any,
*,
on_missing: MissingPolicy | Literal["from_container"] = "from_container",
) -> Any:
"""Resolve a dependency synchronously.
Args:
dependency: Dependency key to resolve.
on_missing: Resolve-time auto-registration policy
for missing concrete dependencies, or ``"from_container"`` to
inherit container defaults.
Returns:
Resolved dependency value.
Raises:
DIWireDependencyNotRegisteredError: If dependency is missing in
strict mode and no open-generic match exists.
DIWireScopeMismatchError: If dependency requires a deeper scope than
the current resolver.
DIWireAsyncDependencyInSyncContextError: If the selected graph
requires async resolution or cleanup.
DIWireInvalidGenericTypeArgumentError: If closed generic arguments
violate TypeVar constraints.
Notes:
Typical fixes:
1. Register missing dependencies (or enable autoregistration).
2. Enter required scope before resolving scoped dependencies.
3. Switch to ``aresolve`` for async dependency chains.
4. Use compatible generic arguments for constrained TypeVars.
Examples:
.. code-block:: python
container.add(Service)
service = container.resolve(Service)
"""
resolver = self._get_context_bound_resolver_or_none()
if resolver is None:
resolver = self._root_resolver
if resolver is None:
resolver = self.compile()
resolved_on_missing = self._resolve_resolution_on_missing(
on_missing=on_missing,
method_name="resolve",
)
if resolved_on_missing is MissingPolicy.ERROR:
return resolver.resolve(dependency)
provider_inner_dependency = self._extract_provider_inner_dependency_fast(dependency)
if provider_inner_dependency is not None:
graph_revision_before = self._graph_revision
self._ensure_autoregistration(
provider_inner_dependency,
on_missing=resolved_on_missing,
)
if self._graph_revision != graph_revision_before:
resolver = self.compile()
return resolver.resolve(dependency)
try:
return resolver.resolve(dependency)
except DIWireDependencyNotRegisteredError:
graph_revision_before = self._graph_revision
self._ensure_autoregistration(
dependency,
on_missing=resolved_on_missing,
)
if self._graph_revision == graph_revision_before:
raise
resolver = self.compile()
return resolver.resolve(dependency)
@overload
async def aresolve(
self,
dependency: type[T],
*,
on_missing: MissingPolicy | Literal["from_container"] = "from_container",
) -> T: ...
@overload
async def aresolve(
self,
dependency: Any,
*,
on_missing: MissingPolicy | Literal["from_container"] = "from_container",
) -> Any: ...
[docs]
async def aresolve(
self,
dependency: Any,
*,
on_missing: MissingPolicy | Literal["from_container"] = "from_container",
) -> Any:
"""Resolve a dependency asynchronously.
Args:
dependency: Dependency key to resolve.
on_missing: Resolve-time auto-registration policy
for missing concrete dependencies, or ``"from_container"`` to
inherit container defaults.
Returns:
Resolved dependency value.
Raises:
DIWireDependencyNotRegisteredError: If dependency is missing in
strict mode and no open-generic match exists.
DIWireScopeMismatchError: If dependency requires a deeper scope than
the current resolver.
DIWireInvalidGenericTypeArgumentError: If closed generic arguments
violate TypeVar constraints.
Notes:
Use this API whenever any part of the selected provider chain is
async.
Examples:
.. code-block:: python
container.add_factory(async_make_client, provides=Client)
client = await container.aresolve(Client)
"""
resolver = self._get_context_bound_resolver_or_none()
if resolver is None:
resolver = self._root_resolver
if resolver is None:
resolver = self.compile()
resolved_on_missing = self._resolve_resolution_on_missing(
on_missing=on_missing,
method_name="aresolve",
)
if resolved_on_missing is MissingPolicy.ERROR:
return await resolver.aresolve(dependency)
provider_inner_dependency = self._extract_provider_inner_dependency_fast(dependency)
if provider_inner_dependency is not None:
graph_revision_before = self._graph_revision
self._ensure_autoregistration(
provider_inner_dependency,
on_missing=resolved_on_missing,
)
if self._graph_revision != graph_revision_before:
resolver = self.compile()
return await resolver.aresolve(dependency)
try:
return await resolver.aresolve(dependency)
except DIWireDependencyNotRegisteredError:
graph_revision_before = self._graph_revision
self._ensure_autoregistration(
dependency,
on_missing=resolved_on_missing,
)
if self._graph_revision == graph_revision_before:
raise
resolver = self.compile()
return await resolver.aresolve(dependency)
[docs]
def enter_scope(
self,
scope: BaseScope | None = None,
) -> ResolverProtocol:
"""Enter a deeper scope and return a scoped resolver.
When ``scope`` is ``None``, DIWire transitions to the next deeper
non-skippable scope.
Args:
scope: Explicit target scope, or ``None`` for default next scope.
Returns:
Resolver bound to the target scope.
Raises:
DIWireScopeMismatchError: If transition is invalid for the current
scope level.
Examples:
.. code-block:: python
with container.enter_scope(Scope.REQUEST) as request_resolver:
service = request_resolver.resolve(Service)
"""
resolver = self._get_context_bound_resolver_or_none()
if resolver is None:
resolver = self._root_resolver
if resolver is None:
resolver = self.compile()
return resolver.enter_scope(scope)
def __enter__(self) -> ResolverProtocol:
"""Enter the resolver context."""
resolver = self.compile()
self._entered_root_resolver = resolver
return resolver.__enter__()
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Exit the resolver context and perform any necessary cleanup.
Cleanup will happen ONLY if the resolver created resources that need to be cleaned up.
Like context managers or generators.
"""
active_resolver = (
self._entered_root_resolver
if self._entered_root_resolver is not None
else self._root_resolver
)
if active_resolver is None:
msg = "Container context exit called without a matching enter."
raise RuntimeError(msg)
try:
return active_resolver.__exit__(exc_type, exc_value, traceback)
finally:
self._entered_root_resolver = None
def __aenter__(self) -> ResolverProtocol:
"""Asynchronously enter the resolver context."""
resolver = self.compile()
self._entered_root_resolver = resolver
return resolver.__aenter__()
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Asynchronously exit the resolver context and perform any necessary cleanup.
Cleanup will happen ONLY if the resolver created resources that need to be cleaned up.
Like context managers or generators.
"""
active_resolver = (
self._entered_root_resolver
if self._entered_root_resolver is not None
else self._root_resolver
)
if active_resolver is None:
msg = "Container async context exit called without a matching enter."
raise RuntimeError(msg)
try:
return await active_resolver.__aexit__(exc_type, exc_value, traceback)
finally:
self._entered_root_resolver = None
[docs]
def close(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
"""Close the root resolver and run pending cleanup callbacks.
Args:
exc_type: Optional exception type propagated to cleanup callbacks.
exc_value: Optional exception instance propagated to callbacks.
traceback: Optional traceback propagated to callbacks.
Raises:
RuntimeError: If called before entering/compiling a resolver context.
Notes:
Cleanup runs only for graphs that created cleanup-enabled resources.
Prefer ``with container.enter_scope(...)`` for deterministic request
cleanup.
"""
return self.__exit__(exc_type, exc_value, traceback)
[docs]
async def aclose(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
"""Asynchronously close the root resolver and run cleanup callbacks.
Args:
exc_type: Optional exception type propagated to cleanup callbacks.
exc_value: Optional exception instance propagated to callbacks.
traceback: Optional traceback propagated to callbacks.
Raises:
RuntimeError: If called before entering/compiling a resolver context.
Notes:
Prefer ``async with`` for scoped async workloads; use this when
owning a long-lived root resolver lifecycle explicitly.
"""
return await self.__aexit__(exc_type, exc_value, traceback)
# endregion Resolution and Scope Management
@dataclass(frozen=True, slots=True)
class _InjectedScopeContract:
callable_name: str
injected_parameters: tuple[InjectedParameter, ...]
scope: BaseScope
@dataclass(frozen=True, slots=True)
class _DecorationRule:
decorator: Callable[..., Any]
inner_parameter: str
dependencies: tuple[ProviderDependency, ...]
is_async: bool
@dataclass(slots=True)
class _DecorationChain:
base_key: Any
layer_keys: list[Any]
@dataclass(frozen=True, slots=True)
class _OpenDecorationAlias:
id: int
layer: int
@dataclass(frozen=True, slots=True)
class _DecorationBaseMetadata:
lifetime: Lifetime
scope: BaseScope
lock_mode: LockMode | Literal["auto"]
is_open_generic: bool
@dataclass(frozen=True, slots=True)
class _ContainerGraphSnapshot:
providers_registrations: ProvidersRegistrations.Snapshot
open_generic_registry: OpenGenericRegistry.Snapshot
decoration_rules_by_provides: dict[Any, list[_DecorationRule]]
decoration_chain_by_provides: dict[Any, _DecorationChain]
decoration_counter: int