# SPDX-FileCopyrightText: 2023 Hynek Schlawack <hs@ox.cx>
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import inspect
import logging
import warnings
from collections.abc import Callable
from contextlib import (
AbstractAsyncContextManager,
AbstractContextManager,
asynccontextmanager,
contextmanager,
suppress,
)
from inspect import (
isasyncgenfunction,
isawaitable,
iscoroutine,
iscoroutinefunction,
isgeneratorfunction,
)
from types import TracebackType
from typing import Any, Awaitable, TypeVar, overload
import attrs
from .exceptions import ServiceNotFoundError
log = logging.getLogger("svcs")
def _full_name(obj: object) -> str:
try:
return f"{obj.__module__}.{obj.__qualname__}" # type: ignore[attr-defined]
except AttributeError:
return repr(obj)
# Default names where to put the container and registry in integrations.
_KEY_REGISTRY = "svcs_registry"
_KEY_CONTAINER = "svcs_container"
@attrs.frozen
class RegisteredService:
svc_type: type
factory: Callable = attrs.field(hash=False)
takes_container: bool
enter: bool
ping: Callable | None = attrs.field(hash=False)
@property
def name(self) -> str:
return _full_name(self.svc_type)
def __repr__(self) -> str:
return (
f"<RegisteredService(svc_type="
f"{self.name}, "
f"factory={self.factory}, "
f"takes_container={self.takes_container}, "
f"enter={self.enter}, "
f"has_ping={ self.ping is not None}"
")>"
)
[docs]
@attrs.frozen
class ServicePing:
"""
A service health check as returned by :meth:`svcs.Container.get_pings`.
Attributes:
name: A fully-qualified name of the service type.
is_async: Whether the service needs to be pinged using :meth:`aping`.
See Also:
:ref:`health`
"""
name: str
is_async: bool
_svc_type: type
_ping: Callable
_container: Container
[docs]
def ping(self) -> None:
"""
Instantiate the service, schedule its cleanup, and call its ping
method.
"""
svc: Any = self._container.get(self._svc_type)
self._ping(svc)
[docs]
async def aping(self) -> None:
"""
Same as :meth:`ping` but instantiate and/or ping asynchronously, if
necessary.
Also works with synchronous services, so in an async application, just
use this.
"""
svc: Any = await self._container.aget(self._svc_type)
if self.is_async:
await self._ping(svc)
else:
self._ping(svc)
[docs]
@attrs.define
class Registry:
"""
A central registry of recipes for creating services.
An instance of this should live as long as your application does.
Also works as a context manager that runs ``on_registry_close`` callbacks
on exit:
.. doctest::
>>> import svcs
>>> with svcs.Registry() as reg:
... reg.register_value(
... int, 42,
... on_registry_close=lambda: print("closed!")
... )
closed!
``async with`` is also supported.
Warns:
ResourceWarning:
If a registry with pending cleanups is garbage-collected.
"""
_services: dict[type, RegisteredService] = attrs.Factory(dict)
_on_close: list[tuple[str, Callable | Awaitable]] = attrs.Factory(list)
def __repr__(self) -> str:
return f"<svcs.Registry(num_services={len(self._services)})>"
[docs]
def __contains__(self, svc_type: type) -> bool:
"""
Check whether this registry knows how to create *svc_type*:
.. doctest::
>>> reg = svcs.Registry()
>>> reg.register_value(int, 42)
>>> int in reg
True
>>> str in reg
False
"""
return svc_type in self._services
def __enter__(self) -> Registry:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
async def __aenter__(self) -> Registry:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
def __del__(self) -> None:
"""
Warn if the registry is gc'ed before being closed.
"""
if getattr(self, "_on_close", None):
warnings.warn(
"Registry was garbage-collected with pending cleanups.",
ResourceWarning,
stacklevel=1,
)
[docs]
def register_factory(
self,
svc_type: type,
factory: Callable,
*,
enter: bool = True,
ping: Callable | None = None,
on_registry_close: Callable | Awaitable | None = None,
) -> None:
"""
Register *factory* to be used when asked for a *svc_type*.
Repeated registrations overwrite previous ones, but the
*on_registry_close* callbacks are run all together when the registry is
closed.
Args:
svc_type: The type of the service to register.
factory:
A callable that is used to instantiated *svc_type* if asked. If
it's a generator or a context manager, a cleanup is registered
after instantiation.
Can also be an async callable/generator/context manager.
If *factory* takes a first argument called ``svcs_container``
or the first argument (of any name) is annotated as being
:class:`svcs.Container`, the container instance that is
instantiating the service is passed into the factory as the
first positional argument.
Note:
Generally speaking, given the churn and edgecases in the
typing ecosystem, we recommend using the name route to
detect the container argument because it's most reliable.
enter:
Whether to enter context managers if one is returned by
*factory*. Usually you want that, but there are occasions --
like database transaction managers -- that you want to enter
manually.
ping:
A callable that marks the service as having a health check.
See Also:
:meth:`Container.get_pings` and :class:`ServicePing`.
on_registry_close:
A callable that is called when the
:meth:`svcs.Registry.close()` method is called.
Can also be an async callable or an
:class:`collections.abc.Awaitable`; then
:meth:`svcs.Registry.aclose()` must be called.
"""
rs = self._register_factory(
svc_type,
factory,
enter=enter,
ping=ping,
on_registry_close=on_registry_close,
)
log.debug(
"registered factory %r for service type %s",
factory,
rs.name,
extra={
"svcs_service_name": rs.name,
"svcs_factory_name": _full_name(factory),
},
stack_info=True,
)
[docs]
def register_value(
self,
svc_type: type,
value: object,
*,
enter: bool = False,
ping: Callable | None = None,
on_registry_close: Callable | Awaitable | None = None,
) -> None:
"""
Syntactic sugar for::
register_factory(
svc_type,
lambda: value,
enter=enter,
ping=ping,
on_registry_close=on_registry_close
)
Please note that, unlike with :meth:`register_factory`, entering
context managers is **disabled** by default.
.. versionchanged:: 23.21.0
*enter* is now ``False`` by default.
"""
rs = self._register_factory(
svc_type,
lambda: value,
enter=enter,
ping=ping,
on_registry_close=on_registry_close,
)
log.debug(
"registered value %r for service type %s",
value,
rs.name,
extra={"svcs_service_name": rs.name, "svcs_value": value},
stack_info=True,
)
def _register_factory(
self,
svc_type: type,
factory: Callable,
enter: bool,
ping: Callable | None,
on_registry_close: Callable | Awaitable | None = None,
) -> RegisteredService:
if isgeneratorfunction(factory):
factory = contextmanager(factory)
elif isasyncgenfunction(factory):
factory = asynccontextmanager(factory)
rs = RegisteredService(
svc_type, factory, _takes_container(factory), enter, ping
)
self._services[svc_type] = rs
if on_registry_close is not None:
self._on_close.append((rs.name, on_registry_close))
return rs
def get_registered_service_for(self, svc_type: type) -> RegisteredService:
try:
return self._services[svc_type]
except KeyError:
raise ServiceNotFoundError(svc_type) from None
[docs]
def close(self) -> None:
"""
Clear registrations and run synchronous *on_registry_close* callbacks.
Async callbacks are *not* awaited and a warning is raised
Errors are logged at warning level, but otherwise ignored.
"""
for name, oc in reversed(self._on_close):
if iscoroutinefunction(oc) or isawaitable(oc):
warnings.warn(
f"Skipped async cleanup for {name!r}. "
"Use aclose() instead.",
# stacklevel doesn't matter here; it's coming from a
# framework.
stacklevel=1,
)
continue
try:
log.debug("closing %r", name)
oc() # type: ignore[operator]
log.debug("closed %r", name)
except Exception: # noqa: BLE001
log.warning(
"Registry's on_registry_close callback failed for %r.",
name,
exc_info=True,
extra={"svcs_service_name": name},
)
self._services.clear()
self._on_close.clear()
[docs]
async def aclose(self) -> None:
"""
Clear registrations and run all *on_registry_close* callbacks.
Errors are logged at warning level, but otherwise ignored.
Also works with synchronous services, so in an async application, just
use this.
"""
for name, oc in reversed(self._on_close):
try:
if iscoroutinefunction(oc):
oc = oc() # noqa: PLW2901
if isawaitable(oc):
log.debug("async closing %r", name)
await oc
log.debug("async closed %r", name)
else:
log.debug("closing %r", name)
oc() # type: ignore[operator]
log.debug("closed %r", name)
except Exception: # noqa: BLE001, PERF203
log.warning(
"Registry's on_registry_close callback failed for %r.",
name,
exc_info=True,
extra={"svcs_service_name": name},
)
self._services.clear()
self._on_close.clear()
def _takes_container(factory: Callable) -> bool:
"""
Return True if *factory* takes a svcs.Container as its first argument.
"""
try:
# Provide the locals so that `eval_str` will work even if the user places the `Container`
# under a `if TYPE_CHECKING` block
sig = inspect.signature(
factory, locals={"Container": Container}, eval_str=True
)
except Exception: # noqa: BLE001
# Retry without `eval_str` since if the annotation is "svcs.Container" the eval
# will fail due to it not finding the `svcs` module
try:
sig = inspect.signature(factory)
except Exception: # noqa: BLE001
return False
if not sig.parameters:
return False
if len(sig.parameters) != 1:
msg = "Factories must take 0 or 1 parameters."
raise TypeError(msg)
((name, p),) = tuple(sig.parameters.items())
return (
name == "svcs_container"
or p.annotation is Container
or p.annotation == "svcs.Container"
or p.annotation == "Container"
)
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
T7 = TypeVar("T7")
T8 = TypeVar("T8")
T9 = TypeVar("T9")
T10 = TypeVar("T10")
[docs]
@attrs.define
class Container:
"""
A per-context container for instantiated services and cleanups.
The instance of this should live as long as a request or a task.
Also works as a context manager that runs clean ups on exit:
.. doctest::
>>> reg = svcs.Registry()
>>> def factory() -> str:
... yield "Hello World"
... print("Cleaned up!")
>>> reg.register_factory(str, factory)
>>> with svcs.Container(reg) as con:
... _ = con.get(str)
Cleaned up!
Warns:
ResourceWarning:
If a container with pending cleanups is garbage-collected.
Attributes:
registry:
The :class:`Registry` instance that this container uses for service
type lookup.
"""
registry: Registry
_lazy_local_registry: Registry | None = None
_instantiated: dict[type, object] = attrs.Factory(dict)
_on_close: list[
tuple[str, AbstractContextManager | AbstractAsyncContextManager]
] = attrs.Factory(list)
def __repr__(self) -> str:
return (
f"<Container(instantiated={len(self._instantiated)}, "
f"cleanups={len(self._on_close)})>"
)
[docs]
def __contains__(self, svc_type: type) -> bool:
"""
Check whether this container has a cached instance of *svc_type*.
"""
return svc_type in self._instantiated
def __enter__(self) -> Container:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
async def __aenter__(self) -> Container:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
def __del__(self) -> None:
"""
Warn if the container is gc'ed before being closed.
"""
if getattr(self, "_on_close", None):
warnings.warn(
"Container was garbage-collected with pending cleanups.",
ResourceWarning,
stacklevel=1,
)
[docs]
def close(self) -> None:
"""
Run all registered *synchronous* cleanups.
Async closes are *not* awaited and a warning is raised.
Errors are logged at warning level, but otherwise ignored.
Hint:
The Container can be used again after this. Closing it is an
idempotent way to reset it.
"""
for name, cm in reversed(self._on_close):
try:
if isinstance(cm, AbstractAsyncContextManager):
warnings.warn(
f"Skipped async cleanup for {name!r}. "
"Use aclose() instead.",
# stacklevel doesn't matter here; it's coming from a
# framework.
stacklevel=1,
)
continue
cm.__exit__(None, None, None)
except Exception: # noqa: BLE001
log.warning(
"Container clean up failed for %r.",
name,
exc_info=True,
extra={"svcs_service_name": name},
)
if self._lazy_local_registry is not None:
self._lazy_local_registry.close()
self._on_close.clear()
self._instantiated.clear()
[docs]
async def aclose(self) -> None:
"""
Run *all* registered cleanups -- synchronous **and** asynchronous.
Errors are logged at warning level, but otherwise ignored.
Also works with synchronous services, so in an async application, just
use this.
Hint:
The container can be used again after this. Closing it is an
idempotent way to reset it.
"""
for name, cm in reversed(self._on_close):
try:
if isinstance(cm, AbstractContextManager):
cm.__exit__(None, None, None)
else:
await cm.__aexit__(None, None, None)
except Exception: # noqa: BLE001, PERF203
log.warning(
"Container clean up failed for %r.",
name,
exc_info=True,
extra={"svcs_service_name": name},
)
if self._lazy_local_registry is not None:
await self._lazy_local_registry.aclose()
self._on_close.clear()
self._instantiated.clear()
[docs]
def get_pings(self) -> list[ServicePing]:
"""
Return all services that have defined a *ping* and bind them to this
container.
Returns:
A list of services that have registered a ping callable.
"""
return [
ServicePing(
rs.name,
iscoroutinefunction(rs.ping),
rs.svc_type,
rs.ping,
self,
)
for rs in self.registry._services.values()
if rs.ping is not None
]
[docs]
def get_abstract(self, *svc_types: type) -> Any:
"""
Like :meth:`get` but is annotated to return :data:`typing.Any` which
allows it to be used with abstract types like :class:`typing.Protocol`
or :mod:`abc` classes.
Note:
See :doc:`typing-caveats` why this is necessary.
"""
return self.get(*svc_types)
[docs]
async def aget_abstract(self, *svc_types: type) -> Any:
"""
Same as :meth:`get_abstract` but instantiates asynchronously, if
necessary.
Also works with synchronous services, so in an async application, just
use this.
"""
return await self.aget(*svc_types)
def _lookup(self, svc_type: type) -> tuple[bool, object, str, bool]:
"""
Look up svc_type first in our cache, then in the registry.
If it's cached, only the first two items of the returned tupled are
meaningful.
"""
if (
svc := self._instantiated.get(svc_type, attrs.NOTHING)
) is not attrs.NOTHING:
return True, svc, "", False
rs = None
if self._lazy_local_registry is not None:
with suppress(ServiceNotFoundError):
rs = self._lazy_local_registry.get_registered_service_for(
svc_type
)
if rs is None:
rs = self.registry.get_registered_service_for(svc_type)
svc = rs.factory(self) if rs.takes_container else rs.factory()
return False, svc, rs.name, rs.enter
[docs]
def register_local_factory(
self,
svc_type: type,
factory: Callable,
*,
enter: bool = True,
ping: Callable | None = None,
on_registry_close: Callable | Awaitable | None = None,
) -> None:
"""
Same as :meth:`svcs.Registry.register_factory()`, but registers the
factory only for this container.
A temporary :class:`svcs.Registry` is transparently created -- and
closed together with the container it belongs to.
See Also:
:ref:`local-registries`
.. versionadded:: 23.21.0
"""
if self._lazy_local_registry is None:
self._lazy_local_registry = Registry()
self._lazy_local_registry.register_factory(
svc_type=svc_type,
factory=factory,
enter=enter,
ping=ping,
on_registry_close=on_registry_close,
)
[docs]
def register_local_value(
self,
svc_type: type,
value: object,
*,
enter: bool = False,
ping: Callable | None = None,
on_registry_close: Callable | Awaitable | None = None,
) -> None:
"""
Syntactic sugar for::
register_local_factory(
svc_type,
lambda: value,
enter=enter,
ping=ping,
on_registry_close=on_registry_close
)
Please note that, unlike with :meth:`register_local_factory`, entering
context managers is **disabled** by default.
See Also:
:ref:`local-registries`
.. versionadded:: 23.21.0
"""
self.register_local_factory(
svc_type,
lambda: value,
enter=enter,
ping=ping,
on_registry_close=on_registry_close,
)
@overload
def get(self, svc_type: type[T1], /) -> T1:
...
@overload
def get(
self, svc_type1: type[T1], svc_type2: type[T2], /
) -> tuple[T1, T2]:
...
@overload
def get(
self, svc_type1: type[T1], svc_type2: type[T2], svc_type3: type[T3], /
) -> tuple[T1, T2, T3]:
...
@overload
def get(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
/,
) -> tuple[T1, T2, T3, T4]:
...
@overload
def get(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
svc_type5: type[T5],
/,
) -> tuple[T1, T2, T3, T4, T5]:
...
@overload
def get(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
svc_type5: type[T5],
svc_type6: type[T6],
/,
) -> tuple[T1, T2, T3, T4, T5, T6]:
...
@overload
def get(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
svc_type5: type[T5],
svc_type6: type[T6],
svc_type7: type[T7],
/,
) -> tuple[T1, T2, T3, T4, T5, T6, T7]:
...
@overload
def get(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
svc_type5: type[T5],
svc_type6: type[T6],
svc_type7: type[T7],
svc_type8: type[T8],
/,
) -> tuple[T1, T2, T3, T4, T5, T6, T7, T8]:
...
@overload
def get(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
svc_type5: type[T5],
svc_type6: type[T6],
svc_type7: type[T7],
svc_type8: type[T8],
svc_type9: type[T9],
/,
) -> tuple[T1, T2, T3, T4, T5, T6, T7, T8, T9]:
...
@overload
def get(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
svc_type5: type[T5],
svc_type6: type[T6],
svc_type7: type[T7],
svc_type8: type[T8],
svc_type9: type[T9],
svc_type10: type[T10],
/,
) -> tuple[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]:
...
[docs]
def get(self, *svc_types: type) -> object:
"""
Get services of *svc_types*.
Instantiate them if necessary and register their cleanup.
Returns:
``svc_types[0]`` | ``tuple[*svc_types]``: If one service is
requested, it's returned directly. If multiple are requested, a
tuple of services is returned.
"""
rv = []
for svc_type in svc_types:
cached, svc, name, enter = self._lookup(svc_type)
if cached:
rv.append(svc)
continue
if iscoroutine(svc) or isinstance(
svc, AbstractAsyncContextManager
):
msg = "Use `aget()` for async factories."
raise TypeError(msg)
if enter and isinstance(svc, AbstractContextManager):
self._on_close.append((name, svc))
svc = svc.__enter__()
self._instantiated[svc_type] = svc
rv.append(svc)
if len(rv) == 1:
return rv[0]
return rv
@overload
async def aget(self, svc_type: type[T1], /) -> T1:
...
@overload
async def aget(
self, svc_type1: type[T1], svc_type2: type[T2], /
) -> tuple[T1, T2]:
...
@overload
async def aget(
self, svc_type1: type[T1], svc_type2: type[T2], svc_type3: type[T3], /
) -> tuple[T1, T2, T3]:
...
@overload
async def aget(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
/,
) -> tuple[T1, T2, T3, T4]:
...
@overload
async def aget(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
svc_type5: type[T5],
/,
) -> tuple[T1, T2, T3, T4, T5]:
...
@overload
async def aget(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
svc_type5: type[T5],
svc_type6: type[T6],
/,
) -> tuple[T1, T2, T3, T4, T5, T6]:
...
@overload
async def aget(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
svc_type5: type[T5],
svc_type6: type[T6],
svc_type7: type[T7],
/,
) -> tuple[T1, T2, T3, T4, T5, T6, T7]:
...
@overload
async def aget(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
svc_type5: type[T5],
svc_type6: type[T6],
svc_type7: type[T7],
svc_type8: type[T8],
/,
) -> tuple[T1, T2, T3, T4, T5, T6, T7, T8]:
...
@overload
async def aget(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
svc_type5: type[T5],
svc_type6: type[T6],
svc_type7: type[T7],
svc_type8: type[T8],
svc_type9: type[T9],
/,
) -> tuple[T1, T2, T3, T4, T5, T6, T7, T8, T9]:
...
@overload
async def aget(
self,
svc_type1: type[T1],
svc_type2: type[T2],
svc_type3: type[T3],
svc_type4: type[T4],
svc_type5: type[T5],
svc_type6: type[T6],
svc_type7: type[T7],
svc_type8: type[T8],
svc_type9: type[T9],
svc_type10: type[T10],
/,
) -> tuple[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]:
...
[docs]
async def aget(self, *svc_types: type) -> object:
"""
Same as :meth:`get` but instantiates asynchronously, if necessary.
Also works with synchronous services, so in an async application, just
use this.
"""
rv = []
for svc_type in svc_types:
cached, svc, name, enter = self._lookup(svc_type)
if cached:
rv.append(svc)
continue
if enter and isinstance(svc, AbstractAsyncContextManager):
self._on_close.append((name, svc))
svc = await svc.__aenter__()
elif isawaitable(svc):
svc = await svc
self._instantiated[svc_type] = svc
rv.append(svc)
if len(rv) == 1:
return rv[0]
return rv