diff --git a/royalnet/engineer/__init__.py b/royalnet/engineer/__init__.py index d688bb27..76eb26d9 100644 --- a/royalnet/engineer/__init__.py +++ b/royalnet/engineer/__init__.py @@ -1,10 +1,5 @@ """ -A chatbot command router inspired by :mod:`fastapi`. +Chat bot utilities. All names are inspired by the `Engineer Class of Team Fortress 2 `_. """ - -from .blueprints import * -from .teleporter import * -from .sentry import * -from .exc import * diff --git a/royalnet/engineer/blueprints/__init__.py b/royalnet/engineer/blueprints/__init__.py deleted file mode 100644 index c2250a8d..00000000 --- a/royalnet/engineer/blueprints/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .blueprint import * -from .message import * -from .channel import * -from .user import * diff --git a/royalnet/engineer/blueprints/blueprint.py b/royalnet/engineer/blueprints/blueprint.py deleted file mode 100644 index 66a60f31..00000000 --- a/royalnet/engineer/blueprints/blueprint.py +++ /dev/null @@ -1,92 +0,0 @@ -import abc - -from .. import exc - - -class Blueprint(metaclass=abc.ABCMeta): - """ - A class containing methods common between all blueprints. - - To extend a blueprint, inherit from it while using the :class:`abc.ABCMeta` metaclass, and make all new functions - return :exc:`.exc.NeverAvailableError`: - - .. code-block:: - - class Channel(Blueprint, metaclass=abc.ABCMeta): - def name(self): - raise exc.NeverAvailableError() - - To implement a blueprint for a specific chat platform, inherit from the blueprint, override :meth:`__init__`, - :meth:`__hash__` and the methods that are implemented by the platform in question, either returning the - corresponding value or raising :exc:`.exc.NotAvailableError` if there is no data available. - - .. code-block:: - - class ExampleChannel(Channel): - def __init__(self, chat_id: int): - self.chat_id: int = chat_id - - def __hash__(self): - return self.chat_id - - def name(self): - return ExampleClient.expensive_get_channel_name(self.chat_id) - - .. note:: To improve performance, you might want to wrap all data methods in :func:`functools.lru_cache` decorators. - - .. code-block:: - - @functools.lru_cache(24) - def name(self): - return ExampleClient.expensive_get_channel_name(self.chat_id) - - """ - - def __init__(self): - """ - :return: The created object. - """ - pass - - @abc.abstractmethod - def __hash__(self): - """ - :return: A value that uniquely identifies the channel inside this Python process. - """ - raise NotImplementedError() - - def requires(self, *fields) -> True: - """ - Ensure that this blueprint has the specified fields, re-raising the highest priority exception raised between - all of them. - - .. code-block:: - - def print_msg(message: Message): - message.requires("text", "timestamp") - print(f"{message.timestamp().isoformat()}: {message.text()}") - - :raises .exc.NeverAvailableError: If at least one of the fields raised a :exc:`.exc.NeverAvailableError`. - :raises .exc.NotAvailableError: If no field raised a :exc:`.exc.NeverAvailableError`, but at least one raised a - :exc:`.exc.NotAvailableError`. - """ - - exceptions = [] - - for field in fields: - try: - self.__getattribute__(field)() - except exc.NeverAvailableError as ex: - exceptions.append(ex) - except exc.NotAvailableError as ex: - exceptions.append(ex) - - if len(exceptions) > 0: - raise max(exceptions, key=lambda e: e.priority) - - return True - - -__all__ = ( - "Blueprint", -) diff --git a/royalnet/engineer/blueprints/channel.py b/royalnet/engineer/blueprints/channel.py deleted file mode 100644 index 56264806..00000000 --- a/royalnet/engineer/blueprints/channel.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations -from royalnet.royaltyping import * -import abc - -from .. import exc -from .blueprint import Blueprint - - -class Channel(Blueprint, metaclass=abc.ABCMeta): - """ - An abstract class representing a channel where messages can be sent. - - .. seealso:: :class:`.Blueprint` - """ - - def name(self) -> str: - """ - :return: The name of the message channel, such as the chat title. - :raises .exc.NeverAvailableError: If the chat platform does not support channel names. - :raises .exc.NotAvailableError: If this channel does not have any name. - """ - raise exc.NeverAvailableError() - - def topic(self) -> str: - """ - :return: The topic or description of the message channel. - :raises .exc.NeverAvailableError: If the chat platform does not support channel topics / descriptions. - :raises .exc.NotAvailableError: If this channel does not have any name. - """ - raise exc.NeverAvailableError() - - -__all__ = ( - "Channel", -) diff --git a/royalnet/engineer/blueprints/message.py b/royalnet/engineer/blueprints/message.py deleted file mode 100644 index 3ce12080..00000000 --- a/royalnet/engineer/blueprints/message.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations -from royalnet.royaltyping import * -import abc -import datetime - -from .. import exc -from .blueprint import Blueprint -from .channel import Channel - - -class Message(Blueprint, metaclass=abc.ABCMeta): - """ - An abstract class representing a chat message sent in any platform. - - .. seealso:: :class:`.Blueprint` - """ - - def text(self) -> str: - """ - :return: The raw text contents of the message. - :raises .exc.NeverAvailableError: If the chat platform does not support text messages. - :raises .exc.NotAvailableError: If this message does not have any text. - """ - raise exc.NeverAvailableError() - - def timestamp(self) -> datetime.datetime: - """ - :return: The :class:`datetime.datetime` at which the message was sent / received. - :raises .exc.NeverAvailableError: If the chat platform does not support timestamps. - :raises .exc.NotAvailableError: If this message is special and does not have any timestamp. - """ - raise exc.NeverAvailableError() - - def reply_to(self) -> Message: - """ - :return: The :class:`.Message` this message is a reply to. - :raises .exc.NeverAvailableError: If the chat platform does not support replies. - :raises .exc.NotAvailableError: If this message is not a reply to any other message. - """ - raise exc.NeverAvailableError() - - def channel(self) -> Channel: - """ - :return: The :class:`.Channel` this message was sent in. - :raises .exc.NeverAvailableError: If the chat platform does not support channels. - :raises .exc.NotAvailableError: If this message was not sent in any channel. - """ - raise exc.NeverAvailableError() - - -__all__ = ( - "Message", -) diff --git a/royalnet/engineer/blueprints/user.py b/royalnet/engineer/blueprints/user.py deleted file mode 100644 index 81fd473f..00000000 --- a/royalnet/engineer/blueprints/user.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations -from royalnet.royaltyping import * -import abc -import sqlalchemy.orm - -from .. import exc -from .blueprint import Blueprint - - -class User(Blueprint, metaclass=abc.ABCMeta): - """ - An abstract class representing a chat user. - - .. seealso:: :class:`.Blueprint` - """ - - def name(self) -> str: - """ - :return: The user's name. - :raises .exc.NeverAvailableError: If the chat platform does not support usernames. - :raises .exc.NotAvailableError: If this user does not have any name. - """ - raise exc.NeverAvailableError() - - def database(self, session: sqlalchemy.orm.Session) -> Any: - """ - :param session: A :class:`sqlalchemy.orm.Session` instance to use to fetch the database entry. - :return: The database entry for this user. - """ - raise exc.NeverAvailableError() - - -__all__ = ( - "User", -) diff --git a/royalnet/engineer/bullet.py b/royalnet/engineer/bullet.py new file mode 100644 index 00000000..dc35c432 --- /dev/null +++ b/royalnet/engineer/bullet.py @@ -0,0 +1,114 @@ +""" + +""" + +from __future__ import annotations +import royalnet.royaltyping as t + +import abc +import datetime +import sqlalchemy.orm + +from . import exc + + +class Bullet(metaclass=abc.ABCMeta): + """ + The abstract base class for Bullet data models. + """ + + @abc.abstractmethod + def __hash__(self) -> int: + """ + :return: A value that uniquely identifies the object in this Python interpreter process. + """ + raise NotImplementedError() + + +class Message(Bullet, metaclass=abc.ABCMeta): + """ + An abstract class representing a chat message. + """ + + async def text(self) -> t.Optional[str]: + """ + :return: The raw text contents of the message. + :raises .exc.NotSupportedError: If the frontend does not support text messages. + """ + raise exc.NotSupportedError() + + async def timestamp(self) -> t.Optional[datetime.datetime]: + """ + :return: The :class:`datetime.datetime` at which the message was sent. + :raises .exc.NotSupportedError: If the frontend does not support timestamps. + """ + raise exc.NotSupportedError() + + async def reply_to(self) -> t.Optional[Message]: + """ + :return: The :class:`.Message` this message is a reply to. + :raises .exc.NotSupportedError: If the frontend does not support replies. + """ + raise exc.NotSupportedError() + + async def channel(self) -> t.Optional[Channel]: + """ + :return: The :class:`.Channel` this message was sent in. + :raises .exc.NotSupportedError: If the frontend does not support channels. + """ + raise exc.NotSupportedError() + + +class Channel(Bullet, metaclass=abc.ABCMeta): + """ + An abstract class representing a channel where messages can be sent. + """ + + async def name(self) -> t.Optional[str]: + """ + :return: The name of the message channel, such as the chat title. + :raises .exc.NotSupportedError: If the frontend does not support channel names. + """ + raise exc.NotSupportedError() + + async def topic(self) -> t.Optional[str]: + """ + :return: The topic (description) of the message channel. + :raises .exc.NotSupportedError: If the frontend does not support channel topics / descriptions. + """ + raise exc.NotSupportedError() + + async def users(self) -> t.List[User]: + """ + :return: A :class:`list` of :class:`.User` who can read messages sent in the channel. + :raises .exc.NotSupportedError: If the frontend does not support such a feature. + """ + raise exc.NotSupportedError() + + +class User(Bullet, metaclass=abc.ABCMeta): + """ + An abstract class representing a user who can read or send messages in the chat. + """ + + async def name(self) -> t.Optional[str]: + """ + :return: The user's name. + :raises .exc.NotSupportedError: If the frontend does not support usernames. + """ + raise exc.NotSupportedError() + + async def database(self, session: sqlalchemy.orm.Session) -> t.Any: + """ + :param session: A :class:`sqlalchemy.orm.Session` instance to use to fetch the database entry. + :return: The database entry for this user. + """ + raise exc.NotSupportedError() + + +__all__ = ( + "Bullet", + "Message", + "Channel", + "User", +) diff --git a/royalnet/engineer/discard.py b/royalnet/engineer/discard.py new file mode 100644 index 00000000..e9167e62 --- /dev/null +++ b/royalnet/engineer/discard.py @@ -0,0 +1,13 @@ +class Discard(BaseException): + """ + A special exception which should be raised by Metals if a certain object should be discarded from the queue. + """ + def __init__(self, obj, message): + self.obj = obj + self.message = message + + def __repr__(self): + return f"" + + def __str__(self): + return f"Discarded {self.obj}: {self.message}" diff --git a/royalnet/engineer/dispenser.py b/royalnet/engineer/dispenser.py index e69de29b..d3413dd7 100644 --- a/royalnet/engineer/dispenser.py +++ b/royalnet/engineer/dispenser.py @@ -0,0 +1,51 @@ +""" +Dispensers instantiate sentries and dispatch events in bulk to the whole group. +""" + +from __future__ import annotations +import royalnet.royaltyping as t + +import logging +import contextlib + +from .sentry import SentrySource + +log = logging.getLogger(__name__) + + +class Dispenser: + def __init__(self): + self.sentries: t.List[SentrySource] = [] + """ + A :class:`list` of all the running sentries of this dispenser. + """ + + def put(self, item: t.Any) -> None: + """ + Insert a new item in the queues of all the running sentries. + + :param item: The item to insert. + """ + log.debug(f"Putting {item}") + for sentry in self.sentries: + sentry.put(item) + + @contextlib.contextmanager + def sentry(self, *args, **kwargs): + """ + A context manager which creates a :class:`.SentrySource` and keeps it in :attr:`.sentries` while it is being + used. + """ + sentry = SentrySource(dispenser=self, *args, **kwargs) + self.sentries.append(sentry) + + yield sentry + + self.sentries.remove(sentry) + + async def run(self, conv: t.Conversation) -> None: + with self.sentry() as sentry: + state = conv(sentry) + + while True: + state = await state diff --git a/royalnet/engineer/exc.py b/royalnet/engineer/exc.py index 7c18c7b3..ece28617 100644 --- a/royalnet/engineer/exc.py +++ b/royalnet/engineer/exc.py @@ -1,38 +1,27 @@ -import royalnet.exc import pydantic -class EngineerException(royalnet.exc.RoyalnetException): +class EngineerException(Exception): """ - An exception raised by the engineer module. + The base class for errors in :mod:`royalnet.engineer`. """ -class BlueprintError(EngineerException): +class WrenchException(EngineerException): """ - An error related to the :mod:`royalnet.engineer.blueprints`. + The base class for errors in :mod:`royalnet.engineer.wrench`. """ -class NeverAvailableError(BlueprintError, NotImplementedError): +class DeliberateException(WrenchException): """ - The requested property is never supplied by the chat platform the message was sent in. + This exception was deliberately raised by :class:`royalnet.engineer.wrench.ErrorAll`. """ - priority = 1 - - -class NotAvailableError(BlueprintError): - """ - The requested property was not supplied by the chat platform for the specific message this exception was raised in. - """ - - priority = 2 - class TeleporterError(EngineerException, pydantic.ValidationError): """ - The validation of some object though a :mod:`pydantic` model failed. + The base class for errors in :mod:`royalnet.engineer.teleporter`. """ @@ -48,28 +37,13 @@ class OutTeleporterError(TeleporterError): """ -class SentryError(EngineerException): +class BulletException(EngineerException): """ - An error related to the :mod:`royalnet.engineer.sentry`. + The base class for errors in :mod:`royalnet.engineer.bullet`. """ -class FilterError(SentryError): +class NotSupportedError(BulletException, NotImplementedError): """ - An error related to the :class:`royalnet.engineer.sentry.Filter`. + The requested property isn't available on the current frontend. """ - - -class Discard(FilterError): - """ - Discard the object from the queue. - """ - def __init__(self, obj, message): - self.obj = obj - self.message = message - - def __repr__(self): - return f"" - - def __str__(self): - return f"Discarded {self.obj}: {self.message}" diff --git a/royalnet/engineer/sentry.py b/royalnet/engineer/sentry.py new file mode 100644 index 00000000..2ae2bf05 --- /dev/null +++ b/royalnet/engineer/sentry.py @@ -0,0 +1,188 @@ +""" +Sentries are asyncronous receivers for events (usually :class:`bullet.Bullet`) incoming from Dispensers. + +They support event filtering through Wrenches and coroutine functions. +""" + +from __future__ import annotations +import royalnet.royaltyping as t + +import abc +import logging +import asyncio + +from . import discard +from . import bullet + +if t.TYPE_CHECKING: + from .dispenser import Dispenser + +log = logging.getLogger(__name__) + + +class Sentry(metaclass=abc.ABCMeta): + """ + The abstract object representing a node of the pipeline. + """ + + @abc.abstractmethod + def __len__(self) -> int: + raise NotImplementedError() + + @abc.abstractmethod + def get_nowait(self) -> bullet.Bullet: + """ + Try to get a single :class:`~.bullet.Bullet` from the pipeline, without blocking or handling discards. + + :return: The **returned** :class:`~.bullet.Bullet`. + :raises asyncio.QueueEmpty: If the queue is empty. + :raises .discard.Discard: If the object was **discarded** by the pipeline. + :raises Exception: If an exception was **raised** in the pipeline. + """ + raise NotImplementedError() + + @abc.abstractmethod + async def get(self) -> bullet.Bullet: + """ + Try to get a single :class:`~.bullet.Bullet` from the pipeline, blocking until something is available, but + without handling discards. + + :return: The **returned** :class:`~.bullet.Bullet`. + :raises .discard.Discard: If the object was **discarded** by the pipeline. + :raises Exception: If an exception was **raised** in the pipeline. + """ + raise NotImplementedError() + + async def wait(self) -> bullet.Bullet: + """ + Try to get a single :class:`~.bullet.Bullet` from the pipeline, blocking until something is available and is not + discarded. + + :return: The **returned** :class:`~.bullet.Bullet`. + :raises Exception: If an exception was **raised** in the pipeline. + """ + while True: + try: + result = await self.get() + log.debug(f"Returned: {result}") + return result + except discard.Discard as d: + log.debug(f"{str(d)}") + continue + + def __await__(self): + """ + Awaiting an object implementing :class:`.SentryInterface` corresponds to awaiting :meth:`.wait`. + """ + return self.get().__await__() + + @abc.abstractmethod + async def put(self, item: t.Any) -> None: + """ + Insert a new item in the queue. + + :param item: The item to be added. + """ + raise NotImplementedError() + + def filter(self, wrench: t.Callable[[t.Any], t.Awaitable[t.Any]]) -> SentryFilter: + """ + Chain a new filter to the pipeline. + + :param wrench: The filter to add to the chain. It can either be a :class:`.wrench.Wrench`, or a coroutine + function accepting a single object as parameter and returning the same or a different one. + :return: A new :class:`.SentryFilter` which includes the filter. + + .. seealso:: :meth:`.__or__` + """ + if callable(wrench): + return SentryFilter(previous=self, wrench=wrench) + else: + raise TypeError("wrench must be either a Wrench or a coroutine function") + + def __or__(self, other: t.Callable[[t.Any], t.Awaitable[t.Any]]) -> SentryFilter: + """ + A unix-pipe-like interface for :meth:`.filter`. + + .. code-block:: + + await (sentry | wrench.Type(Message) | wrench.Sync(lambda o: o.text)) + + """ + try: + return self.filter(other) + except TypeError: + raise TypeError("Right-side must be either a Wrench or a coroutine function") + + @abc.abstractmethod + def dispenser(self): + """ + Get the :class:`.Dispenser` that created this Sentry. + + :return: The :class:`.Dispenser` object. + """ + raise NotImplementedError() + + +class SentryFilter(Sentry): + """ + A non-root node of the filtering pipeline. + """ + + def __init__(self, previous: Sentry, wrench: t.Callable[[t.Any], t.Awaitable[t.Any]]): + self.previous: Sentry = previous + """ + The previous node of the pipeline. + """ + + self.wrench: t.Callable[[t.Any], t.Awaitable[t.Any]] = wrench + """ + The coroutine function to apply to all objects passing through this node. + """ + + def __len__(self) -> int: + return len(self.previous) + 1 + + def get_nowait(self) -> bullet.Bullet: + return self.previous.get_nowait() + + async def get(self) -> bullet.Bullet: + return await self.previous.get() + + async def put(self, item) -> None: + return await self.previous.put(item) + + def dispenser(self): + return self.previous.dispenser() + + +class SentrySource(Sentry): + """ + The root and source of the pipeline. + """ + + def __init__(self, dispenser: "Dispenser", queue_size: int = 12): + self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_size) + self._dispenser: "Dispenser" = dispenser + + def __len__(self) -> int: + return 1 + + def get_nowait(self) -> bullet.Bullet: + return self.queue.get_nowait() + + async def get(self) -> bullet.Bullet: + return await self.queue.get() + + async def put(self, item) -> None: + return await self.queue.put(bullet) + + async def dispenser(self): + return self._dispenser + + +__all__ = ( + "Sentry", + "SentryFilter", + "SentrySource", +) diff --git a/royalnet/engineer/sentry/__init__.py b/royalnet/engineer/sentry/__init__.py deleted file mode 100644 index ef4d371e..00000000 --- a/royalnet/engineer/sentry/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .sentry import * diff --git a/royalnet/engineer/sentry/filter.py b/royalnet/engineer/sentry/filter.py deleted file mode 100644 index 7fb62f04..00000000 --- a/royalnet/engineer/sentry/filter.py +++ /dev/null @@ -1,257 +0,0 @@ -from __future__ import annotations -from royalnet.royaltyping import * -import functools -import logging - -from .. import exc, blueprints - -log = logging.getLogger(__name__) - - -class Filter: - """ - A fluent interface for filtering data. - """ - - def __init__(self, func: Callable): - self.func: Callable = func - - async def get(self) -> Any: - """ - Wait until an :class:`object` leaves the queue and passes through the filter, then return it. - - :return: The :class:`object` which left the queue. - """ - while True: - try: - return await self.get_single() - except exc.Discard: - continue - - async def get_single(self) -> Any: - """ - Let one :class:`object` pass through the filter, then either return it or raise an error if the object should be - discarded. - - :return: The :class:`object` which left the queue. - :raises exc.Discard: If the object was filtered. - """ - try: - result = await self.func(None) - except exc.Discard as e: - log.debug(str(e)) - raise - else: - log.debug(f"Dequeued {result}") - return result - - @staticmethod - def _deco_filter(c: Callable[[Any], bool], *, error: str): - """ - A decorator which checks the condition ``c`` on all objects transiting through the queue: - - - If the check **passes**, the object itself is returned; - - If the check **fails**, :exc:`.exc.Discard` is raised, with the object and the ``error`` string as parameters; - - If an error is raised, propagate the error upwards. - - .. warning:: Raising :exc:`.exc.Discard` in ``c`` will automatically cause the object to be discarded, as if - :data:`False` was returned. - - :param c: A function that takes in input an enqueued object and returns either the same object or a new one to - pass to the next filter in the queue. - :param error: The string that :exc:`.exc.Discard` should display if the object is discarded. - """ - def decorator(func): - @functools.wraps(func) - async def decorated(obj): - result: Any = await func(obj) - if c(result): - return result - else: - raise exc.Discard(obj=result, message=error) - return decorated - return decorator - - def filter(self, c: Callable[[Any], bool], error: str) -> Filter: - """ - Check the condition ``c`` on all objects transiting through the queue: - - - If the check **passes**, the object goes on to the next filter; - - If the check **fails**, the object is discarded, with ``error`` as reason; - - If an error is raised, propagate the error upwards. - - :param c: A function that takes in input an object and performs a check on it, returning either :data:`True` - or :data:`False`. - :param error: The reason for which objects should be discarded. - :return: A new :class:`Filter` with this new condition. - - .. seealso:: :meth:`._deco_filter`, :func:`filter` - """ - return self.__class__(self._deco_filter(c, error=error)(self.func)) - - @staticmethod - def _deco_map(c: Callable[[Any], object]): - """ - A decorator which applies the function ``c`` on all objects transiting through the queue: - - - If the function **returns**, return its return value; - - If the function **raises** an error, it is propagated upwards. - - :param c: A function that takes in input an enqueued object and returns either the same object or something - else. - - .. seealso:: :func:`map` - """ - def decorator(func): - @functools.wraps(func) - async def decorated(obj): - result: Any = await func(obj) - return c(result) - return decorated - return decorator - - def map(self, c: Callable[[Any], object]) -> Filter: - """ - Apply the function ``c`` on all objects transiting through the queue: - - - If the function **returns**, its return value replaces the object in the queue; - - If the function **raises** :exc:`.exc.Discard`, the object is discarded; - - If the function **raises another error**, propagate the error upwards. - - :param c: A function that takes in input an enqueued object and returns either the same object or something - else. - :return: A new :class:`Filter` with this new condition. - - .. seealso:: :meth:`._deco_map`, :func:`filter` - """ - return self.__class__(self._deco_map(c)(self.func)) - - def type(self, t: type) -> Filter: - """ - Check if an object passing through the queue :func:`isinstance` of the type ``t``. - - :param t: The type that objects should be instances of. - :return: A new :class:`Filter` with this new condition. - - .. seealso:: :func:`isinstance` - """ - return self.filter(lambda o: isinstance(o, t), error=f"Not instance of type {t}") - - def msg(self) -> Filter: - """ - Check if an object passing through the queue :func:`isinstance` of :class:`.blueprints.Message`. - - :return: A new :class:`Filter` with this new condition. - """ - return self.type(blueprints.Message) - - def requires(self, *fields, - propagate_not_available=False, - propagate_never_available=True) -> Filter: - """ - Test a :class:`.blueprints.Blueprint`'s fields by using its ``.requires()`` method: - - - If the :class:`.blueprints.Blueprint` has the appropriate fields, return it; - - If the :class:`.blueprints.Blueprint` doesn't have data for at least one of the fields, the object is discarded; - - the :class:`.blueprints.Blueprint` never has data for at least one of the fields, :exc:`.exc.NotAvailableError` is propagated upwards. - - :param fields: The fields to test for. - :param propagate_not_available: If :exc:`.exc.NotAvailableError` should be propagated - instead of discarding the errored object. - :param propagate_never_available: If :exc:`.exc.NeverAvailableError` should be propagated - instead of discarding the errored object. - :return: A new :class:`Filter` with this new condition. - - .. seealso:: :meth:`.blueprints.Blueprint.requires` - """ - def check(obj): - try: - return obj.requires(*fields) - except exc.NotAvailableError: - if propagate_not_available: - raise - raise exc.Discard(obj, "Data is not available") - except exc.NeverAvailableError: - if propagate_never_available: - raise - raise exc.Discard(obj, "Data is never available") - - return self.filter(check, error=".requires() method returned False") - - def field(self, field: str, - propagate_not_available=False, - propagate_never_available=True) -> Filter: - """ - Replace a :class:`.blueprints.Blueprint` with the value of one of its fields. - - :param field: The field to access. - :param propagate_not_available: If :exc:`.exc.NotAvailableError` should be propagated - instead of discarding the errored object. - :param propagate_never_available: If :exc:`.exc.NeverAvailableError` should be propagated - instead of discarding the errored object. - :return: A new :class:`Filter` with the new requirements. - """ - def replace(obj): - try: - return obj.__getattribute__(field)() - except exc.NotAvailableError: - if propagate_not_available: - raise - raise exc.Discard(obj, "Data is not available") - except exc.NeverAvailableError: - if propagate_never_available: - raise - raise exc.Discard(obj, "Data is never available") - - return self.map(replace) - - def startswith(self, prefix: str): - """ - Check if an object starts with the specified prefix and discard the objects that do not. - - :param prefix: The prefix object should start with. - :return: A new :class:`Filter` with the new requirements. - - .. seealso:: :meth:`str.startswith` - """ - return self.filter(lambda x: x.startswith(prefix), error=f"Text didn't start with {prefix}") - - def endswith(self, suffix: str): - """ - Check if an object ends with the specified suffix and discard the objects that do not. - - :param suffix: The prefix object should start with. - :return: A new :class:`Filter` with the new requirements. - - .. seealso:: :meth:`str.endswith` - """ - return self.filter(lambda x: x.endswith(suffix), error=f"Text didn't end with {suffix}") - - def regex(self, pattern: Pattern): - """ - Apply a regex over an object and discard the object if it does not match. - - :param pattern: The pattern that should be matched by the text. - :return: A new :class:`Filter` with the new requirements. - """ - def mapping(x): - if match := pattern.match(x): - return match - else: - raise exc.Discard(x, f"Text didn't match pattern {pattern}") - - return self.map(mapping) - - def choices(self, *choices): - """ - Ensure an object is in the ``choices`` list, discarding the object otherwise. - - :param choices: The pattern that should be matched by the text. - :return: A new :class:`Filter` with the new requirements. - """ - return self.filter(lambda o: o in choices, error="Not a valid choice") - - -__all__ = ( - "Filter", -) diff --git a/royalnet/engineer/sentry/sentry.py b/royalnet/engineer/sentry/sentry.py deleted file mode 100644 index 6d891158..00000000 --- a/royalnet/engineer/sentry/sentry.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations -from royalnet.royaltyping import * -import logging -import asyncio - -from .filter import Filter - -log = logging.getLogger(__name__) - - -class Sentry: - """ - A class that allows using the ``await`` keyword to suspend a command execution until a new message is received. - """ - - QUEUE_SIZE = 12 - """ - The size of the object :attr:`.queue`. - """ - - def __init__(self, filter_type: Type[Filter] = Filter): - self.queue: asyncio.Queue = asyncio.Queue(maxsize=self.QUEUE_SIZE) - """ - An object queue where incoming :class:`object` are stored, with a size limit of :attr:`.QUEUE_SIZE`. - """ - - self.filter_type: Type[Filter] = filter_type - """ - The filter to be used in :meth:`.f` calls, by default :class:`.filters.Filter`. - """ - - def __repr__(self): - return f"" - - def f(self): - """ - Create a :attr:`.filter_type` object, which can be configured through its fluent interface. - - Remember to call ``.get()`` on the end of the chain to finally get the object. - - To get any object, call: - - .. code-block:: - - await sentry.f().get() - - .. seealso:: :class:`.filters.Filter` - - :return: The created :class:`.filters.Filter`. - """ - async def func(_): - return await self.queue.get() - return self.filter_type(func) - - -__all__ = ( - "Sentry", -) diff --git a/royalnet/engineer/teleporter.py b/royalnet/engineer/teleporter.py index 7f36ba00..68ce2eba 100644 --- a/royalnet/engineer/teleporter.py +++ b/royalnet/engineer/teleporter.py @@ -1,13 +1,19 @@ -import functools +""" +The teleporter uses :mod:`pydantic` to validate function parameters and return values. +""" -from royalnet.royaltyping import * +from __future__ import annotations +import royalnet.royaltyping as t + +import logging import pydantic import inspect +import functools + from . import exc - -Model = TypeVar("Model") -Value = TypeVar("Value") +Value = t.TypeVar("Value") +log = logging.getLogger(__name__) class TeleporterConfig(pydantic.BaseConfig): @@ -17,7 +23,7 @@ class TeleporterConfig(pydantic.BaseConfig): arbitrary_types_allowed = True -def parameter_to_field(param: inspect.Parameter, **kwargs) -> Tuple[type, pydantic.fields.FieldInfo]: +def parameter_to_field(param: inspect.Parameter, **kwargs) -> t.Tuple[type, pydantic.fields.FieldInfo]: """ Convert a :class:`inspect.Parameter` to a type-field :class:`tuple`, which can be easily passed to :func:`pydantic.create_model`. @@ -45,9 +51,9 @@ def parameter_to_field(param: inspect.Parameter, **kwargs) -> Tuple[type, pydant ) -def signature_to_model(f: Callable, - __config__: Type[pydantic.BaseConfig] = TeleporterConfig, - extra_params: Dict[str, type] = None) -> Tuple[type, type]: +def signature_to_model(f: t.Callable, + __config__: t.Type[pydantic.BaseConfig] = TeleporterConfig, + extra_params: t.Dict[str, type] = None) -> t.Tuple[type, type]: """ Convert the signature of a function to two pydantic models: one for the input and another one for the output. @@ -80,7 +86,7 @@ def signature_to_model(f: Callable, return input_model, output_model -def split_kwargs(**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: +def split_kwargs(**kwargs) -> t.Tuple[t.Dict[str, t.Any], t.Dict[str, t.Any]]: """ Split the kwargs passed to this function in two different :class:`dict`, based on whether their name starts with ``_`` or not. @@ -129,7 +135,7 @@ def teleport_out(__model: type, value: Value) -> Value: raise exc.OutTeleporterError(errors=e.raw_errors, model=e.model) -def teleporter(__config__: Type[pydantic.BaseConfig] = TeleporterConfig, +def teleporter(__config__: t.Type[pydantic.BaseConfig] = TeleporterConfig, is_async: bool = False, validate_input: bool = True, validate_output: bool = True): @@ -148,7 +154,7 @@ def teleporter(__config__: Type[pydantic.BaseConfig] = TeleporterConfig, .. seealso:: :func:`.signature_to_model` """ - def decorator(f: Callable): + def decorator(f: t.Callable): # noinspection PyPep8Naming InputModel, OutputModel = signature_to_model(f, __config__=__config__) diff --git a/royalnet/engineer/tests/__init__.py b/royalnet/engineer/tests/__init__.py index e69de29b..5baa7611 100644 --- a/royalnet/engineer/tests/__init__.py +++ b/royalnet/engineer/tests/__init__.py @@ -0,0 +1,12 @@ +""" + +""" + +from __future__ import annotations +import royalnet.royaltyping as t + +import logging + +from . import * + +log = logging.getLogger(__name__) diff --git a/royalnet/engineer/tests/test_sentry.py b/royalnet/engineer/tests/test_sentry.py deleted file mode 100644 index c22598cf..00000000 --- a/royalnet/engineer/tests/test_sentry.py +++ /dev/null @@ -1,230 +0,0 @@ -import pytest -import asyncio -import async_timeout -import re -from royalnet.engineer import sentry, exc, blueprints - - -@pytest.fixture -def s() -> sentry.Sentry: - return sentry.Sentry() - - -class TestSentry: - def test_creation(self, s: sentry.Sentry): - assert s - assert isinstance(s, sentry.Sentry) - - @pytest.mark.asyncio - async def test_put(self, s: sentry.Sentry): - await s.queue.put(None) - - @pytest.mark.asyncio - async def test_get(self, s: sentry.Sentry): - await s.queue.put(None) - assert await s.queue.get() is None - - @pytest.mark.asyncio - async def test_f(self, s: sentry.Sentry): - await s.queue.put(None) - f = s.f() - assert f - assert isinstance(f, sentry.Filter) - assert hasattr(f, "get") - - -@pytest.fixture -def discarding_filter() -> sentry.Filter: - async def discard(_): - raise exc.Discard(None, "This filter discards everything!") - - return sentry.Filter(discard) - - -class ErrorTest(Exception): - pass - - -def error_test(*_, **__): - raise ErrorTest("This was raised by error_raiser.") - - -class TestFilter: - def test_creation(self): - f = sentry.Filter(lambda _: _) - assert f - assert isinstance(f, sentry.Filter) - - class TestGetSingle: - @pytest.mark.asyncio - async def test_success(self, s: sentry.Sentry): - await s.queue.put(None) - assert await s.f().get_single() is None - - @pytest.mark.asyncio - async def test_failure(self, discarding_filter: sentry.Filter): - with pytest.raises(exc.Discard): - await discarding_filter.get_single() - - class TestGet: - @pytest.mark.asyncio - async def test_success(self, s: sentry.Sentry): - await s.queue.put(None) - assert await s.f().get() is None - - @pytest.mark.asyncio - async def test_timeout(self, s: sentry.Sentry): - with pytest.raises(asyncio.TimeoutError): - async with async_timeout.timeout(0.001): - await s.f().get() - - @pytest.mark.asyncio - async def test_filter(self, s: sentry.Sentry): - await s.queue.put(None) - await s.queue.put(None) - await s.queue.put(None) - - assert await s.f().filter(lambda x: x is None, "Is not None").get_single() is None - - with pytest.raises(exc.Discard): - await s.f().filter(lambda x: isinstance(x, type), error="Is not type").get_single() - - with pytest.raises(ErrorTest): - await s.f().filter(error_test, error="Is error").get_single() - - @pytest.mark.asyncio - async def test_map(self, s: sentry.Sentry): - await s.queue.put(None) - await s.queue.put(None) - - assert await s.f().map(lambda x: 1).get_single() == 1 - - with pytest.raises(ErrorTest): - await s.f().map(error_test).get_single() - - @pytest.mark.asyncio - async def test_type(self, s: sentry.Sentry): - await s.queue.put(1) - await s.queue.put("no") - - assert await s.f().type(int).get_single() == 1 - - with pytest.raises(exc.Discard): - await s.f().type(int).get_single() - - @pytest.mark.asyncio - async def test_msg(self, s: sentry.Sentry): - class ExampleMessage(blueprints.Message): - def __hash__(self): - return 1 - - msg = ExampleMessage() - await s.queue.put(msg) - await s.queue.put("no") - - assert await s.f().msg().get_single() is msg - - with pytest.raises(exc.Discard): - await s.f().msg().get_single() - - class AvailableMessage(blueprints.Message): - def __hash__(self): - return 1 - - def text(self) -> str: - return "1" - - class NotAvailableMessage(blueprints.Message): - def __hash__(self): - return 2 - - def text(self) -> str: - raise exc.NotAvailableError() - - class NeverAvailableMessage(blueprints.Message): - def __hash__(self): - return 3 - - @pytest.mark.asyncio - async def test_requires(self, s: sentry.Sentry): - avmsg = self.AvailableMessage() - await s.queue.put(avmsg) - assert await s.f().requires("text").get_single() is avmsg - - await s.queue.put(self.NotAvailableMessage()) - with pytest.raises(exc.Discard): - await s.f().requires("text").get_single() - - await s.queue.put(self.NeverAvailableMessage()) - with pytest.raises(exc.NeverAvailableError): - await s.f().requires("text").get_single() - - await s.queue.put(self.NotAvailableMessage()) - with pytest.raises(exc.NotAvailableError): - await s.f().requires("text", propagate_not_available=True).get_single() - - await s.queue.put(self.NeverAvailableMessage()) - with pytest.raises(exc.Discard): - await s.f().requires("text", propagate_never_available=False).get_single() - - @pytest.mark.asyncio - async def test_field(self, s: sentry.Sentry): - avmsg = self.AvailableMessage() - await s.queue.put(avmsg) - assert await s.f().field("text").get_single() == "1" - - await s.queue.put(self.NotAvailableMessage()) - with pytest.raises(exc.Discard): - await s.f().field("text").get_single() - - await s.queue.put(self.NeverAvailableMessage()) - with pytest.raises(exc.NeverAvailableError): - await s.f().field("text").get_single() - - await s.queue.put(self.NotAvailableMessage()) - with pytest.raises(exc.NotAvailableError): - await s.f().field("text", propagate_not_available=True).get_single() - - await s.queue.put(self.NeverAvailableMessage()) - with pytest.raises(exc.Discard): - await s.f().field("text", propagate_never_available=False).get_single() - - @pytest.mark.asyncio - async def test_startswith(self, s: sentry.Sentry): - await s.queue.put("yarrharr") - await s.queue.put("yohoho") - - assert await s.f().startswith("yarr").get_single() == "yarrharr" - - with pytest.raises(exc.Discard): - await s.f().startswith("yarr").get_single() - - @pytest.mark.asyncio - async def test_endswith(self, s: sentry.Sentry): - await s.queue.put("yarrharr") - await s.queue.put("yohoho") - - assert await s.f().endswith("harr").get_single() == "yarrharr" - - with pytest.raises(exc.Discard): - await s.f().endswith("harr").get_single() - - @pytest.mark.asyncio - async def test_regex(self, s: sentry.Sentry): - await s.queue.put("yarrharr") - await s.queue.put("yohoho") - - assert isinstance(await s.f().regex(re.compile(r"[yh]arr")).get_single(), re.Match) - - with pytest.raises(exc.Discard): - await s.f().regex(re.compile(r"[yh]arr")).get_single() - - @pytest.mark.asyncio - async def test_choices(self, s: sentry.Sentry): - await s.queue.put("yarrharr") - await s.queue.put("yohoho") - - assert await s.f().choices("yarrharr", "banana").get_single() == "yarrharr" - - with pytest.raises(exc.Discard): - await s.f().choices("yarrharr", "banana").get_single() diff --git a/royalnet/engineer/tests/test_teleporter.py b/royalnet/engineer/tests/test_teleporter.py index 70ef7d77..875b707f 100644 --- a/royalnet/engineer/tests/test_teleporter.py +++ b/royalnet/engineer/tests/test_teleporter.py @@ -2,12 +2,12 @@ import pytest import inspect import pydantic import pydantic.fields -import royalnet.engineer as re -import typing +import royalnet.engineer.teleporter as tp @pytest.fixture def my_function(): + # noinspection PyUnusedLocal def f(*, big_f: str, _hidden: int) -> int: return _hidden return f @@ -16,14 +16,15 @@ def my_function(): def test_parameter_to_field(my_function): signature = inspect.signature(my_function) parameter = signature.parameters["big_f"] - t, fieldinfo = re.parameter_to_field(parameter) + t, fieldinfo = tp.parameter_to_field(parameter) assert isinstance(fieldinfo, pydantic.fields.FieldInfo) assert fieldinfo.default is ... assert fieldinfo.title == parameter.name == "big_f" def test_signature_to_model(my_function): - InputModel, OutputModel = re.signature_to_model(my_function) + # noinspection PyPep8Naming + InputModel, OutputModel = tp.signature_to_model(my_function) assert callable(InputModel) model = InputModel(big_f="banana") @@ -58,7 +59,7 @@ def test_signature_to_model(my_function): # noinspection PyTypeChecker class TestTeleporter: def test_standard_function(self): - @re.teleporter() + @tp.teleporter() def standard_function(a: int, b: int, _return_str: bool = False) -> int: if _return_str: return "You asked me this." @@ -91,7 +92,7 @@ class TestTeleporter: @pytest.mark.asyncio async def test_async_function(self): - @re.teleporter(is_async=True) + @tp.teleporter(is_async=True) async def async_function(a: int, b: int, _return_str: bool = False) -> int: if _return_str: return "You asked me this." @@ -123,7 +124,7 @@ class TestTeleporter: _ = await async_function(1, 2) def test_only_input(self): - @re.teleporter(validate_output=False) + @tp.teleporter(validate_output=False) def standard_function(a: int, b: int, _return_str: bool = False) -> int: if _return_str: return "You asked me this." @@ -154,7 +155,7 @@ class TestTeleporter: _ = standard_function(1, 2) def test_only_output(self): - @re.teleporter(validate_input=False) + @tp.teleporter(validate_input=False) def standard_function(a: int, b: int, _return_str: bool = False) -> int: if _return_str: return "You asked me this." diff --git a/royalnet/engineer/wrench.py b/royalnet/engineer/wrench.py new file mode 100644 index 00000000..b018e321 --- /dev/null +++ b/royalnet/engineer/wrench.py @@ -0,0 +1,306 @@ +""" +Wrenches are objects which can used instead of coroutines in Sentry receiver filters, +acting similarly to function factories and allowing to easily define filter functions with parameters. +""" + +from __future__ import annotations +import royalnet.royaltyping as t + +import abc + +from . import discard +from . import exc + + +class Wrench(metaclass=abc.ABCMeta): + """ + The abstract base class for Wrenches. + """ + + @abc.abstractmethod + async def filter(self, obj: t.Any) -> t.Any: + """ + The function applied to all objects transiting through the pipeline: + + - If the function **returns**, its return value will be passed to the next node in the pipeline; + - If the function **raises**, the error is propagated downwards. + + A special exception is available for discarding objects: :exc:`.discard.Discard`. + If raised, the object will be silently ignored. + """ + raise NotImplementedError() + + def __call__(self, obj: t.Any) -> t.Awaitable[t.Any]: + """ + Allow instances to be directly called, emulating coroutine functions. + """ + return self.filter(obj) + + +class PassAll(Wrench): + """ + **Return** each received object as it is. + + .. note:: To be used only in testing. + """ + + async def filter(self, obj: t.Any) -> t.Any: + return obj + + +class DiscardAll(Wrench): + """ + **Discard** each received object. + + .. note:: To be used only in testing. + """ + + async def filter(self, obj: t.Any) -> t.Any: + raise discard.Discard(obj, "Discard filter discards everything") + + +class ErrorAll(Wrench): + """ + **Raise** :exc:`.exc.DeliberateException` for each received object. + + .. note:: To be used only in testing. + """ + + async def filter(self, obj: t.Any) -> t.Any: + raise exc.DeliberateException("ErrorAll received an object") + + +class CheckBase(Wrench, metaclass=abc.ABCMeta): + """ + Check a condition on the received objects: + + - If the check returns :data:`True`, the object is **returned**; + - If the check returns :data:`False`, the object is **discarded**; + - If an error is raised, it is **propagated**. + """ + + def __init__(self, *, invert: bool = False): + self.invert: bool = invert + """ + If set to :data:`True`, this Nut will invert its results: + + - If the check returns :data:`True`, the object is **discarded**; + - If the check returns :data:`False`, the object is **returned**; + - If an error is raised, it is **propagated**. + """ + + def __invert__(self): + return self.__class__(invert=not self.invert) + + @abc.abstractmethod + async def check(self, obj: t.Any) -> bool: + """ + The condition to check. + + :param obj: The object passing through the pipeline. + :return: Whether the check was successful or not. + """ + raise NotImplementedError() + + @abc.abstractmethod + def error(self, obj: t.Any) -> str: + """ + The error message to attach as :attr:`.Discard.message` if the object is discarded. + + :param obj: The object passing through the pipeline. + :return: The error message. + """ + raise NotImplementedError() + + async def filter(self, obj: t.Any) -> t.Any: + if await self.check(obj) ^ self.invert: + return obj + else: + raise discard.Discard(obj=obj, message=self.error(obj)) + + +class Type(CheckBase): + """ + Check the type of an object: + + - If the object **is** of the specified type, it is **returned**; + - If the object **isn't** of the specified type, it is **discarded**; + """ + + def __init__(self, type_: t.Type, **kwargs): + super().__init__(**kwargs) + self.type: t.Type = type_ + + async def check(self, obj: t.Any) -> bool: + return isinstance(obj, self.type) + + def error(self, obj: t.Any) -> str: + return f"Not instance of type {self.type}" + + +class StartsWith(CheckBase): + """ + Check if an object :func:`startswith` a certain prefix. + """ + + def __init__(self, prefix: str, **kwargs): + super().__init__(**kwargs) + self.prefix: str = prefix + + async def check(self, obj: t.Any) -> bool: + return obj.startswith(self.prefix) + + def error(self, obj: t.Any) -> str: + return f"Didn't start with {self.prefix}" + + +class EndsWith(CheckBase): + """ + Check if an object :func:`endswith` a certain suffix. + """ + + def __init__(self, suffix: str, **kwargs): + super().__init__(**kwargs) + self.suffix: str = suffix + + async def check(self, obj: t.Any) -> bool: + return obj.startswith(self.suffix) + + def error(self, obj: t.Any) -> str: + return f"Didn't end with {self.suffix}" + + +class Choice(CheckBase): + """ + Check if an object is among the accepted list. + """ + + def __init__(self, *accepted, **kwargs): + super().__init__(**kwargs) + self.accepted: t.Collection = accepted + """ + A collection of elements which can be chosen. + """ + + async def check(self, obj: t.Any) -> bool: + return obj in self.accepted + + def error(self, obj: t.Any) -> str: + return f"Not a valid choice" + + +class RegexCheck(CheckBase): + """ + Check if an object matches a regex pattern. + """ + + def __init__(self, pattern: t.Pattern, **kwargs): + super().__init__(**kwargs) + self.pattern: t.Pattern = pattern + """ + The pattern that should be matched. + """ + + async def check(self, obj: t.Any) -> bool: + return bool(self.pattern.match(obj)) + + def error(self, obj: t.Any) -> str: + return f"Didn't match pattern {self.pattern}" + + +class RegexMatch(Wrench): + """ + Apply a regex over an object: + + - If it matches, **return** the :class:`re.Match` object; + - If it doesn't match, **discard** the object. + """ + + def __init__(self, pattern: t.Pattern, **kwargs): + super().__init__(**kwargs) + self.pattern: t.Pattern = pattern + """ + The pattern that should be matched. + """ + + async def filter(self, obj: t.Any) -> t.Any: + if match := self.pattern.match(obj): + return match + else: + raise discard.Discard(obj, f"Didn't match pattern {obj}") + + +class RegexReplace(Wrench): + """ + Apply a regex over an object: + + - If it matches, replace the match(es) with :attr:`.replacement` and **return** the result. + - If it doesn't match, **return** the object as it is. + """ + + def __init__(self, pattern: t.Pattern, replacement: t.Union[str, bytes]): + self.pattern: t.Pattern = pattern + """ + The pattern that should be matched. + """ + + self.replacement: t.Union[str, bytes] = replacement + """ + The substitution string for the object. + """ + + async def filter(self, obj: t.Any) -> t.Any: + return self.pattern.sub(self.replacement, obj) + + +class Lambda(Wrench): + """ + Apply a syncronous function over the received objects. + """ + + def __init__(self, func: t.Callable[[t.Any], t.Any]): + self.func: t.Callable[[t.Any], t.Any] = func + """ + The function to apply. + """ + + async def filter(self, obj: t.Any) -> t.Any: + return self.func(obj) + + +class Check(CheckBase): + """ + Check a condition on the received objects. + """ + + def __init__(self, func: t.Callable[[t.Any], t.Any], error: str, *args, **kwargs): + super().__init__(*args, **kwargs) + self.func: t.Callable[[t.Any], t.Any] = func + """ + The condition to check. + """ + + self.error: str = error + """ + The error message to display if the check fails. + """ + + async def check(self, obj: t.Any) -> bool: + return self.func(obj) + + def error(self, obj: t.Any) -> str: + return self.error + + +__all__ = ( + "Wrench", + "CheckBase", + "Type", + "StartsWith", + "EndsWith", + "Choice", + "RegexCheck", + "RegexMatch", + "RegexReplace", + "Lambda", +) diff --git a/royalnet/royaltyping/__init__.py b/royalnet/royaltyping/__init__.py index 4372f887..f41e6477 100644 --- a/royalnet/royaltyping/__init__.py +++ b/royalnet/royaltyping/__init__.py @@ -74,3 +74,6 @@ An async generator yielding either: * another :data:`.AsyncAdventure`; * :data:`None`. """ + + +Conversation = Callable[["Sentry"], Awaitable[Optional["Conversation"]]]