1
Fork 0
mirror of https://github.com/RYGhub/royalnet.git synced 2024-11-22 19:14:20 +00:00

💥 Reimplement engineer module from scratch (#2)

This commit is contained in:
Steffo 2020-12-24 12:13:35 +01:00 committed by GitHub
parent 133f503926
commit 715c0e72df
Signed by: github
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 726 additions and 828 deletions

View file

@ -1,10 +1,5 @@
"""
A chatbot command router inspired by :mod:`fastapi`.
Chat bot utilities.
All names are inspired by the `Engineer Class of Team Fortress 2 <https://wiki.teamfortress.com/wiki/Engineer>`_.
"""
from .blueprints import *
from .teleporter import *
from .sentry import *
from .exc import *

View file

@ -1,4 +0,0 @@
from .blueprint import *
from .message import *
from .channel import *
from .user import *

View file

@ -1,92 +0,0 @@
import abc
from .. import exc
class Blueprint(metaclass=abc.ABCMeta):
"""
A class containing methods common between all blueprints.
To extend a blueprint, inherit from it while using the :class:`abc.ABCMeta` metaclass, and make all new functions
return :exc:`.exc.NeverAvailableError`:
.. code-block::
class Channel(Blueprint, metaclass=abc.ABCMeta):
def name(self):
raise exc.NeverAvailableError()
To implement a blueprint for a specific chat platform, inherit from the blueprint, override :meth:`__init__`,
:meth:`__hash__` and the methods that are implemented by the platform in question, either returning the
corresponding value or raising :exc:`.exc.NotAvailableError` if there is no data available.
.. code-block::
class ExampleChannel(Channel):
def __init__(self, chat_id: int):
self.chat_id: int = chat_id
def __hash__(self):
return self.chat_id
def name(self):
return ExampleClient.expensive_get_channel_name(self.chat_id)
.. note:: To improve performance, you might want to wrap all data methods in :func:`functools.lru_cache` decorators.
.. code-block::
@functools.lru_cache(24)
def name(self):
return ExampleClient.expensive_get_channel_name(self.chat_id)
"""
def __init__(self):
"""
:return: The created object.
"""
pass
@abc.abstractmethod
def __hash__(self):
"""
:return: A value that uniquely identifies the channel inside this Python process.
"""
raise NotImplementedError()
def requires(self, *fields) -> True:
"""
Ensure that this blueprint has the specified fields, re-raising the highest priority exception raised between
all of them.
.. code-block::
def print_msg(message: Message):
message.requires("text", "timestamp")
print(f"{message.timestamp().isoformat()}: {message.text()}")
:raises .exc.NeverAvailableError: If at least one of the fields raised a :exc:`.exc.NeverAvailableError`.
:raises .exc.NotAvailableError: If no field raised a :exc:`.exc.NeverAvailableError`, but at least one raised a
:exc:`.exc.NotAvailableError`.
"""
exceptions = []
for field in fields:
try:
self.__getattribute__(field)()
except exc.NeverAvailableError as ex:
exceptions.append(ex)
except exc.NotAvailableError as ex:
exceptions.append(ex)
if len(exceptions) > 0:
raise max(exceptions, key=lambda e: e.priority)
return True
__all__ = (
"Blueprint",
)

View file

@ -1,35 +0,0 @@
from __future__ import annotations
from royalnet.royaltyping import *
import abc
from .. import exc
from .blueprint import Blueprint
class Channel(Blueprint, metaclass=abc.ABCMeta):
"""
An abstract class representing a channel where messages can be sent.
.. seealso:: :class:`.Blueprint`
"""
def name(self) -> str:
"""
:return: The name of the message channel, such as the chat title.
:raises .exc.NeverAvailableError: If the chat platform does not support channel names.
:raises .exc.NotAvailableError: If this channel does not have any name.
"""
raise exc.NeverAvailableError()
def topic(self) -> str:
"""
:return: The topic or description of the message channel.
:raises .exc.NeverAvailableError: If the chat platform does not support channel topics / descriptions.
:raises .exc.NotAvailableError: If this channel does not have any name.
"""
raise exc.NeverAvailableError()
__all__ = (
"Channel",
)

View file

@ -1,53 +0,0 @@
from __future__ import annotations
from royalnet.royaltyping import *
import abc
import datetime
from .. import exc
from .blueprint import Blueprint
from .channel import Channel
class Message(Blueprint, metaclass=abc.ABCMeta):
"""
An abstract class representing a chat message sent in any platform.
.. seealso:: :class:`.Blueprint`
"""
def text(self) -> str:
"""
:return: The raw text contents of the message.
:raises .exc.NeverAvailableError: If the chat platform does not support text messages.
:raises .exc.NotAvailableError: If this message does not have any text.
"""
raise exc.NeverAvailableError()
def timestamp(self) -> datetime.datetime:
"""
:return: The :class:`datetime.datetime` at which the message was sent / received.
:raises .exc.NeverAvailableError: If the chat platform does not support timestamps.
:raises .exc.NotAvailableError: If this message is special and does not have any timestamp.
"""
raise exc.NeverAvailableError()
def reply_to(self) -> Message:
"""
:return: The :class:`.Message` this message is a reply to.
:raises .exc.NeverAvailableError: If the chat platform does not support replies.
:raises .exc.NotAvailableError: If this message is not a reply to any other message.
"""
raise exc.NeverAvailableError()
def channel(self) -> Channel:
"""
:return: The :class:`.Channel` this message was sent in.
:raises .exc.NeverAvailableError: If the chat platform does not support channels.
:raises .exc.NotAvailableError: If this message was not sent in any channel.
"""
raise exc.NeverAvailableError()
__all__ = (
"Message",
)

View file

@ -1,35 +0,0 @@
from __future__ import annotations
from royalnet.royaltyping import *
import abc
import sqlalchemy.orm
from .. import exc
from .blueprint import Blueprint
class User(Blueprint, metaclass=abc.ABCMeta):
"""
An abstract class representing a chat user.
.. seealso:: :class:`.Blueprint`
"""
def name(self) -> str:
"""
:return: The user's name.
:raises .exc.NeverAvailableError: If the chat platform does not support usernames.
:raises .exc.NotAvailableError: If this user does not have any name.
"""
raise exc.NeverAvailableError()
def database(self, session: sqlalchemy.orm.Session) -> Any:
"""
:param session: A :class:`sqlalchemy.orm.Session` instance to use to fetch the database entry.
:return: The database entry for this user.
"""
raise exc.NeverAvailableError()
__all__ = (
"User",
)

114
royalnet/engineer/bullet.py Normal file
View file

@ -0,0 +1,114 @@
"""
"""
from __future__ import annotations
import royalnet.royaltyping as t
import abc
import datetime
import sqlalchemy.orm
from . import exc
class Bullet(metaclass=abc.ABCMeta):
"""
The abstract base class for Bullet data models.
"""
@abc.abstractmethod
def __hash__(self) -> int:
"""
:return: A value that uniquely identifies the object in this Python interpreter process.
"""
raise NotImplementedError()
class Message(Bullet, metaclass=abc.ABCMeta):
"""
An abstract class representing a chat message.
"""
async def text(self) -> t.Optional[str]:
"""
:return: The raw text contents of the message.
:raises .exc.NotSupportedError: If the frontend does not support text messages.
"""
raise exc.NotSupportedError()
async def timestamp(self) -> t.Optional[datetime.datetime]:
"""
:return: The :class:`datetime.datetime` at which the message was sent.
:raises .exc.NotSupportedError: If the frontend does not support timestamps.
"""
raise exc.NotSupportedError()
async def reply_to(self) -> t.Optional[Message]:
"""
:return: The :class:`.Message` this message is a reply to.
:raises .exc.NotSupportedError: If the frontend does not support replies.
"""
raise exc.NotSupportedError()
async def channel(self) -> t.Optional[Channel]:
"""
:return: The :class:`.Channel` this message was sent in.
:raises .exc.NotSupportedError: If the frontend does not support channels.
"""
raise exc.NotSupportedError()
class Channel(Bullet, metaclass=abc.ABCMeta):
"""
An abstract class representing a channel where messages can be sent.
"""
async def name(self) -> t.Optional[str]:
"""
:return: The name of the message channel, such as the chat title.
:raises .exc.NotSupportedError: If the frontend does not support channel names.
"""
raise exc.NotSupportedError()
async def topic(self) -> t.Optional[str]:
"""
:return: The topic (description) of the message channel.
:raises .exc.NotSupportedError: If the frontend does not support channel topics / descriptions.
"""
raise exc.NotSupportedError()
async def users(self) -> t.List[User]:
"""
:return: A :class:`list` of :class:`.User` who can read messages sent in the channel.
:raises .exc.NotSupportedError: If the frontend does not support such a feature.
"""
raise exc.NotSupportedError()
class User(Bullet, metaclass=abc.ABCMeta):
"""
An abstract class representing a user who can read or send messages in the chat.
"""
async def name(self) -> t.Optional[str]:
"""
:return: The user's name.
:raises .exc.NotSupportedError: If the frontend does not support usernames.
"""
raise exc.NotSupportedError()
async def database(self, session: sqlalchemy.orm.Session) -> t.Any:
"""
:param session: A :class:`sqlalchemy.orm.Session` instance to use to fetch the database entry.
:return: The database entry for this user.
"""
raise exc.NotSupportedError()
__all__ = (
"Bullet",
"Message",
"Channel",
"User",
)

View file

@ -0,0 +1,13 @@
class Discard(BaseException):
"""
A special exception which should be raised by Metals if a certain object should be discarded from the queue.
"""
def __init__(self, obj, message):
self.obj = obj
self.message = message
def __repr__(self):
return f"<Discard>"
def __str__(self):
return f"Discarded {self.obj}: {self.message}"

View file

@ -0,0 +1,51 @@
"""
Dispensers instantiate sentries and dispatch events in bulk to the whole group.
"""
from __future__ import annotations
import royalnet.royaltyping as t
import logging
import contextlib
from .sentry import SentrySource
log = logging.getLogger(__name__)
class Dispenser:
def __init__(self):
self.sentries: t.List[SentrySource] = []
"""
A :class:`list` of all the running sentries of this dispenser.
"""
def put(self, item: t.Any) -> None:
"""
Insert a new item in the queues of all the running sentries.
:param item: The item to insert.
"""
log.debug(f"Putting {item}")
for sentry in self.sentries:
sentry.put(item)
@contextlib.contextmanager
def sentry(self, *args, **kwargs):
"""
A context manager which creates a :class:`.SentrySource` and keeps it in :attr:`.sentries` while it is being
used.
"""
sentry = SentrySource(dispenser=self, *args, **kwargs)
self.sentries.append(sentry)
yield sentry
self.sentries.remove(sentry)
async def run(self, conv: t.Conversation) -> None:
with self.sentry() as sentry:
state = conv(sentry)
while True:
state = await state

View file

@ -1,38 +1,27 @@
import royalnet.exc
import pydantic
class EngineerException(royalnet.exc.RoyalnetException):
class EngineerException(Exception):
"""
An exception raised by the engineer module.
The base class for errors in :mod:`royalnet.engineer`.
"""
class BlueprintError(EngineerException):
class WrenchException(EngineerException):
"""
An error related to the :mod:`royalnet.engineer.blueprints`.
The base class for errors in :mod:`royalnet.engineer.wrench`.
"""
class NeverAvailableError(BlueprintError, NotImplementedError):
class DeliberateException(WrenchException):
"""
The requested property is never supplied by the chat platform the message was sent in.
This exception was deliberately raised by :class:`royalnet.engineer.wrench.ErrorAll`.
"""
priority = 1
class NotAvailableError(BlueprintError):
"""
The requested property was not supplied by the chat platform for the specific message this exception was raised in.
"""
priority = 2
class TeleporterError(EngineerException, pydantic.ValidationError):
"""
The validation of some object though a :mod:`pydantic` model failed.
The base class for errors in :mod:`royalnet.engineer.teleporter`.
"""
@ -48,28 +37,13 @@ class OutTeleporterError(TeleporterError):
"""
class SentryError(EngineerException):
class BulletException(EngineerException):
"""
An error related to the :mod:`royalnet.engineer.sentry`.
The base class for errors in :mod:`royalnet.engineer.bullet`.
"""
class FilterError(SentryError):
class NotSupportedError(BulletException, NotImplementedError):
"""
An error related to the :class:`royalnet.engineer.sentry.Filter`.
The requested property isn't available on the current frontend.
"""
class Discard(FilterError):
"""
Discard the object from the queue.
"""
def __init__(self, obj, message):
self.obj = obj
self.message = message
def __repr__(self):
return f"<Discard>"
def __str__(self):
return f"Discarded {self.obj}: {self.message}"

188
royalnet/engineer/sentry.py Normal file
View file

@ -0,0 +1,188 @@
"""
Sentries are asyncronous receivers for events (usually :class:`bullet.Bullet`) incoming from Dispensers.
They support event filtering through Wrenches and coroutine functions.
"""
from __future__ import annotations
import royalnet.royaltyping as t
import abc
import logging
import asyncio
from . import discard
from . import bullet
if t.TYPE_CHECKING:
from .dispenser import Dispenser
log = logging.getLogger(__name__)
class Sentry(metaclass=abc.ABCMeta):
"""
The abstract object representing a node of the pipeline.
"""
@abc.abstractmethod
def __len__(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
def get_nowait(self) -> bullet.Bullet:
"""
Try to get a single :class:`~.bullet.Bullet` from the pipeline, without blocking or handling discards.
:return: The **returned** :class:`~.bullet.Bullet`.
:raises asyncio.QueueEmpty: If the queue is empty.
:raises .discard.Discard: If the object was **discarded** by the pipeline.
:raises Exception: If an exception was **raised** in the pipeline.
"""
raise NotImplementedError()
@abc.abstractmethod
async def get(self) -> bullet.Bullet:
"""
Try to get a single :class:`~.bullet.Bullet` from the pipeline, blocking until something is available, but
without handling discards.
:return: The **returned** :class:`~.bullet.Bullet`.
:raises .discard.Discard: If the object was **discarded** by the pipeline.
:raises Exception: If an exception was **raised** in the pipeline.
"""
raise NotImplementedError()
async def wait(self) -> bullet.Bullet:
"""
Try to get a single :class:`~.bullet.Bullet` from the pipeline, blocking until something is available and is not
discarded.
:return: The **returned** :class:`~.bullet.Bullet`.
:raises Exception: If an exception was **raised** in the pipeline.
"""
while True:
try:
result = await self.get()
log.debug(f"Returned: {result}")
return result
except discard.Discard as d:
log.debug(f"{str(d)}")
continue
def __await__(self):
"""
Awaiting an object implementing :class:`.SentryInterface` corresponds to awaiting :meth:`.wait`.
"""
return self.get().__await__()
@abc.abstractmethod
async def put(self, item: t.Any) -> None:
"""
Insert a new item in the queue.
:param item: The item to be added.
"""
raise NotImplementedError()
def filter(self, wrench: t.Callable[[t.Any], t.Awaitable[t.Any]]) -> SentryFilter:
"""
Chain a new filter to the pipeline.
:param wrench: The filter to add to the chain. It can either be a :class:`.wrench.Wrench`, or a coroutine
function accepting a single object as parameter and returning the same or a different one.
:return: A new :class:`.SentryFilter` which includes the filter.
.. seealso:: :meth:`.__or__`
"""
if callable(wrench):
return SentryFilter(previous=self, wrench=wrench)
else:
raise TypeError("wrench must be either a Wrench or a coroutine function")
def __or__(self, other: t.Callable[[t.Any], t.Awaitable[t.Any]]) -> SentryFilter:
"""
A unix-pipe-like interface for :meth:`.filter`.
.. code-block::
await (sentry | wrench.Type(Message) | wrench.Sync(lambda o: o.text))
"""
try:
return self.filter(other)
except TypeError:
raise TypeError("Right-side must be either a Wrench or a coroutine function")
@abc.abstractmethod
def dispenser(self):
"""
Get the :class:`.Dispenser` that created this Sentry.
:return: The :class:`.Dispenser` object.
"""
raise NotImplementedError()
class SentryFilter(Sentry):
"""
A non-root node of the filtering pipeline.
"""
def __init__(self, previous: Sentry, wrench: t.Callable[[t.Any], t.Awaitable[t.Any]]):
self.previous: Sentry = previous
"""
The previous node of the pipeline.
"""
self.wrench: t.Callable[[t.Any], t.Awaitable[t.Any]] = wrench
"""
The coroutine function to apply to all objects passing through this node.
"""
def __len__(self) -> int:
return len(self.previous) + 1
def get_nowait(self) -> bullet.Bullet:
return self.previous.get_nowait()
async def get(self) -> bullet.Bullet:
return await self.previous.get()
async def put(self, item) -> None:
return await self.previous.put(item)
def dispenser(self):
return self.previous.dispenser()
class SentrySource(Sentry):
"""
The root and source of the pipeline.
"""
def __init__(self, dispenser: "Dispenser", queue_size: int = 12):
self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_size)
self._dispenser: "Dispenser" = dispenser
def __len__(self) -> int:
return 1
def get_nowait(self) -> bullet.Bullet:
return self.queue.get_nowait()
async def get(self) -> bullet.Bullet:
return await self.queue.get()
async def put(self, item) -> None:
return await self.queue.put(bullet)
async def dispenser(self):
return self._dispenser
__all__ = (
"Sentry",
"SentryFilter",
"SentrySource",
)

View file

@ -1 +0,0 @@
from .sentry import *

View file

@ -1,257 +0,0 @@
from __future__ import annotations
from royalnet.royaltyping import *
import functools
import logging
from .. import exc, blueprints
log = logging.getLogger(__name__)
class Filter:
"""
A fluent interface for filtering data.
"""
def __init__(self, func: Callable):
self.func: Callable = func
async def get(self) -> Any:
"""
Wait until an :class:`object` leaves the queue and passes through the filter, then return it.
:return: The :class:`object` which left the queue.
"""
while True:
try:
return await self.get_single()
except exc.Discard:
continue
async def get_single(self) -> Any:
"""
Let one :class:`object` pass through the filter, then either return it or raise an error if the object should be
discarded.
:return: The :class:`object` which left the queue.
:raises exc.Discard: If the object was filtered.
"""
try:
result = await self.func(None)
except exc.Discard as e:
log.debug(str(e))
raise
else:
log.debug(f"Dequeued {result}")
return result
@staticmethod
def _deco_filter(c: Callable[[Any], bool], *, error: str):
"""
A decorator which checks the condition ``c`` on all objects transiting through the queue:
- If the check **passes**, the object itself is returned;
- If the check **fails**, :exc:`.exc.Discard` is raised, with the object and the ``error`` string as parameters;
- If an error is raised, propagate the error upwards.
.. warning:: Raising :exc:`.exc.Discard` in ``c`` will automatically cause the object to be discarded, as if
:data:`False` was returned.
:param c: A function that takes in input an enqueued object and returns either the same object or a new one to
pass to the next filter in the queue.
:param error: The string that :exc:`.exc.Discard` should display if the object is discarded.
"""
def decorator(func):
@functools.wraps(func)
async def decorated(obj):
result: Any = await func(obj)
if c(result):
return result
else:
raise exc.Discard(obj=result, message=error)
return decorated
return decorator
def filter(self, c: Callable[[Any], bool], error: str) -> Filter:
"""
Check the condition ``c`` on all objects transiting through the queue:
- If the check **passes**, the object goes on to the next filter;
- If the check **fails**, the object is discarded, with ``error`` as reason;
- If an error is raised, propagate the error upwards.
:param c: A function that takes in input an object and performs a check on it, returning either :data:`True`
or :data:`False`.
:param error: The reason for which objects should be discarded.
:return: A new :class:`Filter` with this new condition.
.. seealso:: :meth:`._deco_filter`, :func:`filter`
"""
return self.__class__(self._deco_filter(c, error=error)(self.func))
@staticmethod
def _deco_map(c: Callable[[Any], object]):
"""
A decorator which applies the function ``c`` on all objects transiting through the queue:
- If the function **returns**, return its return value;
- If the function **raises** an error, it is propagated upwards.
:param c: A function that takes in input an enqueued object and returns either the same object or something
else.
.. seealso:: :func:`map`
"""
def decorator(func):
@functools.wraps(func)
async def decorated(obj):
result: Any = await func(obj)
return c(result)
return decorated
return decorator
def map(self, c: Callable[[Any], object]) -> Filter:
"""
Apply the function ``c`` on all objects transiting through the queue:
- If the function **returns**, its return value replaces the object in the queue;
- If the function **raises** :exc:`.exc.Discard`, the object is discarded;
- If the function **raises another error**, propagate the error upwards.
:param c: A function that takes in input an enqueued object and returns either the same object or something
else.
:return: A new :class:`Filter` with this new condition.
.. seealso:: :meth:`._deco_map`, :func:`filter`
"""
return self.__class__(self._deco_map(c)(self.func))
def type(self, t: type) -> Filter:
"""
Check if an object passing through the queue :func:`isinstance` of the type ``t``.
:param t: The type that objects should be instances of.
:return: A new :class:`Filter` with this new condition.
.. seealso:: :func:`isinstance`
"""
return self.filter(lambda o: isinstance(o, t), error=f"Not instance of type {t}")
def msg(self) -> Filter:
"""
Check if an object passing through the queue :func:`isinstance` of :class:`.blueprints.Message`.
:return: A new :class:`Filter` with this new condition.
"""
return self.type(blueprints.Message)
def requires(self, *fields,
propagate_not_available=False,
propagate_never_available=True) -> Filter:
"""
Test a :class:`.blueprints.Blueprint`'s fields by using its ``.requires()`` method:
- If the :class:`.blueprints.Blueprint` has the appropriate fields, return it;
- If the :class:`.blueprints.Blueprint` doesn't have data for at least one of the fields, the object is discarded;
- the :class:`.blueprints.Blueprint` never has data for at least one of the fields, :exc:`.exc.NotAvailableError` is propagated upwards.
:param fields: The fields to test for.
:param propagate_not_available: If :exc:`.exc.NotAvailableError` should be propagated
instead of discarding the errored object.
:param propagate_never_available: If :exc:`.exc.NeverAvailableError` should be propagated
instead of discarding the errored object.
:return: A new :class:`Filter` with this new condition.
.. seealso:: :meth:`.blueprints.Blueprint.requires`
"""
def check(obj):
try:
return obj.requires(*fields)
except exc.NotAvailableError:
if propagate_not_available:
raise
raise exc.Discard(obj, "Data is not available")
except exc.NeverAvailableError:
if propagate_never_available:
raise
raise exc.Discard(obj, "Data is never available")
return self.filter(check, error=".requires() method returned False")
def field(self, field: str,
propagate_not_available=False,
propagate_never_available=True) -> Filter:
"""
Replace a :class:`.blueprints.Blueprint` with the value of one of its fields.
:param field: The field to access.
:param propagate_not_available: If :exc:`.exc.NotAvailableError` should be propagated
instead of discarding the errored object.
:param propagate_never_available: If :exc:`.exc.NeverAvailableError` should be propagated
instead of discarding the errored object.
:return: A new :class:`Filter` with the new requirements.
"""
def replace(obj):
try:
return obj.__getattribute__(field)()
except exc.NotAvailableError:
if propagate_not_available:
raise
raise exc.Discard(obj, "Data is not available")
except exc.NeverAvailableError:
if propagate_never_available:
raise
raise exc.Discard(obj, "Data is never available")
return self.map(replace)
def startswith(self, prefix: str):
"""
Check if an object starts with the specified prefix and discard the objects that do not.
:param prefix: The prefix object should start with.
:return: A new :class:`Filter` with the new requirements.
.. seealso:: :meth:`str.startswith`
"""
return self.filter(lambda x: x.startswith(prefix), error=f"Text didn't start with {prefix}")
def endswith(self, suffix: str):
"""
Check if an object ends with the specified suffix and discard the objects that do not.
:param suffix: The prefix object should start with.
:return: A new :class:`Filter` with the new requirements.
.. seealso:: :meth:`str.endswith`
"""
return self.filter(lambda x: x.endswith(suffix), error=f"Text didn't end with {suffix}")
def regex(self, pattern: Pattern):
"""
Apply a regex over an object and discard the object if it does not match.
:param pattern: The pattern that should be matched by the text.
:return: A new :class:`Filter` with the new requirements.
"""
def mapping(x):
if match := pattern.match(x):
return match
else:
raise exc.Discard(x, f"Text didn't match pattern {pattern}")
return self.map(mapping)
def choices(self, *choices):
"""
Ensure an object is in the ``choices`` list, discarding the object otherwise.
:param choices: The pattern that should be matched by the text.
:return: A new :class:`Filter` with the new requirements.
"""
return self.filter(lambda o: o in choices, error="Not a valid choice")
__all__ = (
"Filter",
)

View file

@ -1,58 +0,0 @@
from __future__ import annotations
from royalnet.royaltyping import *
import logging
import asyncio
from .filter import Filter
log = logging.getLogger(__name__)
class Sentry:
"""
A class that allows using the ``await`` keyword to suspend a command execution until a new message is received.
"""
QUEUE_SIZE = 12
"""
The size of the object :attr:`.queue`.
"""
def __init__(self, filter_type: Type[Filter] = Filter):
self.queue: asyncio.Queue = asyncio.Queue(maxsize=self.QUEUE_SIZE)
"""
An object queue where incoming :class:`object` are stored, with a size limit of :attr:`.QUEUE_SIZE`.
"""
self.filter_type: Type[Filter] = filter_type
"""
The filter to be used in :meth:`.f` calls, by default :class:`.filters.Filter`.
"""
def __repr__(self):
return f"<Sentry>"
def f(self):
"""
Create a :attr:`.filter_type` object, which can be configured through its fluent interface.
Remember to call ``.get()`` on the end of the chain to finally get the object.
To get any object, call:
.. code-block::
await sentry.f().get()
.. seealso:: :class:`.filters.Filter`
:return: The created :class:`.filters.Filter`.
"""
async def func(_):
return await self.queue.get()
return self.filter_type(func)
__all__ = (
"Sentry",
)

View file

@ -1,13 +1,19 @@
import functools
"""
The teleporter uses :mod:`pydantic` to validate function parameters and return values.
"""
from royalnet.royaltyping import *
from __future__ import annotations
import royalnet.royaltyping as t
import logging
import pydantic
import inspect
import functools
from . import exc
Model = TypeVar("Model")
Value = TypeVar("Value")
Value = t.TypeVar("Value")
log = logging.getLogger(__name__)
class TeleporterConfig(pydantic.BaseConfig):
@ -17,7 +23,7 @@ class TeleporterConfig(pydantic.BaseConfig):
arbitrary_types_allowed = True
def parameter_to_field(param: inspect.Parameter, **kwargs) -> Tuple[type, pydantic.fields.FieldInfo]:
def parameter_to_field(param: inspect.Parameter, **kwargs) -> t.Tuple[type, pydantic.fields.FieldInfo]:
"""
Convert a :class:`inspect.Parameter` to a type-field :class:`tuple`, which can be easily passed to
:func:`pydantic.create_model`.
@ -45,9 +51,9 @@ def parameter_to_field(param: inspect.Parameter, **kwargs) -> Tuple[type, pydant
)
def signature_to_model(f: Callable,
__config__: Type[pydantic.BaseConfig] = TeleporterConfig,
extra_params: Dict[str, type] = None) -> Tuple[type, type]:
def signature_to_model(f: t.Callable,
__config__: t.Type[pydantic.BaseConfig] = TeleporterConfig,
extra_params: t.Dict[str, type] = None) -> t.Tuple[type, type]:
"""
Convert the signature of a function to two pydantic models: one for the input and another one for the output.
@ -80,7 +86,7 @@ def signature_to_model(f: Callable,
return input_model, output_model
def split_kwargs(**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def split_kwargs(**kwargs) -> t.Tuple[t.Dict[str, t.Any], t.Dict[str, t.Any]]:
"""
Split the kwargs passed to this function in two different :class:`dict`, based on whether their name starts with
``_`` or not.
@ -129,7 +135,7 @@ def teleport_out(__model: type, value: Value) -> Value:
raise exc.OutTeleporterError(errors=e.raw_errors, model=e.model)
def teleporter(__config__: Type[pydantic.BaseConfig] = TeleporterConfig,
def teleporter(__config__: t.Type[pydantic.BaseConfig] = TeleporterConfig,
is_async: bool = False,
validate_input: bool = True,
validate_output: bool = True):
@ -148,7 +154,7 @@ def teleporter(__config__: Type[pydantic.BaseConfig] = TeleporterConfig,
.. seealso:: :func:`.signature_to_model`
"""
def decorator(f: Callable):
def decorator(f: t.Callable):
# noinspection PyPep8Naming
InputModel, OutputModel = signature_to_model(f, __config__=__config__)

View file

@ -0,0 +1,12 @@
"""
"""
from __future__ import annotations
import royalnet.royaltyping as t
import logging
from . import *
log = logging.getLogger(__name__)

View file

@ -1,230 +0,0 @@
import pytest
import asyncio
import async_timeout
import re
from royalnet.engineer import sentry, exc, blueprints
@pytest.fixture
def s() -> sentry.Sentry:
return sentry.Sentry()
class TestSentry:
def test_creation(self, s: sentry.Sentry):
assert s
assert isinstance(s, sentry.Sentry)
@pytest.mark.asyncio
async def test_put(self, s: sentry.Sentry):
await s.queue.put(None)
@pytest.mark.asyncio
async def test_get(self, s: sentry.Sentry):
await s.queue.put(None)
assert await s.queue.get() is None
@pytest.mark.asyncio
async def test_f(self, s: sentry.Sentry):
await s.queue.put(None)
f = s.f()
assert f
assert isinstance(f, sentry.Filter)
assert hasattr(f, "get")
@pytest.fixture
def discarding_filter() -> sentry.Filter:
async def discard(_):
raise exc.Discard(None, "This filter discards everything!")
return sentry.Filter(discard)
class ErrorTest(Exception):
pass
def error_test(*_, **__):
raise ErrorTest("This was raised by error_raiser.")
class TestFilter:
def test_creation(self):
f = sentry.Filter(lambda _: _)
assert f
assert isinstance(f, sentry.Filter)
class TestGetSingle:
@pytest.mark.asyncio
async def test_success(self, s: sentry.Sentry):
await s.queue.put(None)
assert await s.f().get_single() is None
@pytest.mark.asyncio
async def test_failure(self, discarding_filter: sentry.Filter):
with pytest.raises(exc.Discard):
await discarding_filter.get_single()
class TestGet:
@pytest.mark.asyncio
async def test_success(self, s: sentry.Sentry):
await s.queue.put(None)
assert await s.f().get() is None
@pytest.mark.asyncio
async def test_timeout(self, s: sentry.Sentry):
with pytest.raises(asyncio.TimeoutError):
async with async_timeout.timeout(0.001):
await s.f().get()
@pytest.mark.asyncio
async def test_filter(self, s: sentry.Sentry):
await s.queue.put(None)
await s.queue.put(None)
await s.queue.put(None)
assert await s.f().filter(lambda x: x is None, "Is not None").get_single() is None
with pytest.raises(exc.Discard):
await s.f().filter(lambda x: isinstance(x, type), error="Is not type").get_single()
with pytest.raises(ErrorTest):
await s.f().filter(error_test, error="Is error").get_single()
@pytest.mark.asyncio
async def test_map(self, s: sentry.Sentry):
await s.queue.put(None)
await s.queue.put(None)
assert await s.f().map(lambda x: 1).get_single() == 1
with pytest.raises(ErrorTest):
await s.f().map(error_test).get_single()
@pytest.mark.asyncio
async def test_type(self, s: sentry.Sentry):
await s.queue.put(1)
await s.queue.put("no")
assert await s.f().type(int).get_single() == 1
with pytest.raises(exc.Discard):
await s.f().type(int).get_single()
@pytest.mark.asyncio
async def test_msg(self, s: sentry.Sentry):
class ExampleMessage(blueprints.Message):
def __hash__(self):
return 1
msg = ExampleMessage()
await s.queue.put(msg)
await s.queue.put("no")
assert await s.f().msg().get_single() is msg
with pytest.raises(exc.Discard):
await s.f().msg().get_single()
class AvailableMessage(blueprints.Message):
def __hash__(self):
return 1
def text(self) -> str:
return "1"
class NotAvailableMessage(blueprints.Message):
def __hash__(self):
return 2
def text(self) -> str:
raise exc.NotAvailableError()
class NeverAvailableMessage(blueprints.Message):
def __hash__(self):
return 3
@pytest.mark.asyncio
async def test_requires(self, s: sentry.Sentry):
avmsg = self.AvailableMessage()
await s.queue.put(avmsg)
assert await s.f().requires("text").get_single() is avmsg
await s.queue.put(self.NotAvailableMessage())
with pytest.raises(exc.Discard):
await s.f().requires("text").get_single()
await s.queue.put(self.NeverAvailableMessage())
with pytest.raises(exc.NeverAvailableError):
await s.f().requires("text").get_single()
await s.queue.put(self.NotAvailableMessage())
with pytest.raises(exc.NotAvailableError):
await s.f().requires("text", propagate_not_available=True).get_single()
await s.queue.put(self.NeverAvailableMessage())
with pytest.raises(exc.Discard):
await s.f().requires("text", propagate_never_available=False).get_single()
@pytest.mark.asyncio
async def test_field(self, s: sentry.Sentry):
avmsg = self.AvailableMessage()
await s.queue.put(avmsg)
assert await s.f().field("text").get_single() == "1"
await s.queue.put(self.NotAvailableMessage())
with pytest.raises(exc.Discard):
await s.f().field("text").get_single()
await s.queue.put(self.NeverAvailableMessage())
with pytest.raises(exc.NeverAvailableError):
await s.f().field("text").get_single()
await s.queue.put(self.NotAvailableMessage())
with pytest.raises(exc.NotAvailableError):
await s.f().field("text", propagate_not_available=True).get_single()
await s.queue.put(self.NeverAvailableMessage())
with pytest.raises(exc.Discard):
await s.f().field("text", propagate_never_available=False).get_single()
@pytest.mark.asyncio
async def test_startswith(self, s: sentry.Sentry):
await s.queue.put("yarrharr")
await s.queue.put("yohoho")
assert await s.f().startswith("yarr").get_single() == "yarrharr"
with pytest.raises(exc.Discard):
await s.f().startswith("yarr").get_single()
@pytest.mark.asyncio
async def test_endswith(self, s: sentry.Sentry):
await s.queue.put("yarrharr")
await s.queue.put("yohoho")
assert await s.f().endswith("harr").get_single() == "yarrharr"
with pytest.raises(exc.Discard):
await s.f().endswith("harr").get_single()
@pytest.mark.asyncio
async def test_regex(self, s: sentry.Sentry):
await s.queue.put("yarrharr")
await s.queue.put("yohoho")
assert isinstance(await s.f().regex(re.compile(r"[yh]arr")).get_single(), re.Match)
with pytest.raises(exc.Discard):
await s.f().regex(re.compile(r"[yh]arr")).get_single()
@pytest.mark.asyncio
async def test_choices(self, s: sentry.Sentry):
await s.queue.put("yarrharr")
await s.queue.put("yohoho")
assert await s.f().choices("yarrharr", "banana").get_single() == "yarrharr"
with pytest.raises(exc.Discard):
await s.f().choices("yarrharr", "banana").get_single()

View file

@ -2,12 +2,12 @@ import pytest
import inspect
import pydantic
import pydantic.fields
import royalnet.engineer as re
import typing
import royalnet.engineer.teleporter as tp
@pytest.fixture
def my_function():
# noinspection PyUnusedLocal
def f(*, big_f: str, _hidden: int) -> int:
return _hidden
return f
@ -16,14 +16,15 @@ def my_function():
def test_parameter_to_field(my_function):
signature = inspect.signature(my_function)
parameter = signature.parameters["big_f"]
t, fieldinfo = re.parameter_to_field(parameter)
t, fieldinfo = tp.parameter_to_field(parameter)
assert isinstance(fieldinfo, pydantic.fields.FieldInfo)
assert fieldinfo.default is ...
assert fieldinfo.title == parameter.name == "big_f"
def test_signature_to_model(my_function):
InputModel, OutputModel = re.signature_to_model(my_function)
# noinspection PyPep8Naming
InputModel, OutputModel = tp.signature_to_model(my_function)
assert callable(InputModel)
model = InputModel(big_f="banana")
@ -58,7 +59,7 @@ def test_signature_to_model(my_function):
# noinspection PyTypeChecker
class TestTeleporter:
def test_standard_function(self):
@re.teleporter()
@tp.teleporter()
def standard_function(a: int, b: int, _return_str: bool = False) -> int:
if _return_str:
return "You asked me this."
@ -91,7 +92,7 @@ class TestTeleporter:
@pytest.mark.asyncio
async def test_async_function(self):
@re.teleporter(is_async=True)
@tp.teleporter(is_async=True)
async def async_function(a: int, b: int, _return_str: bool = False) -> int:
if _return_str:
return "You asked me this."
@ -123,7 +124,7 @@ class TestTeleporter:
_ = await async_function(1, 2)
def test_only_input(self):
@re.teleporter(validate_output=False)
@tp.teleporter(validate_output=False)
def standard_function(a: int, b: int, _return_str: bool = False) -> int:
if _return_str:
return "You asked me this."
@ -154,7 +155,7 @@ class TestTeleporter:
_ = standard_function(1, 2)
def test_only_output(self):
@re.teleporter(validate_input=False)
@tp.teleporter(validate_input=False)
def standard_function(a: int, b: int, _return_str: bool = False) -> int:
if _return_str:
return "You asked me this."

306
royalnet/engineer/wrench.py Normal file
View file

@ -0,0 +1,306 @@
"""
Wrenches are objects which can used instead of coroutines in Sentry receiver filters,
acting similarly to function factories and allowing to easily define filter functions with parameters.
"""
from __future__ import annotations
import royalnet.royaltyping as t
import abc
from . import discard
from . import exc
class Wrench(metaclass=abc.ABCMeta):
"""
The abstract base class for Wrenches.
"""
@abc.abstractmethod
async def filter(self, obj: t.Any) -> t.Any:
"""
The function applied to all objects transiting through the pipeline:
- If the function **returns**, its return value will be passed to the next node in the pipeline;
- If the function **raises**, the error is propagated downwards.
A special exception is available for discarding objects: :exc:`.discard.Discard`.
If raised, the object will be silently ignored.
"""
raise NotImplementedError()
def __call__(self, obj: t.Any) -> t.Awaitable[t.Any]:
"""
Allow instances to be directly called, emulating coroutine functions.
"""
return self.filter(obj)
class PassAll(Wrench):
"""
**Return** each received object as it is.
.. note:: To be used only in testing.
"""
async def filter(self, obj: t.Any) -> t.Any:
return obj
class DiscardAll(Wrench):
"""
**Discard** each received object.
.. note:: To be used only in testing.
"""
async def filter(self, obj: t.Any) -> t.Any:
raise discard.Discard(obj, "Discard filter discards everything")
class ErrorAll(Wrench):
"""
**Raise** :exc:`.exc.DeliberateException` for each received object.
.. note:: To be used only in testing.
"""
async def filter(self, obj: t.Any) -> t.Any:
raise exc.DeliberateException("ErrorAll received an object")
class CheckBase(Wrench, metaclass=abc.ABCMeta):
"""
Check a condition on the received objects:
- If the check returns :data:`True`, the object is **returned**;
- If the check returns :data:`False`, the object is **discarded**;
- If an error is raised, it is **propagated**.
"""
def __init__(self, *, invert: bool = False):
self.invert: bool = invert
"""
If set to :data:`True`, this Nut will invert its results:
- If the check returns :data:`True`, the object is **discarded**;
- If the check returns :data:`False`, the object is **returned**;
- If an error is raised, it is **propagated**.
"""
def __invert__(self):
return self.__class__(invert=not self.invert)
@abc.abstractmethod
async def check(self, obj: t.Any) -> bool:
"""
The condition to check.
:param obj: The object passing through the pipeline.
:return: Whether the check was successful or not.
"""
raise NotImplementedError()
@abc.abstractmethod
def error(self, obj: t.Any) -> str:
"""
The error message to attach as :attr:`.Discard.message` if the object is discarded.
:param obj: The object passing through the pipeline.
:return: The error message.
"""
raise NotImplementedError()
async def filter(self, obj: t.Any) -> t.Any:
if await self.check(obj) ^ self.invert:
return obj
else:
raise discard.Discard(obj=obj, message=self.error(obj))
class Type(CheckBase):
"""
Check the type of an object:
- If the object **is** of the specified type, it is **returned**;
- If the object **isn't** of the specified type, it is **discarded**;
"""
def __init__(self, type_: t.Type, **kwargs):
super().__init__(**kwargs)
self.type: t.Type = type_
async def check(self, obj: t.Any) -> bool:
return isinstance(obj, self.type)
def error(self, obj: t.Any) -> str:
return f"Not instance of type {self.type}"
class StartsWith(CheckBase):
"""
Check if an object :func:`startswith` a certain prefix.
"""
def __init__(self, prefix: str, **kwargs):
super().__init__(**kwargs)
self.prefix: str = prefix
async def check(self, obj: t.Any) -> bool:
return obj.startswith(self.prefix)
def error(self, obj: t.Any) -> str:
return f"Didn't start with {self.prefix}"
class EndsWith(CheckBase):
"""
Check if an object :func:`endswith` a certain suffix.
"""
def __init__(self, suffix: str, **kwargs):
super().__init__(**kwargs)
self.suffix: str = suffix
async def check(self, obj: t.Any) -> bool:
return obj.startswith(self.suffix)
def error(self, obj: t.Any) -> str:
return f"Didn't end with {self.suffix}"
class Choice(CheckBase):
"""
Check if an object is among the accepted list.
"""
def __init__(self, *accepted, **kwargs):
super().__init__(**kwargs)
self.accepted: t.Collection = accepted
"""
A collection of elements which can be chosen.
"""
async def check(self, obj: t.Any) -> bool:
return obj in self.accepted
def error(self, obj: t.Any) -> str:
return f"Not a valid choice"
class RegexCheck(CheckBase):
"""
Check if an object matches a regex pattern.
"""
def __init__(self, pattern: t.Pattern, **kwargs):
super().__init__(**kwargs)
self.pattern: t.Pattern = pattern
"""
The pattern that should be matched.
"""
async def check(self, obj: t.Any) -> bool:
return bool(self.pattern.match(obj))
def error(self, obj: t.Any) -> str:
return f"Didn't match pattern {self.pattern}"
class RegexMatch(Wrench):
"""
Apply a regex over an object:
- If it matches, **return** the :class:`re.Match` object;
- If it doesn't match, **discard** the object.
"""
def __init__(self, pattern: t.Pattern, **kwargs):
super().__init__(**kwargs)
self.pattern: t.Pattern = pattern
"""
The pattern that should be matched.
"""
async def filter(self, obj: t.Any) -> t.Any:
if match := self.pattern.match(obj):
return match
else:
raise discard.Discard(obj, f"Didn't match pattern {obj}")
class RegexReplace(Wrench):
"""
Apply a regex over an object:
- If it matches, replace the match(es) with :attr:`.replacement` and **return** the result.
- If it doesn't match, **return** the object as it is.
"""
def __init__(self, pattern: t.Pattern, replacement: t.Union[str, bytes]):
self.pattern: t.Pattern = pattern
"""
The pattern that should be matched.
"""
self.replacement: t.Union[str, bytes] = replacement
"""
The substitution string for the object.
"""
async def filter(self, obj: t.Any) -> t.Any:
return self.pattern.sub(self.replacement, obj)
class Lambda(Wrench):
"""
Apply a syncronous function over the received objects.
"""
def __init__(self, func: t.Callable[[t.Any], t.Any]):
self.func: t.Callable[[t.Any], t.Any] = func
"""
The function to apply.
"""
async def filter(self, obj: t.Any) -> t.Any:
return self.func(obj)
class Check(CheckBase):
"""
Check a condition on the received objects.
"""
def __init__(self, func: t.Callable[[t.Any], t.Any], error: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.func: t.Callable[[t.Any], t.Any] = func
"""
The condition to check.
"""
self.error: str = error
"""
The error message to display if the check fails.
"""
async def check(self, obj: t.Any) -> bool:
return self.func(obj)
def error(self, obj: t.Any) -> str:
return self.error
__all__ = (
"Wrench",
"CheckBase",
"Type",
"StartsWith",
"EndsWith",
"Choice",
"RegexCheck",
"RegexMatch",
"RegexReplace",
"Lambda",
)

View file

@ -74,3 +74,6 @@ An async generator yielding either:
* another :data:`.AsyncAdventure`;
* :data:`None`.
"""
Conversation = Callable[["Sentry"], Awaitable[Optional["Conversation"]]]