from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Annotated, Any, NamedTuple, TypeVar, Union, get_args, get_origin
T = TypeVar("T")
_ANNOTATED_MARKER_MIN_ARGS = 2
[docs]
class Component(NamedTuple):
"""Differentiate multiple providers for the same base type.
Attach ``Component`` metadata to ``typing.Annotated`` so DIWire treats each
annotated key as distinct at runtime.
Examples:
.. code-block:: python
from typing import Annotated, TypeAlias
class Database: ...
ReplicaDb: TypeAlias = Annotated[Database, Component("replica")]
PrimaryDb: TypeAlias = Annotated[Database, Component("primary")]
"""
value: Any
class InjectedMarker:
"""A marker used to indicate a parameter should be injected from the DI container.
Used to identify parameters that need to be removed from callable signatures.
"""
class MaybeMarker:
"""Marker that indicates dependency is optional and may resolve to ``None``."""
class ProviderMarker(NamedTuple):
"""Marker for lazy resolver-bound provider callables."""
dependency_key: Any
is_async: bool
class AllMarker(NamedTuple):
"""Marker for collecting all implementations registered for a base dependency key."""
dependency_key: Any
if TYPE_CHECKING:
Injected = Union[T, T] # noqa: UP007,PYI016
"""Mark a parameter for container-driven injection.
At runtime ``Injected[T]`` becomes ``Annotated[T, InjectedMarker()]``.
Container wrappers hide these parameters from the public callable signature.
Examples:
.. code-block:: python
@resolver_context.inject
def run(service: Injected[Service], value: int) -> str:
return service.handle(value)
"""
Provider = Callable[[], T]
"""Mark a dependency as a resolver-bound lazy provider callable.
At runtime ``Provider[T]`` becomes ``Annotated[T, ProviderMarker(...)]`` and resolves
to ``Callable[[], T]`` bound to the injection-time resolver.
"""
AsyncProvider = Callable[[], Awaitable[T]]
"""Mark a dependency as a resolver-bound async lazy provider callable.
At runtime ``AsyncProvider[T]`` becomes ``Annotated[T, ProviderMarker(...)]`` and resolves
to ``Callable[[], Awaitable[T]]`` bound to the injection-time resolver.
"""
All = tuple[T, ...]
"""Resolve all implementations registered for a base dependency key.
``All[T]`` type-checks as ``tuple[T, ...]`` and always resolves to a tuple.
It returns an empty tuple when no matching registrations exist.
"""
Maybe = T | None # type: ignore[misc]
"""Mark a dependency as explicitly optional.
At runtime ``Maybe[T]`` becomes ``Annotated[T, MaybeMarker()]``.
"""
else:
[docs]
class Injected:
"""Mark a parameter for container-driven injection.
At runtime ``Injected[T]`` resolves to ``Annotated[T, InjectedMarker()]``.
Examples:
.. code-block:: python
@resolver_context.inject
def run(service: Injected[Service], value: int) -> str:
return service.handle(value)
"""
def __class_getitem__(cls, item: T) -> Annotated[T, InjectedMarker]:
if get_origin(item) is Annotated:
args = get_args(item)
inner = args[0]
metadata = args[1:]
return _build_annotated((inner, *metadata, InjectedMarker()))
return _build_annotated((item, InjectedMarker()))
[docs]
class Maybe:
"""Mark a dependency as explicitly optional.
At runtime ``Maybe[T]`` resolves to ``Annotated[T, MaybeMarker()]``.
"""
def __class_getitem__(cls, item: T) -> Annotated[T, MaybeMarker]:
if get_origin(item) is Annotated:
args = get_args(item)
inner = args[0]
metadata = args[1:]
return _build_annotated((inner, *metadata, MaybeMarker()))
return _build_annotated((item, MaybeMarker()))
[docs]
class Provider:
"""Mark a dependency for lazy resolver-bound sync provider injection."""
def __class_getitem__(cls, item: T) -> Annotated[T, ProviderMarker]:
return _build_provider_annotation(item=item, is_async=False)
[docs]
class AsyncProvider:
"""Mark a dependency for lazy resolver-bound async provider injection."""
def __class_getitem__(cls, item: T) -> Annotated[T, ProviderMarker]:
return _build_provider_annotation(item=item, is_async=True)
[docs]
class All:
"""Resolve all implementations registered for a base dependency key.
At runtime ``All[T]`` resolves to ``Annotated[T, AllMarker(dependency_key=T)]`` and is
detected by the resolver dispatch to collect the plain registration for ``T`` (if any)
plus all component-qualified registrations keyed as ``Annotated[T, Component(...)]``.
Notes:
If you pass an ``Annotated[...]`` token (for example
``All[Annotated[T, Component('x')]]``), DIWire strips it to the base type ``T`` and
produces the same token as ``All[T]``. Prefer using ``All[BaseType]``.
"""
def __class_getitem__(cls, item: Any) -> Any:
base_key = item
if get_origin(item) is Annotated:
args = get_args(item)
base_key = args[0]
return _build_annotated((base_key, AllMarker(dependency_key=base_key)))
def is_maybe_annotation(annotation: Any) -> bool:
"""Return True when annotation is Annotated[..., MaybeMarker()].
Args:
annotation: Annotation value to inspect or normalize.
"""
if get_origin(annotation) is not Annotated:
return False
annotation_args = get_args(annotation)
if len(annotation_args) < _ANNOTATED_MARKER_MIN_ARGS:
return False
metadata = annotation_args[1:]
return any(isinstance(item, MaybeMarker) for item in metadata)
def strip_maybe_annotation(annotation: Any) -> Any:
"""Strip Maybe marker while preserving non-maybe Annotated metadata.
Args:
annotation: Annotation value to inspect or normalize.
"""
if not is_maybe_annotation(annotation):
return annotation
annotation_args = get_args(annotation)
parameter_type = annotation_args[0]
metadata = annotation_args[1:]
filtered_metadata = tuple(item for item in metadata if not isinstance(item, MaybeMarker))
if not filtered_metadata:
return parameter_type
return _build_annotated((parameter_type, *filtered_metadata))
def is_provider_annotation(annotation: Any) -> bool:
"""Return True when annotation is Annotated[..., ProviderMarker(...)].
Args:
annotation: Annotation value to inspect or normalize.
"""
return _extract_provider_marker(annotation) is not None
def is_all_annotation(annotation: Any) -> bool:
"""Return True when annotation is Annotated[..., AllMarker(...)].
Args:
annotation: Annotation value to inspect or normalize.
"""
return _extract_all_marker(annotation) is not None
def strip_provider_annotation(annotation: Any) -> Any:
"""Return inner dependency key for Provider/AsyncProvider annotations.
Args:
annotation: Annotation value to inspect or normalize.
"""
marker = _extract_provider_marker(annotation)
if marker is None:
return annotation
return marker.dependency_key
def strip_all_annotation(annotation: Any) -> Any:
"""Return inner dependency key for All[...] annotations.
Args:
annotation: Annotation value to inspect or normalize.
"""
marker = _extract_all_marker(annotation)
if marker is None:
return annotation
return marker.dependency_key
def is_async_provider_annotation(annotation: Any) -> bool:
"""Return True when annotation is AsyncProvider[...] marker.
Args:
annotation: Annotation value to inspect or normalize.
"""
marker = _extract_provider_marker(annotation)
if marker is None:
return False
return marker.is_async
def component_base_key(annotation: Any) -> Any | None:
"""Return base dependency key for ``Annotated[Base, Component(...)]`` registrations.
Args:
annotation: Annotation value to inspect or normalize.
"""
if get_origin(annotation) is not Annotated:
return None
annotation_args = get_args(annotation)
if len(annotation_args) < _ANNOTATED_MARKER_MIN_ARGS:
return None
metadata = annotation_args[1:]
if not any(isinstance(item, Component) for item in metadata):
return None
return annotation_args[0]
def strip_non_component_annotation(annotation: Any) -> Any:
"""Strip non-Component metadata from ``Annotated`` dependency keys.
Args:
annotation: Annotation value to inspect or normalize.
"""
metadata = getattr(annotation, "__metadata__", None)
if metadata is None:
return annotation
origin = getattr(annotation, "__origin__", None)
if origin is None:
if get_origin(annotation) is not Annotated:
return annotation
annotation_args = get_args(annotation)
if len(annotation_args) < _ANNOTATED_MARKER_MIN_ARGS:
return annotation
origin = annotation_args[0]
metadata = annotation_args[1:]
component_metadata: list[Component] = []
has_non_component_metadata = False
for item in metadata:
if isinstance(item, Component):
component_metadata.append(item)
else:
has_non_component_metadata = True
if not component_metadata:
return origin
if not has_non_component_metadata:
return annotation
return _build_annotated((origin, *component_metadata))
def _extract_provider_marker(annotation: Any) -> ProviderMarker | None:
if get_origin(annotation) is not Annotated:
return None
annotation_args = get_args(annotation)
if len(annotation_args) < _ANNOTATED_MARKER_MIN_ARGS:
return None
metadata = annotation_args[1:]
return next(
(item for item in metadata if isinstance(item, ProviderMarker)),
None,
)
def _extract_all_marker(annotation: Any) -> AllMarker | None:
if get_origin(annotation) is not Annotated:
return None
annotation_args = get_args(annotation)
if len(annotation_args) < _ANNOTATED_MARKER_MIN_ARGS:
return None
metadata = annotation_args[1:]
return next(
(item for item in metadata if isinstance(item, AllMarker)),
None,
)
def _build_provider_annotation(*, item: T, is_async: bool) -> Annotated[T, ProviderMarker]:
provider_marker = ProviderMarker(dependency_key=item, is_async=is_async)
if get_origin(item) is Annotated:
args = get_args(item)
inner = args[0]
metadata = args[1:]
return _build_annotated((inner, *metadata, provider_marker))
return _build_annotated((item, provider_marker))
def build_annotated_key(params: tuple[object, ...]) -> Any:
"""Return Annotated[...] with a pre-built params tuple (Py 3.10+ compatible).
Args:
params: Pre-built tuple of Annotated arguments to combine into a dependency key.
"""
try:
return Annotated.__class_getitem__(params) # type: ignore[attr-defined]
except AttributeError:
return Annotated.__getitem__(params) # type: ignore[attr-defined]
def _build_annotated(params: tuple[object, ...]) -> Any:
return build_annotated_key(params)