mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 03:24:20 +00:00
💥 Reimplement engineer module from scratch (#2)
This commit is contained in:
parent
133f503926
commit
715c0e72df
20 changed files with 726 additions and 828 deletions
|
@ -1,10 +1,5 @@
|
||||||
"""
|
"""
|
||||||
A chatbot command router inspired by :mod:`fastapi`.
|
Chat bot utilities.
|
||||||
|
|
||||||
All names are inspired by the `Engineer Class of Team Fortress 2 <https://wiki.teamfortress.com/wiki/Engineer>`_.
|
All names are inspired by the `Engineer Class of Team Fortress 2 <https://wiki.teamfortress.com/wiki/Engineer>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .blueprints import *
|
|
||||||
from .teleporter import *
|
|
||||||
from .sentry import *
|
|
||||||
from .exc import *
|
|
||||||
|
|
|
@ -1,4 +0,0 @@
|
||||||
from .blueprint import *
|
|
||||||
from .message import *
|
|
||||||
from .channel import *
|
|
||||||
from .user import *
|
|
|
@ -1,92 +0,0 @@
|
||||||
import abc
|
|
||||||
|
|
||||||
from .. import exc
|
|
||||||
|
|
||||||
|
|
||||||
class Blueprint(metaclass=abc.ABCMeta):
|
|
||||||
"""
|
|
||||||
A class containing methods common between all blueprints.
|
|
||||||
|
|
||||||
To extend a blueprint, inherit from it while using the :class:`abc.ABCMeta` metaclass, and make all new functions
|
|
||||||
return :exc:`.exc.NeverAvailableError`:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
class Channel(Blueprint, metaclass=abc.ABCMeta):
|
|
||||||
def name(self):
|
|
||||||
raise exc.NeverAvailableError()
|
|
||||||
|
|
||||||
To implement a blueprint for a specific chat platform, inherit from the blueprint, override :meth:`__init__`,
|
|
||||||
:meth:`__hash__` and the methods that are implemented by the platform in question, either returning the
|
|
||||||
corresponding value or raising :exc:`.exc.NotAvailableError` if there is no data available.
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
class ExampleChannel(Channel):
|
|
||||||
def __init__(self, chat_id: int):
|
|
||||||
self.chat_id: int = chat_id
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return self.chat_id
|
|
||||||
|
|
||||||
def name(self):
|
|
||||||
return ExampleClient.expensive_get_channel_name(self.chat_id)
|
|
||||||
|
|
||||||
.. note:: To improve performance, you might want to wrap all data methods in :func:`functools.lru_cache` decorators.
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
@functools.lru_cache(24)
|
|
||||||
def name(self):
|
|
||||||
return ExampleClient.expensive_get_channel_name(self.chat_id)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""
|
|
||||||
:return: The created object.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def __hash__(self):
|
|
||||||
"""
|
|
||||||
:return: A value that uniquely identifies the channel inside this Python process.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def requires(self, *fields) -> True:
|
|
||||||
"""
|
|
||||||
Ensure that this blueprint has the specified fields, re-raising the highest priority exception raised between
|
|
||||||
all of them.
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
def print_msg(message: Message):
|
|
||||||
message.requires("text", "timestamp")
|
|
||||||
print(f"{message.timestamp().isoformat()}: {message.text()}")
|
|
||||||
|
|
||||||
:raises .exc.NeverAvailableError: If at least one of the fields raised a :exc:`.exc.NeverAvailableError`.
|
|
||||||
:raises .exc.NotAvailableError: If no field raised a :exc:`.exc.NeverAvailableError`, but at least one raised a
|
|
||||||
:exc:`.exc.NotAvailableError`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
exceptions = []
|
|
||||||
|
|
||||||
for field in fields:
|
|
||||||
try:
|
|
||||||
self.__getattribute__(field)()
|
|
||||||
except exc.NeverAvailableError as ex:
|
|
||||||
exceptions.append(ex)
|
|
||||||
except exc.NotAvailableError as ex:
|
|
||||||
exceptions.append(ex)
|
|
||||||
|
|
||||||
if len(exceptions) > 0:
|
|
||||||
raise max(exceptions, key=lambda e: e.priority)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"Blueprint",
|
|
||||||
)
|
|
|
@ -1,35 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
from royalnet.royaltyping import *
|
|
||||||
import abc
|
|
||||||
|
|
||||||
from .. import exc
|
|
||||||
from .blueprint import Blueprint
|
|
||||||
|
|
||||||
|
|
||||||
class Channel(Blueprint, metaclass=abc.ABCMeta):
|
|
||||||
"""
|
|
||||||
An abstract class representing a channel where messages can be sent.
|
|
||||||
|
|
||||||
.. seealso:: :class:`.Blueprint`
|
|
||||||
"""
|
|
||||||
|
|
||||||
def name(self) -> str:
|
|
||||||
"""
|
|
||||||
:return: The name of the message channel, such as the chat title.
|
|
||||||
:raises .exc.NeverAvailableError: If the chat platform does not support channel names.
|
|
||||||
:raises .exc.NotAvailableError: If this channel does not have any name.
|
|
||||||
"""
|
|
||||||
raise exc.NeverAvailableError()
|
|
||||||
|
|
||||||
def topic(self) -> str:
|
|
||||||
"""
|
|
||||||
:return: The topic or description of the message channel.
|
|
||||||
:raises .exc.NeverAvailableError: If the chat platform does not support channel topics / descriptions.
|
|
||||||
:raises .exc.NotAvailableError: If this channel does not have any name.
|
|
||||||
"""
|
|
||||||
raise exc.NeverAvailableError()
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"Channel",
|
|
||||||
)
|
|
|
@ -1,53 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
from royalnet.royaltyping import *
|
|
||||||
import abc
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
from .. import exc
|
|
||||||
from .blueprint import Blueprint
|
|
||||||
from .channel import Channel
|
|
||||||
|
|
||||||
|
|
||||||
class Message(Blueprint, metaclass=abc.ABCMeta):
|
|
||||||
"""
|
|
||||||
An abstract class representing a chat message sent in any platform.
|
|
||||||
|
|
||||||
.. seealso:: :class:`.Blueprint`
|
|
||||||
"""
|
|
||||||
|
|
||||||
def text(self) -> str:
|
|
||||||
"""
|
|
||||||
:return: The raw text contents of the message.
|
|
||||||
:raises .exc.NeverAvailableError: If the chat platform does not support text messages.
|
|
||||||
:raises .exc.NotAvailableError: If this message does not have any text.
|
|
||||||
"""
|
|
||||||
raise exc.NeverAvailableError()
|
|
||||||
|
|
||||||
def timestamp(self) -> datetime.datetime:
|
|
||||||
"""
|
|
||||||
:return: The :class:`datetime.datetime` at which the message was sent / received.
|
|
||||||
:raises .exc.NeverAvailableError: If the chat platform does not support timestamps.
|
|
||||||
:raises .exc.NotAvailableError: If this message is special and does not have any timestamp.
|
|
||||||
"""
|
|
||||||
raise exc.NeverAvailableError()
|
|
||||||
|
|
||||||
def reply_to(self) -> Message:
|
|
||||||
"""
|
|
||||||
:return: The :class:`.Message` this message is a reply to.
|
|
||||||
:raises .exc.NeverAvailableError: If the chat platform does not support replies.
|
|
||||||
:raises .exc.NotAvailableError: If this message is not a reply to any other message.
|
|
||||||
"""
|
|
||||||
raise exc.NeverAvailableError()
|
|
||||||
|
|
||||||
def channel(self) -> Channel:
|
|
||||||
"""
|
|
||||||
:return: The :class:`.Channel` this message was sent in.
|
|
||||||
:raises .exc.NeverAvailableError: If the chat platform does not support channels.
|
|
||||||
:raises .exc.NotAvailableError: If this message was not sent in any channel.
|
|
||||||
"""
|
|
||||||
raise exc.NeverAvailableError()
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"Message",
|
|
||||||
)
|
|
|
@ -1,35 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
from royalnet.royaltyping import *
|
|
||||||
import abc
|
|
||||||
import sqlalchemy.orm
|
|
||||||
|
|
||||||
from .. import exc
|
|
||||||
from .blueprint import Blueprint
|
|
||||||
|
|
||||||
|
|
||||||
class User(Blueprint, metaclass=abc.ABCMeta):
|
|
||||||
"""
|
|
||||||
An abstract class representing a chat user.
|
|
||||||
|
|
||||||
.. seealso:: :class:`.Blueprint`
|
|
||||||
"""
|
|
||||||
|
|
||||||
def name(self) -> str:
|
|
||||||
"""
|
|
||||||
:return: The user's name.
|
|
||||||
:raises .exc.NeverAvailableError: If the chat platform does not support usernames.
|
|
||||||
:raises .exc.NotAvailableError: If this user does not have any name.
|
|
||||||
"""
|
|
||||||
raise exc.NeverAvailableError()
|
|
||||||
|
|
||||||
def database(self, session: sqlalchemy.orm.Session) -> Any:
|
|
||||||
"""
|
|
||||||
:param session: A :class:`sqlalchemy.orm.Session` instance to use to fetch the database entry.
|
|
||||||
:return: The database entry for this user.
|
|
||||||
"""
|
|
||||||
raise exc.NeverAvailableError()
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"User",
|
|
||||||
)
|
|
114
royalnet/engineer/bullet.py
Normal file
114
royalnet/engineer/bullet.py
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import royalnet.royaltyping as t
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import datetime
|
||||||
|
import sqlalchemy.orm
|
||||||
|
|
||||||
|
from . import exc
|
||||||
|
|
||||||
|
|
||||||
|
class Bullet(metaclass=abc.ABCMeta):
|
||||||
|
"""
|
||||||
|
The abstract base class for Bullet data models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""
|
||||||
|
:return: A value that uniquely identifies the object in this Python interpreter process.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class Message(Bullet, metaclass=abc.ABCMeta):
|
||||||
|
"""
|
||||||
|
An abstract class representing a chat message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def text(self) -> t.Optional[str]:
|
||||||
|
"""
|
||||||
|
:return: The raw text contents of the message.
|
||||||
|
:raises .exc.NotSupportedError: If the frontend does not support text messages.
|
||||||
|
"""
|
||||||
|
raise exc.NotSupportedError()
|
||||||
|
|
||||||
|
async def timestamp(self) -> t.Optional[datetime.datetime]:
|
||||||
|
"""
|
||||||
|
:return: The :class:`datetime.datetime` at which the message was sent.
|
||||||
|
:raises .exc.NotSupportedError: If the frontend does not support timestamps.
|
||||||
|
"""
|
||||||
|
raise exc.NotSupportedError()
|
||||||
|
|
||||||
|
async def reply_to(self) -> t.Optional[Message]:
|
||||||
|
"""
|
||||||
|
:return: The :class:`.Message` this message is a reply to.
|
||||||
|
:raises .exc.NotSupportedError: If the frontend does not support replies.
|
||||||
|
"""
|
||||||
|
raise exc.NotSupportedError()
|
||||||
|
|
||||||
|
async def channel(self) -> t.Optional[Channel]:
|
||||||
|
"""
|
||||||
|
:return: The :class:`.Channel` this message was sent in.
|
||||||
|
:raises .exc.NotSupportedError: If the frontend does not support channels.
|
||||||
|
"""
|
||||||
|
raise exc.NotSupportedError()
|
||||||
|
|
||||||
|
|
||||||
|
class Channel(Bullet, metaclass=abc.ABCMeta):
|
||||||
|
"""
|
||||||
|
An abstract class representing a channel where messages can be sent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def name(self) -> t.Optional[str]:
|
||||||
|
"""
|
||||||
|
:return: The name of the message channel, such as the chat title.
|
||||||
|
:raises .exc.NotSupportedError: If the frontend does not support channel names.
|
||||||
|
"""
|
||||||
|
raise exc.NotSupportedError()
|
||||||
|
|
||||||
|
async def topic(self) -> t.Optional[str]:
|
||||||
|
"""
|
||||||
|
:return: The topic (description) of the message channel.
|
||||||
|
:raises .exc.NotSupportedError: If the frontend does not support channel topics / descriptions.
|
||||||
|
"""
|
||||||
|
raise exc.NotSupportedError()
|
||||||
|
|
||||||
|
async def users(self) -> t.List[User]:
|
||||||
|
"""
|
||||||
|
:return: A :class:`list` of :class:`.User` who can read messages sent in the channel.
|
||||||
|
:raises .exc.NotSupportedError: If the frontend does not support such a feature.
|
||||||
|
"""
|
||||||
|
raise exc.NotSupportedError()
|
||||||
|
|
||||||
|
|
||||||
|
class User(Bullet, metaclass=abc.ABCMeta):
|
||||||
|
"""
|
||||||
|
An abstract class representing a user who can read or send messages in the chat.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def name(self) -> t.Optional[str]:
|
||||||
|
"""
|
||||||
|
:return: The user's name.
|
||||||
|
:raises .exc.NotSupportedError: If the frontend does not support usernames.
|
||||||
|
"""
|
||||||
|
raise exc.NotSupportedError()
|
||||||
|
|
||||||
|
async def database(self, session: sqlalchemy.orm.Session) -> t.Any:
|
||||||
|
"""
|
||||||
|
:param session: A :class:`sqlalchemy.orm.Session` instance to use to fetch the database entry.
|
||||||
|
:return: The database entry for this user.
|
||||||
|
"""
|
||||||
|
raise exc.NotSupportedError()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"Bullet",
|
||||||
|
"Message",
|
||||||
|
"Channel",
|
||||||
|
"User",
|
||||||
|
)
|
13
royalnet/engineer/discard.py
Normal file
13
royalnet/engineer/discard.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
class Discard(BaseException):
|
||||||
|
"""
|
||||||
|
A special exception which should be raised by Metals if a certain object should be discarded from the queue.
|
||||||
|
"""
|
||||||
|
def __init__(self, obj, message):
|
||||||
|
self.obj = obj
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Discard>"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Discarded {self.obj}: {self.message}"
|
|
@ -0,0 +1,51 @@
|
||||||
|
"""
|
||||||
|
Dispensers instantiate sentries and dispatch events in bulk to the whole group.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import royalnet.royaltyping as t
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
from .sentry import SentrySource
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Dispenser:
|
||||||
|
def __init__(self):
|
||||||
|
self.sentries: t.List[SentrySource] = []
|
||||||
|
"""
|
||||||
|
A :class:`list` of all the running sentries of this dispenser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def put(self, item: t.Any) -> None:
|
||||||
|
"""
|
||||||
|
Insert a new item in the queues of all the running sentries.
|
||||||
|
|
||||||
|
:param item: The item to insert.
|
||||||
|
"""
|
||||||
|
log.debug(f"Putting {item}")
|
||||||
|
for sentry in self.sentries:
|
||||||
|
sentry.put(item)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def sentry(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
A context manager which creates a :class:`.SentrySource` and keeps it in :attr:`.sentries` while it is being
|
||||||
|
used.
|
||||||
|
"""
|
||||||
|
sentry = SentrySource(dispenser=self, *args, **kwargs)
|
||||||
|
self.sentries.append(sentry)
|
||||||
|
|
||||||
|
yield sentry
|
||||||
|
|
||||||
|
self.sentries.remove(sentry)
|
||||||
|
|
||||||
|
async def run(self, conv: t.Conversation) -> None:
|
||||||
|
with self.sentry() as sentry:
|
||||||
|
state = conv(sentry)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
state = await state
|
|
@ -1,38 +1,27 @@
|
||||||
import royalnet.exc
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
|
|
||||||
class EngineerException(royalnet.exc.RoyalnetException):
|
class EngineerException(Exception):
|
||||||
"""
|
"""
|
||||||
An exception raised by the engineer module.
|
The base class for errors in :mod:`royalnet.engineer`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class BlueprintError(EngineerException):
|
class WrenchException(EngineerException):
|
||||||
"""
|
"""
|
||||||
An error related to the :mod:`royalnet.engineer.blueprints`.
|
The base class for errors in :mod:`royalnet.engineer.wrench`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class NeverAvailableError(BlueprintError, NotImplementedError):
|
class DeliberateException(WrenchException):
|
||||||
"""
|
"""
|
||||||
The requested property is never supplied by the chat platform the message was sent in.
|
This exception was deliberately raised by :class:`royalnet.engineer.wrench.ErrorAll`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
priority = 1
|
|
||||||
|
|
||||||
|
|
||||||
class NotAvailableError(BlueprintError):
|
|
||||||
"""
|
|
||||||
The requested property was not supplied by the chat platform for the specific message this exception was raised in.
|
|
||||||
"""
|
|
||||||
|
|
||||||
priority = 2
|
|
||||||
|
|
||||||
|
|
||||||
class TeleporterError(EngineerException, pydantic.ValidationError):
|
class TeleporterError(EngineerException, pydantic.ValidationError):
|
||||||
"""
|
"""
|
||||||
The validation of some object though a :mod:`pydantic` model failed.
|
The base class for errors in :mod:`royalnet.engineer.teleporter`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,28 +37,13 @@ class OutTeleporterError(TeleporterError):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class SentryError(EngineerException):
|
class BulletException(EngineerException):
|
||||||
"""
|
"""
|
||||||
An error related to the :mod:`royalnet.engineer.sentry`.
|
The base class for errors in :mod:`royalnet.engineer.bullet`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class FilterError(SentryError):
|
class NotSupportedError(BulletException, NotImplementedError):
|
||||||
"""
|
"""
|
||||||
An error related to the :class:`royalnet.engineer.sentry.Filter`.
|
The requested property isn't available on the current frontend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Discard(FilterError):
|
|
||||||
"""
|
|
||||||
Discard the object from the queue.
|
|
||||||
"""
|
|
||||||
def __init__(self, obj, message):
|
|
||||||
self.obj = obj
|
|
||||||
self.message = message
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<Discard>"
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"Discarded {self.obj}: {self.message}"
|
|
||||||
|
|
188
royalnet/engineer/sentry.py
Normal file
188
royalnet/engineer/sentry.py
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
"""
|
||||||
|
Sentries are asyncronous receivers for events (usually :class:`bullet.Bullet`) incoming from Dispensers.
|
||||||
|
|
||||||
|
They support event filtering through Wrenches and coroutine functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import royalnet.royaltyping as t
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from . import discard
|
||||||
|
from . import bullet
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from .dispenser import Dispenser
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Sentry(metaclass=abc.ABCMeta):
|
||||||
|
"""
|
||||||
|
The abstract object representing a node of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __len__(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_nowait(self) -> bullet.Bullet:
|
||||||
|
"""
|
||||||
|
Try to get a single :class:`~.bullet.Bullet` from the pipeline, without blocking or handling discards.
|
||||||
|
|
||||||
|
:return: The **returned** :class:`~.bullet.Bullet`.
|
||||||
|
:raises asyncio.QueueEmpty: If the queue is empty.
|
||||||
|
:raises .discard.Discard: If the object was **discarded** by the pipeline.
|
||||||
|
:raises Exception: If an exception was **raised** in the pipeline.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def get(self) -> bullet.Bullet:
|
||||||
|
"""
|
||||||
|
Try to get a single :class:`~.bullet.Bullet` from the pipeline, blocking until something is available, but
|
||||||
|
without handling discards.
|
||||||
|
|
||||||
|
:return: The **returned** :class:`~.bullet.Bullet`.
|
||||||
|
:raises .discard.Discard: If the object was **discarded** by the pipeline.
|
||||||
|
:raises Exception: If an exception was **raised** in the pipeline.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def wait(self) -> bullet.Bullet:
|
||||||
|
"""
|
||||||
|
Try to get a single :class:`~.bullet.Bullet` from the pipeline, blocking until something is available and is not
|
||||||
|
discarded.
|
||||||
|
|
||||||
|
:return: The **returned** :class:`~.bullet.Bullet`.
|
||||||
|
:raises Exception: If an exception was **raised** in the pipeline.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
result = await self.get()
|
||||||
|
log.debug(f"Returned: {result}")
|
||||||
|
return result
|
||||||
|
except discard.Discard as d:
|
||||||
|
log.debug(f"{str(d)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
def __await__(self):
|
||||||
|
"""
|
||||||
|
Awaiting an object implementing :class:`.SentryInterface` corresponds to awaiting :meth:`.wait`.
|
||||||
|
"""
|
||||||
|
return self.get().__await__()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def put(self, item: t.Any) -> None:
|
||||||
|
"""
|
||||||
|
Insert a new item in the queue.
|
||||||
|
|
||||||
|
:param item: The item to be added.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def filter(self, wrench: t.Callable[[t.Any], t.Awaitable[t.Any]]) -> SentryFilter:
|
||||||
|
"""
|
||||||
|
Chain a new filter to the pipeline.
|
||||||
|
|
||||||
|
:param wrench: The filter to add to the chain. It can either be a :class:`.wrench.Wrench`, or a coroutine
|
||||||
|
function accepting a single object as parameter and returning the same or a different one.
|
||||||
|
:return: A new :class:`.SentryFilter` which includes the filter.
|
||||||
|
|
||||||
|
.. seealso:: :meth:`.__or__`
|
||||||
|
"""
|
||||||
|
if callable(wrench):
|
||||||
|
return SentryFilter(previous=self, wrench=wrench)
|
||||||
|
else:
|
||||||
|
raise TypeError("wrench must be either a Wrench or a coroutine function")
|
||||||
|
|
||||||
|
def __or__(self, other: t.Callable[[t.Any], t.Awaitable[t.Any]]) -> SentryFilter:
|
||||||
|
"""
|
||||||
|
A unix-pipe-like interface for :meth:`.filter`.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
await (sentry | wrench.Type(Message) | wrench.Sync(lambda o: o.text))
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.filter(other)
|
||||||
|
except TypeError:
|
||||||
|
raise TypeError("Right-side must be either a Wrench or a coroutine function")
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def dispenser(self):
|
||||||
|
"""
|
||||||
|
Get the :class:`.Dispenser` that created this Sentry.
|
||||||
|
|
||||||
|
:return: The :class:`.Dispenser` object.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class SentryFilter(Sentry):
|
||||||
|
"""
|
||||||
|
A non-root node of the filtering pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, previous: Sentry, wrench: t.Callable[[t.Any], t.Awaitable[t.Any]]):
|
||||||
|
self.previous: Sentry = previous
|
||||||
|
"""
|
||||||
|
The previous node of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.wrench: t.Callable[[t.Any], t.Awaitable[t.Any]] = wrench
|
||||||
|
"""
|
||||||
|
The coroutine function to apply to all objects passing through this node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.previous) + 1
|
||||||
|
|
||||||
|
def get_nowait(self) -> bullet.Bullet:
|
||||||
|
return self.previous.get_nowait()
|
||||||
|
|
||||||
|
async def get(self) -> bullet.Bullet:
|
||||||
|
return await self.previous.get()
|
||||||
|
|
||||||
|
async def put(self, item) -> None:
|
||||||
|
return await self.previous.put(item)
|
||||||
|
|
||||||
|
def dispenser(self):
|
||||||
|
return self.previous.dispenser()
|
||||||
|
|
||||||
|
|
||||||
|
class SentrySource(Sentry):
|
||||||
|
"""
|
||||||
|
The root and source of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dispenser: "Dispenser", queue_size: int = 12):
|
||||||
|
self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_size)
|
||||||
|
self._dispenser: "Dispenser" = dispenser
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def get_nowait(self) -> bullet.Bullet:
|
||||||
|
return self.queue.get_nowait()
|
||||||
|
|
||||||
|
async def get(self) -> bullet.Bullet:
|
||||||
|
return await self.queue.get()
|
||||||
|
|
||||||
|
async def put(self, item) -> None:
|
||||||
|
return await self.queue.put(bullet)
|
||||||
|
|
||||||
|
async def dispenser(self):
|
||||||
|
return self._dispenser
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"Sentry",
|
||||||
|
"SentryFilter",
|
||||||
|
"SentrySource",
|
||||||
|
)
|
|
@ -1 +0,0 @@
|
||||||
from .sentry import *
|
|
|
@ -1,257 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
from royalnet.royaltyping import *
|
|
||||||
import functools
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .. import exc, blueprints
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Filter:
|
|
||||||
"""
|
|
||||||
A fluent interface for filtering data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, func: Callable):
|
|
||||||
self.func: Callable = func
|
|
||||||
|
|
||||||
async def get(self) -> Any:
|
|
||||||
"""
|
|
||||||
Wait until an :class:`object` leaves the queue and passes through the filter, then return it.
|
|
||||||
|
|
||||||
:return: The :class:`object` which left the queue.
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
return await self.get_single()
|
|
||||||
except exc.Discard:
|
|
||||||
continue
|
|
||||||
|
|
||||||
async def get_single(self) -> Any:
|
|
||||||
"""
|
|
||||||
Let one :class:`object` pass through the filter, then either return it or raise an error if the object should be
|
|
||||||
discarded.
|
|
||||||
|
|
||||||
:return: The :class:`object` which left the queue.
|
|
||||||
:raises exc.Discard: If the object was filtered.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await self.func(None)
|
|
||||||
except exc.Discard as e:
|
|
||||||
log.debug(str(e))
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
log.debug(f"Dequeued {result}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _deco_filter(c: Callable[[Any], bool], *, error: str):
|
|
||||||
"""
|
|
||||||
A decorator which checks the condition ``c`` on all objects transiting through the queue:
|
|
||||||
|
|
||||||
- If the check **passes**, the object itself is returned;
|
|
||||||
- If the check **fails**, :exc:`.exc.Discard` is raised, with the object and the ``error`` string as parameters;
|
|
||||||
- If an error is raised, propagate the error upwards.
|
|
||||||
|
|
||||||
.. warning:: Raising :exc:`.exc.Discard` in ``c`` will automatically cause the object to be discarded, as if
|
|
||||||
:data:`False` was returned.
|
|
||||||
|
|
||||||
:param c: A function that takes in input an enqueued object and returns either the same object or a new one to
|
|
||||||
pass to the next filter in the queue.
|
|
||||||
:param error: The string that :exc:`.exc.Discard` should display if the object is discarded.
|
|
||||||
"""
|
|
||||||
def decorator(func):
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def decorated(obj):
|
|
||||||
result: Any = await func(obj)
|
|
||||||
if c(result):
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
raise exc.Discard(obj=result, message=error)
|
|
||||||
return decorated
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def filter(self, c: Callable[[Any], bool], error: str) -> Filter:
|
|
||||||
"""
|
|
||||||
Check the condition ``c`` on all objects transiting through the queue:
|
|
||||||
|
|
||||||
- If the check **passes**, the object goes on to the next filter;
|
|
||||||
- If the check **fails**, the object is discarded, with ``error`` as reason;
|
|
||||||
- If an error is raised, propagate the error upwards.
|
|
||||||
|
|
||||||
:param c: A function that takes in input an object and performs a check on it, returning either :data:`True`
|
|
||||||
or :data:`False`.
|
|
||||||
:param error: The reason for which objects should be discarded.
|
|
||||||
:return: A new :class:`Filter` with this new condition.
|
|
||||||
|
|
||||||
.. seealso:: :meth:`._deco_filter`, :func:`filter`
|
|
||||||
"""
|
|
||||||
return self.__class__(self._deco_filter(c, error=error)(self.func))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _deco_map(c: Callable[[Any], object]):
|
|
||||||
"""
|
|
||||||
A decorator which applies the function ``c`` on all objects transiting through the queue:
|
|
||||||
|
|
||||||
- If the function **returns**, return its return value;
|
|
||||||
- If the function **raises** an error, it is propagated upwards.
|
|
||||||
|
|
||||||
:param c: A function that takes in input an enqueued object and returns either the same object or something
|
|
||||||
else.
|
|
||||||
|
|
||||||
.. seealso:: :func:`map`
|
|
||||||
"""
|
|
||||||
def decorator(func):
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def decorated(obj):
|
|
||||||
result: Any = await func(obj)
|
|
||||||
return c(result)
|
|
||||||
return decorated
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def map(self, c: Callable[[Any], object]) -> Filter:
|
|
||||||
"""
|
|
||||||
Apply the function ``c`` on all objects transiting through the queue:
|
|
||||||
|
|
||||||
- If the function **returns**, its return value replaces the object in the queue;
|
|
||||||
- If the function **raises** :exc:`.exc.Discard`, the object is discarded;
|
|
||||||
- If the function **raises another error**, propagate the error upwards.
|
|
||||||
|
|
||||||
:param c: A function that takes in input an enqueued object and returns either the same object or something
|
|
||||||
else.
|
|
||||||
:return: A new :class:`Filter` with this new condition.
|
|
||||||
|
|
||||||
.. seealso:: :meth:`._deco_map`, :func:`filter`
|
|
||||||
"""
|
|
||||||
return self.__class__(self._deco_map(c)(self.func))
|
|
||||||
|
|
||||||
def type(self, t: type) -> Filter:
|
|
||||||
"""
|
|
||||||
Check if an object passing through the queue :func:`isinstance` of the type ``t``.
|
|
||||||
|
|
||||||
:param t: The type that objects should be instances of.
|
|
||||||
:return: A new :class:`Filter` with this new condition.
|
|
||||||
|
|
||||||
.. seealso:: :func:`isinstance`
|
|
||||||
"""
|
|
||||||
return self.filter(lambda o: isinstance(o, t), error=f"Not instance of type {t}")
|
|
||||||
|
|
||||||
def msg(self) -> Filter:
|
|
||||||
"""
|
|
||||||
Check if an object passing through the queue :func:`isinstance` of :class:`.blueprints.Message`.
|
|
||||||
|
|
||||||
:return: A new :class:`Filter` with this new condition.
|
|
||||||
"""
|
|
||||||
return self.type(blueprints.Message)
|
|
||||||
|
|
||||||
def requires(self, *fields,
|
|
||||||
propagate_not_available=False,
|
|
||||||
propagate_never_available=True) -> Filter:
|
|
||||||
"""
|
|
||||||
Test a :class:`.blueprints.Blueprint`'s fields by using its ``.requires()`` method:
|
|
||||||
|
|
||||||
- If the :class:`.blueprints.Blueprint` has the appropriate fields, return it;
|
|
||||||
- If the :class:`.blueprints.Blueprint` doesn't have data for at least one of the fields, the object is discarded;
|
|
||||||
- the :class:`.blueprints.Blueprint` never has data for at least one of the fields, :exc:`.exc.NotAvailableError` is propagated upwards.
|
|
||||||
|
|
||||||
:param fields: The fields to test for.
|
|
||||||
:param propagate_not_available: If :exc:`.exc.NotAvailableError` should be propagated
|
|
||||||
instead of discarding the errored object.
|
|
||||||
:param propagate_never_available: If :exc:`.exc.NeverAvailableError` should be propagated
|
|
||||||
instead of discarding the errored object.
|
|
||||||
:return: A new :class:`Filter` with this new condition.
|
|
||||||
|
|
||||||
.. seealso:: :meth:`.blueprints.Blueprint.requires`
|
|
||||||
"""
|
|
||||||
def check(obj):
|
|
||||||
try:
|
|
||||||
return obj.requires(*fields)
|
|
||||||
except exc.NotAvailableError:
|
|
||||||
if propagate_not_available:
|
|
||||||
raise
|
|
||||||
raise exc.Discard(obj, "Data is not available")
|
|
||||||
except exc.NeverAvailableError:
|
|
||||||
if propagate_never_available:
|
|
||||||
raise
|
|
||||||
raise exc.Discard(obj, "Data is never available")
|
|
||||||
|
|
||||||
return self.filter(check, error=".requires() method returned False")
|
|
||||||
|
|
||||||
def field(self, field: str,
|
|
||||||
propagate_not_available=False,
|
|
||||||
propagate_never_available=True) -> Filter:
|
|
||||||
"""
|
|
||||||
Replace a :class:`.blueprints.Blueprint` with the value of one of its fields.
|
|
||||||
|
|
||||||
:param field: The field to access.
|
|
||||||
:param propagate_not_available: If :exc:`.exc.NotAvailableError` should be propagated
|
|
||||||
instead of discarding the errored object.
|
|
||||||
:param propagate_never_available: If :exc:`.exc.NeverAvailableError` should be propagated
|
|
||||||
instead of discarding the errored object.
|
|
||||||
:return: A new :class:`Filter` with the new requirements.
|
|
||||||
"""
|
|
||||||
def replace(obj):
|
|
||||||
try:
|
|
||||||
return obj.__getattribute__(field)()
|
|
||||||
except exc.NotAvailableError:
|
|
||||||
if propagate_not_available:
|
|
||||||
raise
|
|
||||||
raise exc.Discard(obj, "Data is not available")
|
|
||||||
except exc.NeverAvailableError:
|
|
||||||
if propagate_never_available:
|
|
||||||
raise
|
|
||||||
raise exc.Discard(obj, "Data is never available")
|
|
||||||
|
|
||||||
return self.map(replace)
|
|
||||||
|
|
||||||
def startswith(self, prefix: str):
|
|
||||||
"""
|
|
||||||
Check if an object starts with the specified prefix and discard the objects that do not.
|
|
||||||
|
|
||||||
:param prefix: The prefix object should start with.
|
|
||||||
:return: A new :class:`Filter` with the new requirements.
|
|
||||||
|
|
||||||
.. seealso:: :meth:`str.startswith`
|
|
||||||
"""
|
|
||||||
return self.filter(lambda x: x.startswith(prefix), error=f"Text didn't start with {prefix}")
|
|
||||||
|
|
||||||
def endswith(self, suffix: str):
|
|
||||||
"""
|
|
||||||
Check if an object ends with the specified suffix and discard the objects that do not.
|
|
||||||
|
|
||||||
:param suffix: The prefix object should start with.
|
|
||||||
:return: A new :class:`Filter` with the new requirements.
|
|
||||||
|
|
||||||
.. seealso:: :meth:`str.endswith`
|
|
||||||
"""
|
|
||||||
return self.filter(lambda x: x.endswith(suffix), error=f"Text didn't end with {suffix}")
|
|
||||||
|
|
||||||
def regex(self, pattern: Pattern):
|
|
||||||
"""
|
|
||||||
Apply a regex over an object and discard the object if it does not match.
|
|
||||||
|
|
||||||
:param pattern: The pattern that should be matched by the text.
|
|
||||||
:return: A new :class:`Filter` with the new requirements.
|
|
||||||
"""
|
|
||||||
def mapping(x):
|
|
||||||
if match := pattern.match(x):
|
|
||||||
return match
|
|
||||||
else:
|
|
||||||
raise exc.Discard(x, f"Text didn't match pattern {pattern}")
|
|
||||||
|
|
||||||
return self.map(mapping)
|
|
||||||
|
|
||||||
def choices(self, *choices):
|
|
||||||
"""
|
|
||||||
Ensure an object is in the ``choices`` list, discarding the object otherwise.
|
|
||||||
|
|
||||||
:param choices: The pattern that should be matched by the text.
|
|
||||||
:return: A new :class:`Filter` with the new requirements.
|
|
||||||
"""
|
|
||||||
return self.filter(lambda o: o in choices, error="Not a valid choice")
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"Filter",
|
|
||||||
)
|
|
|
@ -1,58 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
from royalnet.royaltyping import *
|
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from .filter import Filter
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Sentry:
|
|
||||||
"""
|
|
||||||
A class that allows using the ``await`` keyword to suspend a command execution until a new message is received.
|
|
||||||
"""
|
|
||||||
|
|
||||||
QUEUE_SIZE = 12
|
|
||||||
"""
|
|
||||||
The size of the object :attr:`.queue`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, filter_type: Type[Filter] = Filter):
|
|
||||||
self.queue: asyncio.Queue = asyncio.Queue(maxsize=self.QUEUE_SIZE)
|
|
||||||
"""
|
|
||||||
An object queue where incoming :class:`object` are stored, with a size limit of :attr:`.QUEUE_SIZE`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.filter_type: Type[Filter] = filter_type
|
|
||||||
"""
|
|
||||||
The filter to be used in :meth:`.f` calls, by default :class:`.filters.Filter`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<Sentry>"
|
|
||||||
|
|
||||||
def f(self):
|
|
||||||
"""
|
|
||||||
Create a :attr:`.filter_type` object, which can be configured through its fluent interface.
|
|
||||||
|
|
||||||
Remember to call ``.get()`` on the end of the chain to finally get the object.
|
|
||||||
|
|
||||||
To get any object, call:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
await sentry.f().get()
|
|
||||||
|
|
||||||
.. seealso:: :class:`.filters.Filter`
|
|
||||||
|
|
||||||
:return: The created :class:`.filters.Filter`.
|
|
||||||
"""
|
|
||||||
async def func(_):
|
|
||||||
return await self.queue.get()
|
|
||||||
return self.filter_type(func)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
|
||||||
"Sentry",
|
|
||||||
)
|
|
|
@ -1,13 +1,19 @@
|
||||||
import functools
|
"""
|
||||||
|
The teleporter uses :mod:`pydantic` to validate function parameters and return values.
|
||||||
|
"""
|
||||||
|
|
||||||
from royalnet.royaltyping import *
|
from __future__ import annotations
|
||||||
|
import royalnet.royaltyping as t
|
||||||
|
|
||||||
|
import logging
|
||||||
import pydantic
|
import pydantic
|
||||||
import inspect
|
import inspect
|
||||||
|
import functools
|
||||||
|
|
||||||
from . import exc
|
from . import exc
|
||||||
|
|
||||||
|
Value = t.TypeVar("Value")
|
||||||
Model = TypeVar("Model")
|
log = logging.getLogger(__name__)
|
||||||
Value = TypeVar("Value")
|
|
||||||
|
|
||||||
|
|
||||||
class TeleporterConfig(pydantic.BaseConfig):
|
class TeleporterConfig(pydantic.BaseConfig):
|
||||||
|
@ -17,7 +23,7 @@ class TeleporterConfig(pydantic.BaseConfig):
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
def parameter_to_field(param: inspect.Parameter, **kwargs) -> Tuple[type, pydantic.fields.FieldInfo]:
|
def parameter_to_field(param: inspect.Parameter, **kwargs) -> t.Tuple[type, pydantic.fields.FieldInfo]:
|
||||||
"""
|
"""
|
||||||
Convert a :class:`inspect.Parameter` to a type-field :class:`tuple`, which can be easily passed to
|
Convert a :class:`inspect.Parameter` to a type-field :class:`tuple`, which can be easily passed to
|
||||||
:func:`pydantic.create_model`.
|
:func:`pydantic.create_model`.
|
||||||
|
@ -45,9 +51,9 @@ def parameter_to_field(param: inspect.Parameter, **kwargs) -> Tuple[type, pydant
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def signature_to_model(f: Callable,
|
def signature_to_model(f: t.Callable,
|
||||||
__config__: Type[pydantic.BaseConfig] = TeleporterConfig,
|
__config__: t.Type[pydantic.BaseConfig] = TeleporterConfig,
|
||||||
extra_params: Dict[str, type] = None) -> Tuple[type, type]:
|
extra_params: t.Dict[str, type] = None) -> t.Tuple[type, type]:
|
||||||
"""
|
"""
|
||||||
Convert the signature of a function to two pydantic models: one for the input and another one for the output.
|
Convert the signature of a function to two pydantic models: one for the input and another one for the output.
|
||||||
|
|
||||||
|
@ -80,7 +86,7 @@ def signature_to_model(f: Callable,
|
||||||
return input_model, output_model
|
return input_model, output_model
|
||||||
|
|
||||||
|
|
||||||
def split_kwargs(**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
def split_kwargs(**kwargs) -> t.Tuple[t.Dict[str, t.Any], t.Dict[str, t.Any]]:
|
||||||
"""
|
"""
|
||||||
Split the kwargs passed to this function in two different :class:`dict`, based on whether their name starts with
|
Split the kwargs passed to this function in two different :class:`dict`, based on whether their name starts with
|
||||||
``_`` or not.
|
``_`` or not.
|
||||||
|
@ -129,7 +135,7 @@ def teleport_out(__model: type, value: Value) -> Value:
|
||||||
raise exc.OutTeleporterError(errors=e.raw_errors, model=e.model)
|
raise exc.OutTeleporterError(errors=e.raw_errors, model=e.model)
|
||||||
|
|
||||||
|
|
||||||
def teleporter(__config__: Type[pydantic.BaseConfig] = TeleporterConfig,
|
def teleporter(__config__: t.Type[pydantic.BaseConfig] = TeleporterConfig,
|
||||||
is_async: bool = False,
|
is_async: bool = False,
|
||||||
validate_input: bool = True,
|
validate_input: bool = True,
|
||||||
validate_output: bool = True):
|
validate_output: bool = True):
|
||||||
|
@ -148,7 +154,7 @@ def teleporter(__config__: Type[pydantic.BaseConfig] = TeleporterConfig,
|
||||||
|
|
||||||
.. seealso:: :func:`.signature_to_model`
|
.. seealso:: :func:`.signature_to_model`
|
||||||
"""
|
"""
|
||||||
def decorator(f: Callable):
|
def decorator(f: t.Callable):
|
||||||
# noinspection PyPep8Naming
|
# noinspection PyPep8Naming
|
||||||
InputModel, OutputModel = signature_to_model(f, __config__=__config__)
|
InputModel, OutputModel = signature_to_model(f, __config__=__config__)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import royalnet.royaltyping as t
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from . import *
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
|
@ -1,230 +0,0 @@
|
||||||
import pytest
|
|
||||||
import asyncio
|
|
||||||
import async_timeout
|
|
||||||
import re
|
|
||||||
from royalnet.engineer import sentry, exc, blueprints
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def s() -> sentry.Sentry:
|
|
||||||
return sentry.Sentry()
|
|
||||||
|
|
||||||
|
|
||||||
class TestSentry:
|
|
||||||
def test_creation(self, s: sentry.Sentry):
|
|
||||||
assert s
|
|
||||||
assert isinstance(s, sentry.Sentry)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_put(self, s: sentry.Sentry):
|
|
||||||
await s.queue.put(None)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get(self, s: sentry.Sentry):
|
|
||||||
await s.queue.put(None)
|
|
||||||
assert await s.queue.get() is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_f(self, s: sentry.Sentry):
|
|
||||||
await s.queue.put(None)
|
|
||||||
f = s.f()
|
|
||||||
assert f
|
|
||||||
assert isinstance(f, sentry.Filter)
|
|
||||||
assert hasattr(f, "get")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def discarding_filter() -> sentry.Filter:
|
|
||||||
async def discard(_):
|
|
||||||
raise exc.Discard(None, "This filter discards everything!")
|
|
||||||
|
|
||||||
return sentry.Filter(discard)
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorTest(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def error_test(*_, **__):
|
|
||||||
raise ErrorTest("This was raised by error_raiser.")
|
|
||||||
|
|
||||||
|
|
||||||
class TestFilter:
|
|
||||||
def test_creation(self):
|
|
||||||
f = sentry.Filter(lambda _: _)
|
|
||||||
assert f
|
|
||||||
assert isinstance(f, sentry.Filter)
|
|
||||||
|
|
||||||
class TestGetSingle:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_success(self, s: sentry.Sentry):
|
|
||||||
await s.queue.put(None)
|
|
||||||
assert await s.f().get_single() is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_failure(self, discarding_filter: sentry.Filter):
|
|
||||||
with pytest.raises(exc.Discard):
|
|
||||||
await discarding_filter.get_single()
|
|
||||||
|
|
||||||
class TestGet:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_success(self, s: sentry.Sentry):
|
|
||||||
await s.queue.put(None)
|
|
||||||
assert await s.f().get() is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_timeout(self, s: sentry.Sentry):
|
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
async with async_timeout.timeout(0.001):
|
|
||||||
await s.f().get()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_filter(self, s: sentry.Sentry):
|
|
||||||
await s.queue.put(None)
|
|
||||||
await s.queue.put(None)
|
|
||||||
await s.queue.put(None)
|
|
||||||
|
|
||||||
assert await s.f().filter(lambda x: x is None, "Is not None").get_single() is None
|
|
||||||
|
|
||||||
with pytest.raises(exc.Discard):
|
|
||||||
await s.f().filter(lambda x: isinstance(x, type), error="Is not type").get_single()
|
|
||||||
|
|
||||||
with pytest.raises(ErrorTest):
|
|
||||||
await s.f().filter(error_test, error="Is error").get_single()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_map(self, s: sentry.Sentry):
|
|
||||||
await s.queue.put(None)
|
|
||||||
await s.queue.put(None)
|
|
||||||
|
|
||||||
assert await s.f().map(lambda x: 1).get_single() == 1
|
|
||||||
|
|
||||||
with pytest.raises(ErrorTest):
|
|
||||||
await s.f().map(error_test).get_single()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_type(self, s: sentry.Sentry):
|
|
||||||
await s.queue.put(1)
|
|
||||||
await s.queue.put("no")
|
|
||||||
|
|
||||||
assert await s.f().type(int).get_single() == 1
|
|
||||||
|
|
||||||
with pytest.raises(exc.Discard):
|
|
||||||
await s.f().type(int).get_single()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_msg(self, s: sentry.Sentry):
|
|
||||||
class ExampleMessage(blueprints.Message):
|
|
||||||
def __hash__(self):
|
|
||||||
return 1
|
|
||||||
|
|
||||||
msg = ExampleMessage()
|
|
||||||
await s.queue.put(msg)
|
|
||||||
await s.queue.put("no")
|
|
||||||
|
|
||||||
assert await s.f().msg().get_single() is msg
|
|
||||||
|
|
||||||
with pytest.raises(exc.Discard):
|
|
||||||
await s.f().msg().get_single()
|
|
||||||
|
|
||||||
class AvailableMessage(blueprints.Message):
|
|
||||||
def __hash__(self):
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def text(self) -> str:
|
|
||||||
return "1"
|
|
||||||
|
|
||||||
class NotAvailableMessage(blueprints.Message):
|
|
||||||
def __hash__(self):
|
|
||||||
return 2
|
|
||||||
|
|
||||||
def text(self) -> str:
|
|
||||||
raise exc.NotAvailableError()
|
|
||||||
|
|
||||||
class NeverAvailableMessage(blueprints.Message):
|
|
||||||
def __hash__(self):
|
|
||||||
return 3
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_requires(self, s: sentry.Sentry):
|
|
||||||
avmsg = self.AvailableMessage()
|
|
||||||
await s.queue.put(avmsg)
|
|
||||||
assert await s.f().requires("text").get_single() is avmsg
|
|
||||||
|
|
||||||
await s.queue.put(self.NotAvailableMessage())
|
|
||||||
with pytest.raises(exc.Discard):
|
|
||||||
await s.f().requires("text").get_single()
|
|
||||||
|
|
||||||
await s.queue.put(self.NeverAvailableMessage())
|
|
||||||
with pytest.raises(exc.NeverAvailableError):
|
|
||||||
await s.f().requires("text").get_single()
|
|
||||||
|
|
||||||
await s.queue.put(self.NotAvailableMessage())
|
|
||||||
with pytest.raises(exc.NotAvailableError):
|
|
||||||
await s.f().requires("text", propagate_not_available=True).get_single()
|
|
||||||
|
|
||||||
await s.queue.put(self.NeverAvailableMessage())
|
|
||||||
with pytest.raises(exc.Discard):
|
|
||||||
await s.f().requires("text", propagate_never_available=False).get_single()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_field(self, s: sentry.Sentry):
|
|
||||||
avmsg = self.AvailableMessage()
|
|
||||||
await s.queue.put(avmsg)
|
|
||||||
assert await s.f().field("text").get_single() == "1"
|
|
||||||
|
|
||||||
await s.queue.put(self.NotAvailableMessage())
|
|
||||||
with pytest.raises(exc.Discard):
|
|
||||||
await s.f().field("text").get_single()
|
|
||||||
|
|
||||||
await s.queue.put(self.NeverAvailableMessage())
|
|
||||||
with pytest.raises(exc.NeverAvailableError):
|
|
||||||
await s.f().field("text").get_single()
|
|
||||||
|
|
||||||
await s.queue.put(self.NotAvailableMessage())
|
|
||||||
with pytest.raises(exc.NotAvailableError):
|
|
||||||
await s.f().field("text", propagate_not_available=True).get_single()
|
|
||||||
|
|
||||||
await s.queue.put(self.NeverAvailableMessage())
|
|
||||||
with pytest.raises(exc.Discard):
|
|
||||||
await s.f().field("text", propagate_never_available=False).get_single()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_startswith(self, s: sentry.Sentry):
|
|
||||||
await s.queue.put("yarrharr")
|
|
||||||
await s.queue.put("yohoho")
|
|
||||||
|
|
||||||
assert await s.f().startswith("yarr").get_single() == "yarrharr"
|
|
||||||
|
|
||||||
with pytest.raises(exc.Discard):
|
|
||||||
await s.f().startswith("yarr").get_single()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_endswith(self, s: sentry.Sentry):
|
|
||||||
await s.queue.put("yarrharr")
|
|
||||||
await s.queue.put("yohoho")
|
|
||||||
|
|
||||||
assert await s.f().endswith("harr").get_single() == "yarrharr"
|
|
||||||
|
|
||||||
with pytest.raises(exc.Discard):
|
|
||||||
await s.f().endswith("harr").get_single()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_regex(self, s: sentry.Sentry):
|
|
||||||
await s.queue.put("yarrharr")
|
|
||||||
await s.queue.put("yohoho")
|
|
||||||
|
|
||||||
assert isinstance(await s.f().regex(re.compile(r"[yh]arr")).get_single(), re.Match)
|
|
||||||
|
|
||||||
with pytest.raises(exc.Discard):
|
|
||||||
await s.f().regex(re.compile(r"[yh]arr")).get_single()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_choices(self, s: sentry.Sentry):
|
|
||||||
await s.queue.put("yarrharr")
|
|
||||||
await s.queue.put("yohoho")
|
|
||||||
|
|
||||||
assert await s.f().choices("yarrharr", "banana").get_single() == "yarrharr"
|
|
||||||
|
|
||||||
with pytest.raises(exc.Discard):
|
|
||||||
await s.f().choices("yarrharr", "banana").get_single()
|
|
|
@ -2,12 +2,12 @@ import pytest
|
||||||
import inspect
|
import inspect
|
||||||
import pydantic
|
import pydantic
|
||||||
import pydantic.fields
|
import pydantic.fields
|
||||||
import royalnet.engineer as re
|
import royalnet.engineer.teleporter as tp
|
||||||
import typing
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def my_function():
|
def my_function():
|
||||||
|
# noinspection PyUnusedLocal
|
||||||
def f(*, big_f: str, _hidden: int) -> int:
|
def f(*, big_f: str, _hidden: int) -> int:
|
||||||
return _hidden
|
return _hidden
|
||||||
return f
|
return f
|
||||||
|
@ -16,14 +16,15 @@ def my_function():
|
||||||
def test_parameter_to_field(my_function):
|
def test_parameter_to_field(my_function):
|
||||||
signature = inspect.signature(my_function)
|
signature = inspect.signature(my_function)
|
||||||
parameter = signature.parameters["big_f"]
|
parameter = signature.parameters["big_f"]
|
||||||
t, fieldinfo = re.parameter_to_field(parameter)
|
t, fieldinfo = tp.parameter_to_field(parameter)
|
||||||
assert isinstance(fieldinfo, pydantic.fields.FieldInfo)
|
assert isinstance(fieldinfo, pydantic.fields.FieldInfo)
|
||||||
assert fieldinfo.default is ...
|
assert fieldinfo.default is ...
|
||||||
assert fieldinfo.title == parameter.name == "big_f"
|
assert fieldinfo.title == parameter.name == "big_f"
|
||||||
|
|
||||||
|
|
||||||
def test_signature_to_model(my_function):
|
def test_signature_to_model(my_function):
|
||||||
InputModel, OutputModel = re.signature_to_model(my_function)
|
# noinspection PyPep8Naming
|
||||||
|
InputModel, OutputModel = tp.signature_to_model(my_function)
|
||||||
assert callable(InputModel)
|
assert callable(InputModel)
|
||||||
|
|
||||||
model = InputModel(big_f="banana")
|
model = InputModel(big_f="banana")
|
||||||
|
@ -58,7 +59,7 @@ def test_signature_to_model(my_function):
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
class TestTeleporter:
|
class TestTeleporter:
|
||||||
def test_standard_function(self):
|
def test_standard_function(self):
|
||||||
@re.teleporter()
|
@tp.teleporter()
|
||||||
def standard_function(a: int, b: int, _return_str: bool = False) -> int:
|
def standard_function(a: int, b: int, _return_str: bool = False) -> int:
|
||||||
if _return_str:
|
if _return_str:
|
||||||
return "You asked me this."
|
return "You asked me this."
|
||||||
|
@ -91,7 +92,7 @@ class TestTeleporter:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_function(self):
|
async def test_async_function(self):
|
||||||
@re.teleporter(is_async=True)
|
@tp.teleporter(is_async=True)
|
||||||
async def async_function(a: int, b: int, _return_str: bool = False) -> int:
|
async def async_function(a: int, b: int, _return_str: bool = False) -> int:
|
||||||
if _return_str:
|
if _return_str:
|
||||||
return "You asked me this."
|
return "You asked me this."
|
||||||
|
@ -123,7 +124,7 @@ class TestTeleporter:
|
||||||
_ = await async_function(1, 2)
|
_ = await async_function(1, 2)
|
||||||
|
|
||||||
def test_only_input(self):
|
def test_only_input(self):
|
||||||
@re.teleporter(validate_output=False)
|
@tp.teleporter(validate_output=False)
|
||||||
def standard_function(a: int, b: int, _return_str: bool = False) -> int:
|
def standard_function(a: int, b: int, _return_str: bool = False) -> int:
|
||||||
if _return_str:
|
if _return_str:
|
||||||
return "You asked me this."
|
return "You asked me this."
|
||||||
|
@ -154,7 +155,7 @@ class TestTeleporter:
|
||||||
_ = standard_function(1, 2)
|
_ = standard_function(1, 2)
|
||||||
|
|
||||||
def test_only_output(self):
|
def test_only_output(self):
|
||||||
@re.teleporter(validate_input=False)
|
@tp.teleporter(validate_input=False)
|
||||||
def standard_function(a: int, b: int, _return_str: bool = False) -> int:
|
def standard_function(a: int, b: int, _return_str: bool = False) -> int:
|
||||||
if _return_str:
|
if _return_str:
|
||||||
return "You asked me this."
|
return "You asked me this."
|
||||||
|
|
306
royalnet/engineer/wrench.py
Normal file
306
royalnet/engineer/wrench.py
Normal file
|
@ -0,0 +1,306 @@
|
||||||
|
"""
|
||||||
|
Wrenches are objects which can used instead of coroutines in Sentry receiver filters,
|
||||||
|
acting similarly to function factories and allowing to easily define filter functions with parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import royalnet.royaltyping as t
|
||||||
|
|
||||||
|
import abc
|
||||||
|
|
||||||
|
from . import discard
|
||||||
|
from . import exc
|
||||||
|
|
||||||
|
|
||||||
|
class Wrench(metaclass=abc.ABCMeta):
|
||||||
|
"""
|
||||||
|
The abstract base class for Wrenches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def filter(self, obj: t.Any) -> t.Any:
|
||||||
|
"""
|
||||||
|
The function applied to all objects transiting through the pipeline:
|
||||||
|
|
||||||
|
- If the function **returns**, its return value will be passed to the next node in the pipeline;
|
||||||
|
- If the function **raises**, the error is propagated downwards.
|
||||||
|
|
||||||
|
A special exception is available for discarding objects: :exc:`.discard.Discard`.
|
||||||
|
If raised, the object will be silently ignored.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def __call__(self, obj: t.Any) -> t.Awaitable[t.Any]:
|
||||||
|
"""
|
||||||
|
Allow instances to be directly called, emulating coroutine functions.
|
||||||
|
"""
|
||||||
|
return self.filter(obj)
|
||||||
|
|
||||||
|
|
||||||
|
class PassAll(Wrench):
|
||||||
|
"""
|
||||||
|
**Return** each received object as it is.
|
||||||
|
|
||||||
|
.. note:: To be used only in testing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def filter(self, obj: t.Any) -> t.Any:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class DiscardAll(Wrench):
|
||||||
|
"""
|
||||||
|
**Discard** each received object.
|
||||||
|
|
||||||
|
.. note:: To be used only in testing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def filter(self, obj: t.Any) -> t.Any:
|
||||||
|
raise discard.Discard(obj, "Discard filter discards everything")
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorAll(Wrench):
|
||||||
|
"""
|
||||||
|
**Raise** :exc:`.exc.DeliberateException` for each received object.
|
||||||
|
|
||||||
|
.. note:: To be used only in testing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def filter(self, obj: t.Any) -> t.Any:
|
||||||
|
raise exc.DeliberateException("ErrorAll received an object")
|
||||||
|
|
||||||
|
|
||||||
|
class CheckBase(Wrench, metaclass=abc.ABCMeta):
|
||||||
|
"""
|
||||||
|
Check a condition on the received objects:
|
||||||
|
|
||||||
|
- If the check returns :data:`True`, the object is **returned**;
|
||||||
|
- If the check returns :data:`False`, the object is **discarded**;
|
||||||
|
- If an error is raised, it is **propagated**.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, invert: bool = False):
|
||||||
|
self.invert: bool = invert
|
||||||
|
"""
|
||||||
|
If set to :data:`True`, this Nut will invert its results:
|
||||||
|
|
||||||
|
- If the check returns :data:`True`, the object is **discarded**;
|
||||||
|
- If the check returns :data:`False`, the object is **returned**;
|
||||||
|
- If an error is raised, it is **propagated**.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __invert__(self):
|
||||||
|
return self.__class__(invert=not self.invert)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def check(self, obj: t.Any) -> bool:
|
||||||
|
"""
|
||||||
|
The condition to check.
|
||||||
|
|
||||||
|
:param obj: The object passing through the pipeline.
|
||||||
|
:return: Whether the check was successful or not.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def error(self, obj: t.Any) -> str:
|
||||||
|
"""
|
||||||
|
The error message to attach as :attr:`.Discard.message` if the object is discarded.
|
||||||
|
|
||||||
|
:param obj: The object passing through the pipeline.
|
||||||
|
:return: The error message.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def filter(self, obj: t.Any) -> t.Any:
|
||||||
|
if await self.check(obj) ^ self.invert:
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
raise discard.Discard(obj=obj, message=self.error(obj))
|
||||||
|
|
||||||
|
|
||||||
|
class Type(CheckBase):
|
||||||
|
"""
|
||||||
|
Check the type of an object:
|
||||||
|
|
||||||
|
- If the object **is** of the specified type, it is **returned**;
|
||||||
|
- If the object **isn't** of the specified type, it is **discarded**;
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, type_: t.Type, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.type: t.Type = type_
|
||||||
|
|
||||||
|
async def check(self, obj: t.Any) -> bool:
|
||||||
|
return isinstance(obj, self.type)
|
||||||
|
|
||||||
|
def error(self, obj: t.Any) -> str:
|
||||||
|
return f"Not instance of type {self.type}"
|
||||||
|
|
||||||
|
|
||||||
|
class StartsWith(CheckBase):
|
||||||
|
"""
|
||||||
|
Check if an object :func:`startswith` a certain prefix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefix: str, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.prefix: str = prefix
|
||||||
|
|
||||||
|
async def check(self, obj: t.Any) -> bool:
|
||||||
|
return obj.startswith(self.prefix)
|
||||||
|
|
||||||
|
def error(self, obj: t.Any) -> str:
|
||||||
|
return f"Didn't start with {self.prefix}"
|
||||||
|
|
||||||
|
|
||||||
|
class EndsWith(CheckBase):
|
||||||
|
"""
|
||||||
|
Check if an object :func:`endswith` a certain suffix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, suffix: str, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.suffix: str = suffix
|
||||||
|
|
||||||
|
async def check(self, obj: t.Any) -> bool:
|
||||||
|
return obj.startswith(self.suffix)
|
||||||
|
|
||||||
|
def error(self, obj: t.Any) -> str:
|
||||||
|
return f"Didn't end with {self.suffix}"
|
||||||
|
|
||||||
|
|
||||||
|
class Choice(CheckBase):
|
||||||
|
"""
|
||||||
|
Check if an object is among the accepted list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *accepted, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.accepted: t.Collection = accepted
|
||||||
|
"""
|
||||||
|
A collection of elements which can be chosen.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def check(self, obj: t.Any) -> bool:
|
||||||
|
return obj in self.accepted
|
||||||
|
|
||||||
|
def error(self, obj: t.Any) -> str:
|
||||||
|
return f"Not a valid choice"
|
||||||
|
|
||||||
|
|
||||||
|
class RegexCheck(CheckBase):
|
||||||
|
"""
|
||||||
|
Check if an object matches a regex pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pattern: t.Pattern, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.pattern: t.Pattern = pattern
|
||||||
|
"""
|
||||||
|
The pattern that should be matched.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def check(self, obj: t.Any) -> bool:
|
||||||
|
return bool(self.pattern.match(obj))
|
||||||
|
|
||||||
|
def error(self, obj: t.Any) -> str:
|
||||||
|
return f"Didn't match pattern {self.pattern}"
|
||||||
|
|
||||||
|
|
||||||
|
class RegexMatch(Wrench):
|
||||||
|
"""
|
||||||
|
Apply a regex over an object:
|
||||||
|
|
||||||
|
- If it matches, **return** the :class:`re.Match` object;
|
||||||
|
- If it doesn't match, **discard** the object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pattern: t.Pattern, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.pattern: t.Pattern = pattern
|
||||||
|
"""
|
||||||
|
The pattern that should be matched.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def filter(self, obj: t.Any) -> t.Any:
|
||||||
|
if match := self.pattern.match(obj):
|
||||||
|
return match
|
||||||
|
else:
|
||||||
|
raise discard.Discard(obj, f"Didn't match pattern {obj}")
|
||||||
|
|
||||||
|
|
||||||
|
class RegexReplace(Wrench):
|
||||||
|
"""
|
||||||
|
Apply a regex over an object:
|
||||||
|
|
||||||
|
- If it matches, replace the match(es) with :attr:`.replacement` and **return** the result.
|
||||||
|
- If it doesn't match, **return** the object as it is.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pattern: t.Pattern, replacement: t.Union[str, bytes]):
|
||||||
|
self.pattern: t.Pattern = pattern
|
||||||
|
"""
|
||||||
|
The pattern that should be matched.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.replacement: t.Union[str, bytes] = replacement
|
||||||
|
"""
|
||||||
|
The substitution string for the object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def filter(self, obj: t.Any) -> t.Any:
|
||||||
|
return self.pattern.sub(self.replacement, obj)
|
||||||
|
|
||||||
|
|
||||||
|
class Lambda(Wrench):
|
||||||
|
"""
|
||||||
|
Apply a syncronous function over the received objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func: t.Callable[[t.Any], t.Any]):
|
||||||
|
self.func: t.Callable[[t.Any], t.Any] = func
|
||||||
|
"""
|
||||||
|
The function to apply.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def filter(self, obj: t.Any) -> t.Any:
|
||||||
|
return self.func(obj)
|
||||||
|
|
||||||
|
|
||||||
|
class Check(CheckBase):
|
||||||
|
"""
|
||||||
|
Check a condition on the received objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func: t.Callable[[t.Any], t.Any], error: str, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.func: t.Callable[[t.Any], t.Any] = func
|
||||||
|
"""
|
||||||
|
The condition to check.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.error: str = error
|
||||||
|
"""
|
||||||
|
The error message to display if the check fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def check(self, obj: t.Any) -> bool:
|
||||||
|
return self.func(obj)
|
||||||
|
|
||||||
|
def error(self, obj: t.Any) -> str:
|
||||||
|
return self.error
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
"Wrench",
|
||||||
|
"CheckBase",
|
||||||
|
"Type",
|
||||||
|
"StartsWith",
|
||||||
|
"EndsWith",
|
||||||
|
"Choice",
|
||||||
|
"RegexCheck",
|
||||||
|
"RegexMatch",
|
||||||
|
"RegexReplace",
|
||||||
|
"Lambda",
|
||||||
|
)
|
|
@ -74,3 +74,6 @@ An async generator yielding either:
|
||||||
* another :data:`.AsyncAdventure`;
|
* another :data:`.AsyncAdventure`;
|
||||||
* :data:`None`.
|
* :data:`None`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
Conversation = Callable[["Sentry"], Awaitable[Optional["Conversation"]]]
|
||||||
|
|
Loading…
Reference in a new issue