Source code for diwire._internal.providers

from __future__ import annotations

import inspect
import threading
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Generator
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from dataclasses import dataclass, field
from enum import Enum, auto
from inspect import Parameter
from types import TracebackType
from typing import (
    Annotated,
    Any,
    ClassVar,
    Literal,
    Protocol,
    TypeAlias,
    TypeVar,
    cast,
    get_args,
    get_origin,
    get_type_hints,
)

from diwire._internal.lock_mode import LockMode
from diwire._internal.scope import BaseScope
from diwire.exceptions import (
    DIWireInvalidProviderSpecError,
    DIWireInvalidRegistrationError,
    DIWireProviderDependencyInferenceError,
)

T = TypeVar("T")
_CMT_co = TypeVar("_CMT_co", covariant=True)

UserDependency: TypeAlias = Any
"""A dependency that been registered or trying to be resolved from the user's code."""

UserProviderObject: TypeAlias = Any
"""An object, function, or class provided by the user as a provider. Determined by ProviderKind."""

ProviderSlot: TypeAlias = int
"""A unique slot number assigned to each provider specification."""

ConcreteTypeProvider: TypeAlias = type[T]
"""A concrete type that can be instantiated to produce a dependency."""

FactoryProvider: TypeAlias = Callable[..., T] | Callable[..., Awaitable[T]]
"""A factory function or asynchronous function that produces a dependency."""

GeneratorProvider: TypeAlias = (
    Callable[..., Generator[T, None, None]] | Callable[..., AsyncGenerator[T, None]]
)
"""A generator function or asynchronous generator function that yields a dependency."""


class _ContextManagerLike(Protocol[_CMT_co]):
    def __enter__(self) -> _CMT_co: ...

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> bool | None: ...


class _AsyncContextManagerLike(Protocol[_CMT_co]):
    def __aenter__(self) -> Awaitable[_CMT_co]: ...

    def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: TracebackType | None,
    ) -> Awaitable[bool | None]: ...


ContextManagerProvider: TypeAlias = (
    Callable[..., _ContextManagerLike[T]] | Callable[..., _AsyncContextManagerLike[T]]
)

_MISSING_ANNOTATION: Any = object()
_IMPLICIT_FIRST_PARAMETER_NAMES = {"self", "cls"}
_COROUTINE_RESULT_INDEX = 2
_COROUTINE_ARGUMENT_COUNT = 3
_ASYNC_RESOLVER_ORIGINS: tuple[type[Any], ...] = (
    Awaitable,
    Coroutine,
    AsyncGenerator,
    AbstractAsyncContextManager,
)


@dataclass(kw_only=True)
class ProviderSpec:
    """Describe how a single dependency key is produced and cached.

    A spec is the normalized registration model used by resolver generation.
    Exactly one provider source is expected in user-facing registrations
    (instance, concrete type, factory, generator, or context manager), plus
    metadata describing scope, lifetime, async requirements, cleanup behavior,
    and locking.
    """

    SLOT_COUNTER: ClassVar[ProviderSlot] = 0
    SLOT_COUNTER_LOCK: ClassVar[Any] = threading.Lock()

    provides: UserDependency
    """The dependency type that this provider supplies."""

    instance: UserDependency | None = None
    """An optional instance of the provided dependency, if applicable."""
    concrete_type: ConcreteTypeProvider[Any] | None = None
    """An optional concrete type of the provided dependency, if applicable."""
    factory: FactoryProvider[Any] | None = None
    """An optional factory function to create the provided dependency, if applicable."""
    generator: GeneratorProvider[Any] | None = None
    """An optional generator function to yield the provided dependency, if applicable."""
    context_manager: ContextManagerProvider[Any] | None = None
    """An optional context manager function to manage the provided dependency, if applicable."""
    dependencies: list[ProviderDependency] = field(default_factory=list)
    """A list of dependencies required by this provider."""
    is_async: bool
    """Indicates whether the provider specification itself is asynchronous."""
    is_any_dependency_async: bool
    """Indicates whether any dependency requires asynchronous resolution."""
    needs_cleanup: bool
    """True if provider itself or any dependency requires cleanup."""
    lock_mode: LockMode | Literal["auto"] = "auto"
    """Resolved locking strategy for this provider."""

    lifetime: Lifetime | None = None
    """The lifetime of the provided dependency. Could be none in case of instance providers."""
    scope: BaseScope
    """The scope in which the provided dependency is valid."""

    slot: ProviderSlot = field(init=False)
    """A unique slot number assigned to this provider specification."""

    def __post_init__(self) -> None:
        with self.SLOT_COUNTER_LOCK:
            self.__class__.SLOT_COUNTER += 1
            self.slot = self.SLOT_COUNTER


[docs] class Lifetime(Enum): """Define cache behavior for provider results.""" TRANSIENT = auto() """Disable caching and build a new value for every resolution call.""" SCOPED = auto() """Cache per declared scope owner until that scope exits. For registrations declared on the container root scope, this behaves like a singleton for the container lifetime. Cleanup-enabled providers are released when the owning scope/resolver is closed. """
class ProvidersRegistrations: """Store provider specs indexed by dependency key and slot. Registration keys are unique: adding a spec for an existing dependency key replaces the previous spec. Slot indexing is maintained for resolver planning, and cleanup flags are recomputed after each mutation. """ def __init__(self) -> None: self._registrations_by_type: dict[UserDependency, ProviderSpec] = {} self._registrations_by_slot: dict[int, ProviderSpec] = {} @dataclass(frozen=True, slots=True) class Snapshot: """Capture provider registration state for transactional rollback.""" registrations_by_type: dict[UserDependency, ProviderSpec] registrations_by_slot: dict[int, ProviderSpec] def snapshot(self) -> Snapshot: """Capture current registrations for rollback.""" return self.Snapshot( registrations_by_type=dict(self._registrations_by_type), registrations_by_slot=dict(self._registrations_by_slot), ) def restore(self, snapshot: Snapshot) -> None: """Restore registrations from a previous snapshot. Args: snapshot: Previously captured snapshot state to restore into the registry. """ self._registrations_by_type = dict(snapshot.registrations_by_type) self._registrations_by_slot = dict(snapshot.registrations_by_slot) self._refresh_needs_cleanup_flags() def add(self, spec: ProviderSpec) -> None: """Add a new provider specification to the registrations. Args: spec: Provider specification to register. """ if previous_spec := self._registrations_by_type.get(spec.provides): self._registrations_by_slot.pop(previous_spec.slot, None) self._registrations_by_type[spec.provides] = spec self._registrations_by_slot[spec.slot] = spec self._refresh_needs_cleanup_flags() def get_by_type(self, dep_type: UserDependency) -> ProviderSpec: """Get a provider specification by the type of dependency it provides. Args: dep_type: Dependency type key to look up. """ return self._registrations_by_type[dep_type] def find_by_type(self, dep_type: UserDependency) -> ProviderSpec | None: """Get a provider specification by dependency type, if it exists. Args: dep_type: Dependency type key to look up. """ return self._registrations_by_type.get(dep_type) def get_by_slot(self, slot: int) -> ProviderSpec: """Get a provider specification by its unique slot number. Args: slot: Provider slot identifier to look up. """ return self._registrations_by_slot[slot] def get_by_scope(self, scope: BaseScope | None) -> list[ProviderSpec]: """Get all provider specifications registered for a specific scope. Args: scope: Scope value used to filter registrations or open nested resolution scope. """ return [spec for spec in self._registrations_by_type.values() if spec.scope == scope] def values(self) -> list[ProviderSpec]: """Get all provider specifications.""" return list(self._registrations_by_type.values()) def _refresh_needs_cleanup_flags(self) -> None: """Recompute cleanup requirements for all registered providers.""" # Dependencies can be registered after dependents, so recompute until stable. has_changes = True while has_changes: has_changes = False for spec in self.values(): dependency_needs_cleanup = any( dependency_spec.needs_cleanup for dependency in spec.dependencies if (dependency_spec := self.find_by_type(dependency.provides)) ) needs_cleanup = ( spec.generator is not None or spec.context_manager is not None or dependency_needs_cleanup ) if spec.needs_cleanup is not needs_cleanup: spec.needs_cleanup = needs_cleanup has_changes = True def __len__(self) -> int: return len(self._registrations_by_type) @dataclass(slots=True) class ProviderDependenciesExtractor: """Extracts dependencies from user-defined provider objects.""" def extract_from_concrete_type( self, concrete_type: ConcreteTypeProvider[Any], ) -> list[ProviderDependency]: """Extract dependencies from a concrete type-based provider. Args: concrete_type: Concrete class provider to inspect or validate. """ return self._extract_dependencies( provider=concrete_type, provider_name=concrete_type.__qualname__, skip_first_parameter=False, ) def extract_from_factory( self, factory: FactoryProvider[Any], ) -> list[ProviderDependency]: """Extract dependencies from a factory-based provider. Args: factory: Factory provider callable to inspect or validate. """ return self._extract_dependencies( provider=factory, provider_name=self._provider_name(factory), skip_first_parameter=False, ) def extract_from_generator( self, generator: GeneratorProvider[Any], ) -> list[ProviderDependency]: """Extract dependencies from a generator-based provider. Args: generator: Generator provider callable to inspect or validate. """ return self._extract_dependencies( provider=generator, provider_name=self._provider_name(generator), skip_first_parameter=False, ) def extract_from_context_manager( self, context_manager: ContextManagerProvider[Any], ) -> list[ProviderDependency]: """Extract dependencies from a context manager-based provider. Args: context_manager: Context manager provider callable to inspect or validate. """ return self._extract_dependencies( provider=context_manager, provider_name=self._provider_name(context_manager), skip_first_parameter=False, ) def validate_explicit_for_concrete_type( self, concrete_type: ConcreteTypeProvider[Any], dependencies: list[ProviderDependency], ) -> list[ProviderDependency]: """Validate explicit dependencies for a concrete type-based provider. Args: concrete_type: Concrete class provider to inspect or validate. dependencies: Explicit dependency declarations to validate or inspect for async requirements. """ return self._validate_explicit_dependencies( provider=concrete_type, provider_name=concrete_type.__qualname__, dependencies=dependencies, skip_first_parameter=False, ) def validate_explicit_for_factory( self, factory: FactoryProvider[Any], dependencies: list[ProviderDependency], ) -> list[ProviderDependency]: """Validate explicit dependencies for a factory-based provider. Args: factory: Factory provider callable to inspect or validate. dependencies: Explicit dependency declarations to validate or inspect for async requirements. """ return self._validate_explicit_dependencies( provider=factory, provider_name=self._provider_name(factory), dependencies=dependencies, skip_first_parameter=False, ) def validate_explicit_for_generator( self, generator: GeneratorProvider[Any], dependencies: list[ProviderDependency], ) -> list[ProviderDependency]: """Validate explicit dependencies for a generator-based provider. Args: generator: Generator provider callable to inspect or validate. dependencies: Explicit dependency declarations to validate or inspect for async requirements. """ return self._validate_explicit_dependencies( provider=generator, provider_name=self._provider_name(generator), dependencies=dependencies, skip_first_parameter=False, ) def validate_explicit_for_context_manager( self, context_manager: ContextManagerProvider[Any], dependencies: list[ProviderDependency], ) -> list[ProviderDependency]: """Validate explicit dependencies for a context manager-based provider. Args: context_manager: Context manager provider callable to inspect or validate. dependencies: Explicit dependency declarations to validate or inspect for async requirements. """ return self._validate_explicit_dependencies( provider=context_manager, provider_name=self._provider_name(context_manager), dependencies=dependencies, skip_first_parameter=False, ) def _extract_dependencies( self, *, provider: Callable[..., Any], provider_name: str, skip_first_parameter: bool, ) -> list[ProviderDependency]: parameters = self._provider_parameters( provider=provider, skip_first_parameter=skip_first_parameter, ) annotations, annotation_error = self._resolved_type_hints(provider) dependencies: list[ProviderDependency] = [] for parameter in parameters: provides = self._resolve_parameter_annotation( parameter=parameter, annotations=annotations, annotation_error=annotation_error, provider_name=provider_name, ) if provides is _MISSING_ANNOTATION: continue dependencies.append( ProviderDependency( provides=provides, parameter=parameter, ), ) return dependencies def _validate_explicit_dependencies( self, *, provider: Callable[..., Any], provider_name: str, dependencies: list[ProviderDependency], skip_first_parameter: bool, ) -> list[ProviderDependency]: parameters = self._provider_parameters( provider=provider, skip_first_parameter=skip_first_parameter, ) parameters_by_name = {parameter.name: parameter for parameter in parameters} validated_dependencies_by_name: dict[str, ProviderDependency] = {} for dependency in dependencies: parameter_name = dependency.parameter.name signature_parameter = parameters_by_name.get(parameter_name) if signature_parameter is None: msg = ( f"Explicit dependency for unknown parameter '{parameter_name}' in " f"provider '{provider_name}'." ) raise DIWireInvalidProviderSpecError(msg) if parameter_name in validated_dependencies_by_name: msg = ( f"Explicit dependency for parameter '{parameter_name}' in " f"provider '{provider_name}' is duplicated." ) raise DIWireInvalidProviderSpecError(msg) if dependency.parameter.kind is not signature_parameter.kind: msg = ( f"Explicit dependency for parameter '{parameter_name}' in provider " f"'{provider_name}' has kind '{dependency.parameter.kind}' but expected " f"'{signature_parameter.kind}'." ) raise DIWireInvalidProviderSpecError(msg) validated_dependencies_by_name[parameter_name] = ProviderDependency( provides=dependency.provides, parameter=signature_parameter, ) missing_required_parameters = [ parameter.name for parameter in parameters if self._is_required_parameter(parameter) and parameter.name not in validated_dependencies_by_name ] if missing_required_parameters: missing_parameters = ", ".join(f"'{name}'" for name in missing_required_parameters) msg = ( f"Explicit dependencies for provider '{provider_name}' are incomplete. " f"Missing required parameters: {missing_parameters}." ) raise DIWireProviderDependencyInferenceError(msg) return [ validated_dependencies_by_name[parameter.name] for parameter in parameters if parameter.name in validated_dependencies_by_name ] def _resolve_parameter_annotation( self, *, parameter: Parameter, annotations: dict[str, Any], annotation_error: Exception | None, provider_name: str, ) -> Any: annotation = annotations.get(parameter.name, _MISSING_ANNOTATION) if annotation is not _MISSING_ANNOTATION: return annotation raw_annotation = parameter.annotation if raw_annotation is not Parameter.empty and not isinstance(raw_annotation, str): return raw_annotation if not self._is_required_parameter(parameter): return _MISSING_ANNOTATION error_message = ( f"Unable to infer dependency for required parameter '{parameter.name}' " f"in provider '{provider_name}'. Add a type annotation or pass explicit dependencies." ) if annotation_error is None: raise DIWireProviderDependencyInferenceError(error_message) msg = f"{error_message} Original annotation error: {annotation_error}" raise DIWireProviderDependencyInferenceError(msg) from annotation_error def _provider_parameters( self, *, provider: Callable[..., Any], skip_first_parameter: bool, ) -> tuple[Parameter, ...]: parameters = tuple(inspect.signature(provider).parameters.values()) if ( skip_first_parameter and parameters and parameters[0].name in _IMPLICIT_FIRST_PARAMETER_NAMES ): return parameters[1:] return parameters def _resolved_type_hints( self, provider: Callable[..., Any], ) -> tuple[dict[str, Any], Exception | None]: annotations: dict[str, Any] = {} annotation_error: Exception | None = None try: annotations = get_type_hints(provider, include_extras=True) except (AttributeError, NameError, TypeError) as error: annotation_error = error if inspect.isclass(provider): annotations, annotation_error = self._merge_concrete_type_callable_hints( concrete_type=provider, annotations=annotations, annotation_error=annotation_error, ) return annotations, annotation_error def _merge_concrete_type_callable_hints( self, *, concrete_type: type[Any], annotations: dict[str, Any], annotation_error: Exception | None, ) -> tuple[dict[str, Any], Exception | None]: merged_annotations = dict(annotations) merged_error = annotation_error for callable_member_name in ("__call__", "__new__", "__init__"): callable_member = getattr(concrete_type, callable_member_name) try: member_annotations = get_type_hints(callable_member, include_extras=True) except (AttributeError, NameError, TypeError) as error: if merged_error is None: merged_error = error continue for parameter_name, parameter_annotation in member_annotations.items(): merged_annotations.setdefault(parameter_name, parameter_annotation) return merged_annotations, merged_error def _is_required_parameter(self, parameter: Parameter) -> bool: return ( parameter.default is Parameter.empty and parameter.kind is not Parameter.VAR_POSITIONAL and parameter.kind is not Parameter.VAR_KEYWORD ) def _provider_name(self, provider: Callable[..., Any]) -> str: return getattr(provider, "__qualname__", repr(provider)) @dataclass(slots=True) class ProviderDependency: """Represent a resolved dependency key bound to a provider parameter.""" provides: UserDependency parameter: Parameter @dataclass(slots=True) class ProviderReturnTypeExtractor: """Extracts return types from user-defined provider objects.""" def is_factory_async( self, factory: FactoryProvider[Any], ) -> bool: """Check whether a factory provider itself is asynchronous. Args: factory: Factory provider callable to inspect or validate. """ return inspect.iscoroutinefunction(factory) or self.return_annotation_matches_origins( provider=factory, expected_origins=(Awaitable, Coroutine), ) def is_generator_async( self, generator: GeneratorProvider[Any], ) -> bool: """Check whether a generator provider itself is asynchronous. Args: generator: Generator provider callable to inspect or validate. """ return inspect.isasyncgenfunction(generator) or self.return_annotation_matches_origins( provider=generator, expected_origins=(AsyncGenerator,), ) def is_context_manager_async( self, context_manager: ContextManagerProvider[Any], ) -> bool: """Check whether a context manager provider itself is asynchronous. Args: context_manager: Context manager provider callable to inspect or validate. """ unwrapped_context_manager = inspect.unwrap(context_manager) if inspect.isasyncgenfunction(unwrapped_context_manager): return True if inspect.iscoroutinefunction(unwrapped_context_manager): return True if self.return_annotation_matches_origins( provider=unwrapped_context_manager, expected_origins=(AbstractAsyncContextManager, AsyncGenerator), ): return True if isinstance(unwrapped_context_manager, type) and ( self._context_manager_mode(unwrapped_context_manager) == "async" ): return True return_annotation, _annotation_error = self._resolved_return_annotation( unwrapped_context_manager, ) unwrapped_return_annotation = self.unwrap_annotated(return_annotation) if isinstance(unwrapped_return_annotation, type): return self._context_manager_mode(unwrapped_return_annotation) == "async" return False def is_any_dependency_async( self, dependencies: list[ProviderDependency], ) -> bool: """Check whether provider dependencies require asynchronous resolution. Args: dependencies: Explicit dependency declarations to validate or inspect for async requirements. """ return any( self.annotation_matches_origins( annotation=dependency.provides, expected_origins=_ASYNC_RESOLVER_ORIGINS, ) for dependency in dependencies ) def extract_from_factory( self, factory: FactoryProvider[Any], ) -> Any: """Extract a return type from a factory-based provider. Args: factory: Factory provider callable to inspect or validate. """ return_annotation, annotation_error = self._resolved_return_annotation(factory) provider_name = self._provider_name(factory) if return_annotation is _MISSING_ANNOTATION: self._raise_missing_return_annotation_error( provider_kind="factory", provider_name=provider_name, annotation_error=annotation_error, ) unwrapped_return_type = self._unwrap_factory_return_type(return_annotation) if unwrapped_return_type is _MISSING_ANNOTATION: self._raise_invalid_return_annotation_error( provider_kind="factory", provider_name=provider_name, expected_annotations="`T`, `Awaitable[T]`, or `Coroutine[Any, Any, T]`", annotation_error=annotation_error, ) return unwrapped_return_type def extract_from_generator( self, generator: GeneratorProvider[Any], ) -> Any: """Extract a return type from a generator-based provider. Args: generator: Generator provider callable to inspect or validate. """ return_annotation, annotation_error = self._resolved_return_annotation(generator) provider_name = self._provider_name(generator) if return_annotation is _MISSING_ANNOTATION: self._raise_missing_return_annotation_error( provider_kind="generator", provider_name=provider_name, annotation_error=annotation_error, ) yielded_type = self._extract_yielded_or_managed_type( return_annotation=return_annotation, expected_origins=(Generator, AsyncGenerator), ) if yielded_type is _MISSING_ANNOTATION: self._raise_invalid_return_annotation_error( provider_kind="generator", provider_name=provider_name, expected_annotations="`Generator[T, None, None]` or `AsyncGenerator[T, None]`", annotation_error=annotation_error, ) return yielded_type def extract_from_context_manager( self, context_manager: ContextManagerProvider[Any], ) -> Any: """Extract a return type from a context manager-based provider. Args: context_manager: Context manager provider callable to inspect or validate. """ return_annotation, annotation_error = self._resolved_return_annotation(context_manager) provider_name = self._provider_name(context_manager) if return_annotation is not _MISSING_ANNOTATION: yielded_type = self._extract_yielded_or_managed_type( return_annotation=return_annotation, expected_origins=( AbstractContextManager, AbstractAsyncContextManager, Generator, AsyncGenerator, ), ) if yielded_type is not _MISSING_ANNOTATION: return yielded_type if isinstance(context_manager, type): managed_type = self._infer_managed_type_from_cm_type(context_manager) if managed_type is not _MISSING_ANNOTATION: return managed_type unwrapped_return_annotation = self.unwrap_annotated(return_annotation) if isinstance(unwrapped_return_annotation, type): managed_type = self._infer_managed_type_from_cm_type(unwrapped_return_annotation) if managed_type is not _MISSING_ANNOTATION: return managed_type msg = ( f"Unable to infer return type for context manager provider '{provider_name}'. " "Add a provider return annotation (`AbstractContextManager[T]`, " "`AbstractAsyncContextManager[T]`, `Generator[T, None, None]`, or " "`AsyncGenerator[T, None]`), annotate `__enter__` or `__aenter__` on the " "context manager class, or pass provides= explicitly." ) self._raise_invalid_registration_error(msg=msg, annotation_error=annotation_error) return _MISSING_ANNOTATION # pragma: no cover def return_annotation_matches_origins( self, *, provider: Callable[..., Any], expected_origins: tuple[type[Any], ...], ) -> bool: """Check whether provider return annotation origin matches one of expected origins. Args: provider: Provider callable or type whose return annotation is being checked. expected_origins: Allowed annotation origins for a successful match check. """ return_annotation, _annotation_error = self._resolved_return_annotation(provider) if return_annotation is _MISSING_ANNOTATION: return False return self.annotation_matches_origins( annotation=return_annotation, expected_origins=expected_origins, ) def annotation_matches_origins( self, *, annotation: Any, expected_origins: tuple[type[Any], ...], ) -> bool: """Check whether annotation origin matches one of expected origins. Args: annotation: Annotation value to inspect or normalize. expected_origins: Allowed annotation origins for a successful match check. """ annotation = self.unwrap_annotated(annotation) origin = get_origin(annotation) if origin in expected_origins: return True return annotation in expected_origins def unwrap_annotated( self, annotation: Any, ) -> Any: """Recursively unwrap Annotated[T, ...] into T. Args: annotation: Annotation value to inspect or normalize. """ if get_origin(annotation) is not Annotated: return annotation annotation_args = get_args(annotation) if ( not annotation_args ): # pragma: no cover - typing.Annotated always wraps at least one type return annotation return self.unwrap_annotated(annotation_args[0]) def _is_self_annotation( self, annotation: Any, ) -> bool: unwrapped_annotation = self.unwrap_annotated(annotation) annotation_module = getattr(unwrapped_annotation, "__module__", None) if annotation_module not in {"typing", "typing_extensions"}: return False annotation_name = getattr( unwrapped_annotation, "__qualname__", getattr(unwrapped_annotation, "__name__", None), ) return annotation_name == "Self" def _infer_managed_type_from_cm_type( self, cm_type: type[Any], ) -> Any: context_manager_mode = self._context_manager_mode(cm_type) if context_manager_mode == "none": return _MISSING_ANNOTATION enter_method_name = "__aenter__" if context_manager_mode == "async" else "__enter__" enter_method = cast("Callable[..., Any]", getattr(cm_type, enter_method_name)) return_annotation, _annotation_error = self._resolved_return_annotation(enter_method) if return_annotation is _MISSING_ANNOTATION: return _MISSING_ANNOTATION managed_type = return_annotation if context_manager_mode == "async": managed_type = self._unwrap_factory_return_type(managed_type) if managed_type is _MISSING_ANNOTATION: return _MISSING_ANNOTATION if self._is_self_annotation(managed_type): return cm_type return managed_type def _context_manager_mode( self, cm_type: type[Any], ) -> Literal["sync", "async", "none"]: has_sync_methods = callable(getattr(cm_type, "__enter__", None)) and callable( getattr(cm_type, "__exit__", None), ) has_async_methods = callable(getattr(cm_type, "__aenter__", None)) and callable( getattr(cm_type, "__aexit__", None), ) if has_async_methods and not has_sync_methods: return "async" if has_sync_methods: return "sync" return "none" def _resolved_return_annotation( self, provider: Callable[..., Any], ) -> tuple[Any, Exception | None]: try: return_type_hints = get_type_hints(provider, include_extras=True) annotation_error: Exception | None = None except (AttributeError, NameError, TypeError) as error: return_type_hints = {} annotation_error = error resolved_return_annotation = return_type_hints.get("return", _MISSING_ANNOTATION) if resolved_return_annotation is not _MISSING_ANNOTATION: return resolved_return_annotation, annotation_error try: raw_return_annotation = inspect.signature(provider).return_annotation except (TypeError, ValueError) as error: if annotation_error is None: annotation_error = error return _MISSING_ANNOTATION, annotation_error if raw_return_annotation is inspect.Signature.empty or isinstance( raw_return_annotation, str, ): return _MISSING_ANNOTATION, annotation_error return raw_return_annotation, annotation_error def _unwrap_factory_return_type( self, return_annotation: Any, ) -> Any: origin = get_origin(return_annotation) annotation_args = get_args(return_annotation) if origin is Awaitable: if len(annotation_args) != 1: return _MISSING_ANNOTATION return annotation_args[0] if origin is Coroutine: if len(annotation_args) != _COROUTINE_ARGUMENT_COUNT: return _MISSING_ANNOTATION return annotation_args[_COROUTINE_RESULT_INDEX] return return_annotation def _extract_yielded_or_managed_type( self, *, return_annotation: Any, expected_origins: tuple[type[Any], ...], ) -> Any: origin = get_origin(return_annotation) if origin not in expected_origins: return _MISSING_ANNOTATION annotation_args = get_args(return_annotation) if len(annotation_args) < 1: return _MISSING_ANNOTATION return annotation_args[0] def _raise_missing_return_annotation_error( self, *, provider_kind: str, provider_name: str, annotation_error: Exception | None, ) -> None: msg = ( f"Unable to infer return type for {provider_kind} provider '{provider_name}'. " "Add a return type annotation or pass provides= explicitly." ) self._raise_invalid_registration_error(msg=msg, annotation_error=annotation_error) def _raise_invalid_return_annotation_error( self, *, provider_kind: str, provider_name: str, expected_annotations: str, annotation_error: Exception | None, ) -> None: msg = ( f"Unable to infer return type for {provider_kind} provider '{provider_name}'. " f"Expected return annotation {expected_annotations}. " "Add a valid return annotation or pass provides= explicitly." ) self._raise_invalid_registration_error(msg=msg, annotation_error=annotation_error) def _raise_invalid_registration_error( self, *, msg: str, annotation_error: Exception | None, ) -> None: if annotation_error is None: raise DIWireInvalidRegistrationError(msg) full_msg = f"{msg} Original annotation error: {annotation_error}" raise DIWireInvalidRegistrationError(full_msg) from annotation_error def _provider_name(self, provider: Callable[..., Any]) -> str: return getattr(provider, "__qualname__", repr(provider))