mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 03:24:20 +00:00
💥 Reimplement engineer module from scratch (#2)
This commit is contained in:
parent
133f503926
commit
715c0e72df
20 changed files with 726 additions and 828 deletions
|
@ -1,10 +1,5 @@
|
|||
"""
|
||||
A chatbot command router inspired by :mod:`fastapi`.
|
||||
Chat bot utilities.
|
||||
|
||||
All names are inspired by the `Engineer Class of Team Fortress 2 <https://wiki.teamfortress.com/wiki/Engineer>`_.
|
||||
"""
|
||||
|
||||
from .blueprints import *
|
||||
from .teleporter import *
|
||||
from .sentry import *
|
||||
from .exc import *
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
from .blueprint import *
|
||||
from .message import *
|
||||
from .channel import *
|
||||
from .user import *
|
|
@ -1,92 +0,0 @@
|
|||
import abc
|
||||
|
||||
from .. import exc
|
||||
|
||||
|
||||
class Blueprint(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
A class containing methods common between all blueprints.
|
||||
|
||||
To extend a blueprint, inherit from it while using the :class:`abc.ABCMeta` metaclass, and make all new functions
|
||||
return :exc:`.exc.NeverAvailableError`:
|
||||
|
||||
.. code-block::
|
||||
|
||||
class Channel(Blueprint, metaclass=abc.ABCMeta):
|
||||
def name(self):
|
||||
raise exc.NeverAvailableError()
|
||||
|
||||
To implement a blueprint for a specific chat platform, inherit from the blueprint, override :meth:`__init__`,
|
||||
:meth:`__hash__` and the methods that are implemented by the platform in question, either returning the
|
||||
corresponding value or raising :exc:`.exc.NotAvailableError` if there is no data available.
|
||||
|
||||
.. code-block::
|
||||
|
||||
class ExampleChannel(Channel):
|
||||
def __init__(self, chat_id: int):
|
||||
self.chat_id: int = chat_id
|
||||
|
||||
def __hash__(self):
|
||||
return self.chat_id
|
||||
|
||||
def name(self):
|
||||
return ExampleClient.expensive_get_channel_name(self.chat_id)
|
||||
|
||||
.. note:: To improve performance, you might want to wrap all data methods in :func:`functools.lru_cache` decorators.
|
||||
|
||||
.. code-block::
|
||||
|
||||
@functools.lru_cache(24)
|
||||
def name(self):
|
||||
return ExampleClient.expensive_get_channel_name(self.chat_id)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
:return: The created object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __hash__(self):
|
||||
"""
|
||||
:return: A value that uniquely identifies the channel inside this Python process.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def requires(self, *fields) -> True:
|
||||
"""
|
||||
Ensure that this blueprint has the specified fields, re-raising the highest priority exception raised between
|
||||
all of them.
|
||||
|
||||
.. code-block::
|
||||
|
||||
def print_msg(message: Message):
|
||||
message.requires("text", "timestamp")
|
||||
print(f"{message.timestamp().isoformat()}: {message.text()}")
|
||||
|
||||
:raises .exc.NeverAvailableError: If at least one of the fields raised a :exc:`.exc.NeverAvailableError`.
|
||||
:raises .exc.NotAvailableError: If no field raised a :exc:`.exc.NeverAvailableError`, but at least one raised a
|
||||
:exc:`.exc.NotAvailableError`.
|
||||
"""
|
||||
|
||||
exceptions = []
|
||||
|
||||
for field in fields:
|
||||
try:
|
||||
self.__getattribute__(field)()
|
||||
except exc.NeverAvailableError as ex:
|
||||
exceptions.append(ex)
|
||||
except exc.NotAvailableError as ex:
|
||||
exceptions.append(ex)
|
||||
|
||||
if len(exceptions) > 0:
|
||||
raise max(exceptions, key=lambda e: e.priority)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
__all__ = (
|
||||
"Blueprint",
|
||||
)
|
|
@ -1,35 +0,0 @@
|
|||
from __future__ import annotations
|
||||
from royalnet.royaltyping import *
|
||||
import abc
|
||||
|
||||
from .. import exc
|
||||
from .blueprint import Blueprint
|
||||
|
||||
|
||||
class Channel(Blueprint, metaclass=abc.ABCMeta):
|
||||
"""
|
||||
An abstract class representing a channel where messages can be sent.
|
||||
|
||||
.. seealso:: :class:`.Blueprint`
|
||||
"""
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
:return: The name of the message channel, such as the chat title.
|
||||
:raises .exc.NeverAvailableError: If the chat platform does not support channel names.
|
||||
:raises .exc.NotAvailableError: If this channel does not have any name.
|
||||
"""
|
||||
raise exc.NeverAvailableError()
|
||||
|
||||
def topic(self) -> str:
|
||||
"""
|
||||
:return: The topic or description of the message channel.
|
||||
:raises .exc.NeverAvailableError: If the chat platform does not support channel topics / descriptions.
|
||||
:raises .exc.NotAvailableError: If this channel does not have any name.
|
||||
"""
|
||||
raise exc.NeverAvailableError()
|
||||
|
||||
|
||||
__all__ = (
|
||||
"Channel",
|
||||
)
|
|
@ -1,53 +0,0 @@
|
|||
from __future__ import annotations
|
||||
from royalnet.royaltyping import *
|
||||
import abc
|
||||
import datetime
|
||||
|
||||
from .. import exc
|
||||
from .blueprint import Blueprint
|
||||
from .channel import Channel
|
||||
|
||||
|
||||
class Message(Blueprint, metaclass=abc.ABCMeta):
|
||||
"""
|
||||
An abstract class representing a chat message sent in any platform.
|
||||
|
||||
.. seealso:: :class:`.Blueprint`
|
||||
"""
|
||||
|
||||
def text(self) -> str:
|
||||
"""
|
||||
:return: The raw text contents of the message.
|
||||
:raises .exc.NeverAvailableError: If the chat platform does not support text messages.
|
||||
:raises .exc.NotAvailableError: If this message does not have any text.
|
||||
"""
|
||||
raise exc.NeverAvailableError()
|
||||
|
||||
def timestamp(self) -> datetime.datetime:
|
||||
"""
|
||||
:return: The :class:`datetime.datetime` at which the message was sent / received.
|
||||
:raises .exc.NeverAvailableError: If the chat platform does not support timestamps.
|
||||
:raises .exc.NotAvailableError: If this message is special and does not have any timestamp.
|
||||
"""
|
||||
raise exc.NeverAvailableError()
|
||||
|
||||
def reply_to(self) -> Message:
|
||||
"""
|
||||
:return: The :class:`.Message` this message is a reply to.
|
||||
:raises .exc.NeverAvailableError: If the chat platform does not support replies.
|
||||
:raises .exc.NotAvailableError: If this message is not a reply to any other message.
|
||||
"""
|
||||
raise exc.NeverAvailableError()
|
||||
|
||||
def channel(self) -> Channel:
|
||||
"""
|
||||
:return: The :class:`.Channel` this message was sent in.
|
||||
:raises .exc.NeverAvailableError: If the chat platform does not support channels.
|
||||
:raises .exc.NotAvailableError: If this message was not sent in any channel.
|
||||
"""
|
||||
raise exc.NeverAvailableError()
|
||||
|
||||
|
||||
__all__ = (
|
||||
"Message",
|
||||
)
|
|
@ -1,35 +0,0 @@
|
|||
from __future__ import annotations
|
||||
from royalnet.royaltyping import *
|
||||
import abc
|
||||
import sqlalchemy.orm
|
||||
|
||||
from .. import exc
|
||||
from .blueprint import Blueprint
|
||||
|
||||
|
||||
class User(Blueprint, metaclass=abc.ABCMeta):
|
||||
"""
|
||||
An abstract class representing a chat user.
|
||||
|
||||
.. seealso:: :class:`.Blueprint`
|
||||
"""
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
:return: The user's name.
|
||||
:raises .exc.NeverAvailableError: If the chat platform does not support usernames.
|
||||
:raises .exc.NotAvailableError: If this user does not have any name.
|
||||
"""
|
||||
raise exc.NeverAvailableError()
|
||||
|
||||
def database(self, session: sqlalchemy.orm.Session) -> Any:
|
||||
"""
|
||||
:param session: A :class:`sqlalchemy.orm.Session` instance to use to fetch the database entry.
|
||||
:return: The database entry for this user.
|
||||
"""
|
||||
raise exc.NeverAvailableError()
|
||||
|
||||
|
||||
__all__ = (
|
||||
"User",
|
||||
)
|
114
royalnet/engineer/bullet.py
Normal file
114
royalnet/engineer/bullet.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
"""
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import royalnet.royaltyping as t
|
||||
|
||||
import abc
|
||||
import datetime
|
||||
import sqlalchemy.orm
|
||||
|
||||
from . import exc
|
||||
|
||||
|
||||
class Bullet(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
The abstract base class for Bullet data models.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
:return: A value that uniquely identifies the object in this Python interpreter process.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Message(Bullet, metaclass=abc.ABCMeta):
|
||||
"""
|
||||
An abstract class representing a chat message.
|
||||
"""
|
||||
|
||||
async def text(self) -> t.Optional[str]:
|
||||
"""
|
||||
:return: The raw text contents of the message.
|
||||
:raises .exc.NotSupportedError: If the frontend does not support text messages.
|
||||
"""
|
||||
raise exc.NotSupportedError()
|
||||
|
||||
async def timestamp(self) -> t.Optional[datetime.datetime]:
|
||||
"""
|
||||
:return: The :class:`datetime.datetime` at which the message was sent.
|
||||
:raises .exc.NotSupportedError: If the frontend does not support timestamps.
|
||||
"""
|
||||
raise exc.NotSupportedError()
|
||||
|
||||
async def reply_to(self) -> t.Optional[Message]:
|
||||
"""
|
||||
:return: The :class:`.Message` this message is a reply to.
|
||||
:raises .exc.NotSupportedError: If the frontend does not support replies.
|
||||
"""
|
||||
raise exc.NotSupportedError()
|
||||
|
||||
async def channel(self) -> t.Optional[Channel]:
|
||||
"""
|
||||
:return: The :class:`.Channel` this message was sent in.
|
||||
:raises .exc.NotSupportedError: If the frontend does not support channels.
|
||||
"""
|
||||
raise exc.NotSupportedError()
|
||||
|
||||
|
||||
class Channel(Bullet, metaclass=abc.ABCMeta):
|
||||
"""
|
||||
An abstract class representing a channel where messages can be sent.
|
||||
"""
|
||||
|
||||
async def name(self) -> t.Optional[str]:
|
||||
"""
|
||||
:return: The name of the message channel, such as the chat title.
|
||||
:raises .exc.NotSupportedError: If the frontend does not support channel names.
|
||||
"""
|
||||
raise exc.NotSupportedError()
|
||||
|
||||
async def topic(self) -> t.Optional[str]:
|
||||
"""
|
||||
:return: The topic (description) of the message channel.
|
||||
:raises .exc.NotSupportedError: If the frontend does not support channel topics / descriptions.
|
||||
"""
|
||||
raise exc.NotSupportedError()
|
||||
|
||||
async def users(self) -> t.List[User]:
|
||||
"""
|
||||
:return: A :class:`list` of :class:`.User` who can read messages sent in the channel.
|
||||
:raises .exc.NotSupportedError: If the frontend does not support such a feature.
|
||||
"""
|
||||
raise exc.NotSupportedError()
|
||||
|
||||
|
||||
class User(Bullet, metaclass=abc.ABCMeta):
|
||||
"""
|
||||
An abstract class representing a user who can read or send messages in the chat.
|
||||
"""
|
||||
|
||||
async def name(self) -> t.Optional[str]:
|
||||
"""
|
||||
:return: The user's name.
|
||||
:raises .exc.NotSupportedError: If the frontend does not support usernames.
|
||||
"""
|
||||
raise exc.NotSupportedError()
|
||||
|
||||
async def database(self, session: sqlalchemy.orm.Session) -> t.Any:
|
||||
"""
|
||||
:param session: A :class:`sqlalchemy.orm.Session` instance to use to fetch the database entry.
|
||||
:return: The database entry for this user.
|
||||
"""
|
||||
raise exc.NotSupportedError()
|
||||
|
||||
|
||||
__all__ = (
|
||||
"Bullet",
|
||||
"Message",
|
||||
"Channel",
|
||||
"User",
|
||||
)
|
13
royalnet/engineer/discard.py
Normal file
13
royalnet/engineer/discard.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
class Discard(BaseException):
|
||||
"""
|
||||
A special exception which should be raised by Metals if a certain object should be discarded from the queue.
|
||||
"""
|
||||
def __init__(self, obj, message):
|
||||
self.obj = obj
|
||||
self.message = message
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Discard>"
|
||||
|
||||
def __str__(self):
|
||||
return f"Discarded {self.obj}: {self.message}"
|
|
@ -0,0 +1,51 @@
|
|||
"""
|
||||
Dispensers instantiate sentries and dispatch events in bulk to the whole group.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import royalnet.royaltyping as t
|
||||
|
||||
import logging
|
||||
import contextlib
|
||||
|
||||
from .sentry import SentrySource
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Dispenser:
|
||||
def __init__(self):
|
||||
self.sentries: t.List[SentrySource] = []
|
||||
"""
|
||||
A :class:`list` of all the running sentries of this dispenser.
|
||||
"""
|
||||
|
||||
def put(self, item: t.Any) -> None:
|
||||
"""
|
||||
Insert a new item in the queues of all the running sentries.
|
||||
|
||||
:param item: The item to insert.
|
||||
"""
|
||||
log.debug(f"Putting {item}")
|
||||
for sentry in self.sentries:
|
||||
sentry.put(item)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def sentry(self, *args, **kwargs):
|
||||
"""
|
||||
A context manager which creates a :class:`.SentrySource` and keeps it in :attr:`.sentries` while it is being
|
||||
used.
|
||||
"""
|
||||
sentry = SentrySource(dispenser=self, *args, **kwargs)
|
||||
self.sentries.append(sentry)
|
||||
|
||||
yield sentry
|
||||
|
||||
self.sentries.remove(sentry)
|
||||
|
||||
async def run(self, conv: t.Conversation) -> None:
|
||||
with self.sentry() as sentry:
|
||||
state = conv(sentry)
|
||||
|
||||
while True:
|
||||
state = await state
|
|
@ -1,38 +1,27 @@
|
|||
import royalnet.exc
|
||||
import pydantic
|
||||
|
||||
|
||||
class EngineerException(royalnet.exc.RoyalnetException):
|
||||
class EngineerException(Exception):
|
||||
"""
|
||||
An exception raised by the engineer module.
|
||||
The base class for errors in :mod:`royalnet.engineer`.
|
||||
"""
|
||||
|
||||
|
||||
class BlueprintError(EngineerException):
|
||||
class WrenchException(EngineerException):
|
||||
"""
|
||||
An error related to the :mod:`royalnet.engineer.blueprints`.
|
||||
The base class for errors in :mod:`royalnet.engineer.wrench`.
|
||||
"""
|
||||
|
||||
|
||||
class NeverAvailableError(BlueprintError, NotImplementedError):
|
||||
class DeliberateException(WrenchException):
|
||||
"""
|
||||
The requested property is never supplied by the chat platform the message was sent in.
|
||||
This exception was deliberately raised by :class:`royalnet.engineer.wrench.ErrorAll`.
|
||||
"""
|
||||
|
||||
priority = 1
|
||||
|
||||
|
||||
class NotAvailableError(BlueprintError):
|
||||
"""
|
||||
The requested property was not supplied by the chat platform for the specific message this exception was raised in.
|
||||
"""
|
||||
|
||||
priority = 2
|
||||
|
||||
|
||||
class TeleporterError(EngineerException, pydantic.ValidationError):
|
||||
"""
|
||||
The validation of some object though a :mod:`pydantic` model failed.
|
||||
The base class for errors in :mod:`royalnet.engineer.teleporter`.
|
||||
"""
|
||||
|
||||
|
||||
|
@ -48,28 +37,13 @@ class OutTeleporterError(TeleporterError):
|
|||
"""
|
||||
|
||||
|
||||
class SentryError(EngineerException):
|
||||
class BulletException(EngineerException):
|
||||
"""
|
||||
An error related to the :mod:`royalnet.engineer.sentry`.
|
||||
The base class for errors in :mod:`royalnet.engineer.bullet`.
|
||||
"""
|
||||
|
||||
|
||||
class FilterError(SentryError):
|
||||
class NotSupportedError(BulletException, NotImplementedError):
|
||||
"""
|
||||
An error related to the :class:`royalnet.engineer.sentry.Filter`.
|
||||
The requested property isn't available on the current frontend.
|
||||
"""
|
||||
|
||||
|
||||
class Discard(FilterError):
|
||||
"""
|
||||
Discard the object from the queue.
|
||||
"""
|
||||
def __init__(self, obj, message):
|
||||
self.obj = obj
|
||||
self.message = message
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Discard>"
|
||||
|
||||
def __str__(self):
|
||||
return f"Discarded {self.obj}: {self.message}"
|
||||
|
|
188
royalnet/engineer/sentry.py
Normal file
188
royalnet/engineer/sentry.py
Normal file
|
@ -0,0 +1,188 @@
|
|||
"""
|
||||
Sentries are asyncronous receivers for events (usually :class:`bullet.Bullet`) incoming from Dispensers.
|
||||
|
||||
They support event filtering through Wrenches and coroutine functions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import royalnet.royaltyping as t
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from . import discard
|
||||
from . import bullet
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .dispenser import Dispenser
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Sentry(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
The abstract object representing a node of the pipeline.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_nowait(self) -> bullet.Bullet:
|
||||
"""
|
||||
Try to get a single :class:`~.bullet.Bullet` from the pipeline, without blocking or handling discards.
|
||||
|
||||
:return: The **returned** :class:`~.bullet.Bullet`.
|
||||
:raises asyncio.QueueEmpty: If the queue is empty.
|
||||
:raises .discard.Discard: If the object was **discarded** by the pipeline.
|
||||
:raises Exception: If an exception was **raised** in the pipeline.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self) -> bullet.Bullet:
|
||||
"""
|
||||
Try to get a single :class:`~.bullet.Bullet` from the pipeline, blocking until something is available, but
|
||||
without handling discards.
|
||||
|
||||
:return: The **returned** :class:`~.bullet.Bullet`.
|
||||
:raises .discard.Discard: If the object was **discarded** by the pipeline.
|
||||
:raises Exception: If an exception was **raised** in the pipeline.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def wait(self) -> bullet.Bullet:
|
||||
"""
|
||||
Try to get a single :class:`~.bullet.Bullet` from the pipeline, blocking until something is available and is not
|
||||
discarded.
|
||||
|
||||
:return: The **returned** :class:`~.bullet.Bullet`.
|
||||
:raises Exception: If an exception was **raised** in the pipeline.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
result = await self.get()
|
||||
log.debug(f"Returned: {result}")
|
||||
return result
|
||||
except discard.Discard as d:
|
||||
log.debug(f"{str(d)}")
|
||||
continue
|
||||
|
||||
def __await__(self):
|
||||
"""
|
||||
Awaiting an object implementing :class:`.SentryInterface` corresponds to awaiting :meth:`.wait`.
|
||||
"""
|
||||
return self.get().__await__()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def put(self, item: t.Any) -> None:
|
||||
"""
|
||||
Insert a new item in the queue.
|
||||
|
||||
:param item: The item to be added.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def filter(self, wrench: t.Callable[[t.Any], t.Awaitable[t.Any]]) -> SentryFilter:
|
||||
"""
|
||||
Chain a new filter to the pipeline.
|
||||
|
||||
:param wrench: The filter to add to the chain. It can either be a :class:`.wrench.Wrench`, or a coroutine
|
||||
function accepting a single object as parameter and returning the same or a different one.
|
||||
:return: A new :class:`.SentryFilter` which includes the filter.
|
||||
|
||||
.. seealso:: :meth:`.__or__`
|
||||
"""
|
||||
if callable(wrench):
|
||||
return SentryFilter(previous=self, wrench=wrench)
|
||||
else:
|
||||
raise TypeError("wrench must be either a Wrench or a coroutine function")
|
||||
|
||||
def __or__(self, other: t.Callable[[t.Any], t.Awaitable[t.Any]]) -> SentryFilter:
|
||||
"""
|
||||
A unix-pipe-like interface for :meth:`.filter`.
|
||||
|
||||
.. code-block::
|
||||
|
||||
await (sentry | wrench.Type(Message) | wrench.Sync(lambda o: o.text))
|
||||
|
||||
"""
|
||||
try:
|
||||
return self.filter(other)
|
||||
except TypeError:
|
||||
raise TypeError("Right-side must be either a Wrench or a coroutine function")
|
||||
|
||||
@abc.abstractmethod
|
||||
def dispenser(self):
|
||||
"""
|
||||
Get the :class:`.Dispenser` that created this Sentry.
|
||||
|
||||
:return: The :class:`.Dispenser` object.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SentryFilter(Sentry):
|
||||
"""
|
||||
A non-root node of the filtering pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, previous: Sentry, wrench: t.Callable[[t.Any], t.Awaitable[t.Any]]):
|
||||
self.previous: Sentry = previous
|
||||
"""
|
||||
The previous node of the pipeline.
|
||||
"""
|
||||
|
||||
self.wrench: t.Callable[[t.Any], t.Awaitable[t.Any]] = wrench
|
||||
"""
|
||||
The coroutine function to apply to all objects passing through this node.
|
||||
"""
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.previous) + 1
|
||||
|
||||
def get_nowait(self) -> bullet.Bullet:
|
||||
return self.previous.get_nowait()
|
||||
|
||||
async def get(self) -> bullet.Bullet:
|
||||
return await self.previous.get()
|
||||
|
||||
async def put(self, item) -> None:
|
||||
return await self.previous.put(item)
|
||||
|
||||
def dispenser(self):
|
||||
return self.previous.dispenser()
|
||||
|
||||
|
||||
class SentrySource(Sentry):
|
||||
"""
|
||||
The root and source of the pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, dispenser: "Dispenser", queue_size: int = 12):
|
||||
self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_size)
|
||||
self._dispenser: "Dispenser" = dispenser
|
||||
|
||||
def __len__(self) -> int:
|
||||
return 1
|
||||
|
||||
def get_nowait(self) -> bullet.Bullet:
|
||||
return self.queue.get_nowait()
|
||||
|
||||
async def get(self) -> bullet.Bullet:
|
||||
return await self.queue.get()
|
||||
|
||||
async def put(self, item) -> None:
|
||||
return await self.queue.put(bullet)
|
||||
|
||||
async def dispenser(self):
|
||||
return self._dispenser
|
||||
|
||||
|
||||
__all__ = (
|
||||
"Sentry",
|
||||
"SentryFilter",
|
||||
"SentrySource",
|
||||
)
|
|
@ -1 +0,0 @@
|
|||
from .sentry import *
|
|
@ -1,257 +0,0 @@
|
|||
from __future__ import annotations
|
||||
from royalnet.royaltyping import *
|
||||
import functools
|
||||
import logging
|
||||
|
||||
from .. import exc, blueprints
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Filter:
|
||||
"""
|
||||
A fluent interface for filtering data.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable):
|
||||
self.func: Callable = func
|
||||
|
||||
async def get(self) -> Any:
|
||||
"""
|
||||
Wait until an :class:`object` leaves the queue and passes through the filter, then return it.
|
||||
|
||||
:return: The :class:`object` which left the queue.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
return await self.get_single()
|
||||
except exc.Discard:
|
||||
continue
|
||||
|
||||
async def get_single(self) -> Any:
|
||||
"""
|
||||
Let one :class:`object` pass through the filter, then either return it or raise an error if the object should be
|
||||
discarded.
|
||||
|
||||
:return: The :class:`object` which left the queue.
|
||||
:raises exc.Discard: If the object was filtered.
|
||||
"""
|
||||
try:
|
||||
result = await self.func(None)
|
||||
except exc.Discard as e:
|
||||
log.debug(str(e))
|
||||
raise
|
||||
else:
|
||||
log.debug(f"Dequeued {result}")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _deco_filter(c: Callable[[Any], bool], *, error: str):
|
||||
"""
|
||||
A decorator which checks the condition ``c`` on all objects transiting through the queue:
|
||||
|
||||
- If the check **passes**, the object itself is returned;
|
||||
- If the check **fails**, :exc:`.exc.Discard` is raised, with the object and the ``error`` string as parameters;
|
||||
- If an error is raised, propagate the error upwards.
|
||||
|
||||
.. warning:: Raising :exc:`.exc.Discard` in ``c`` will automatically cause the object to be discarded, as if
|
||||
:data:`False` was returned.
|
||||
|
||||
:param c: A function that takes in input an enqueued object and returns either the same object or a new one to
|
||||
pass to the next filter in the queue.
|
||||
:param error: The string that :exc:`.exc.Discard` should display if the object is discarded.
|
||||
"""
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def decorated(obj):
|
||||
result: Any = await func(obj)
|
||||
if c(result):
|
||||
return result
|
||||
else:
|
||||
raise exc.Discard(obj=result, message=error)
|
||||
return decorated
|
||||
return decorator
|
||||
|
||||
def filter(self, c: Callable[[Any], bool], error: str) -> Filter:
|
||||
"""
|
||||
Check the condition ``c`` on all objects transiting through the queue:
|
||||
|
||||
- If the check **passes**, the object goes on to the next filter;
|
||||
- If the check **fails**, the object is discarded, with ``error`` as reason;
|
||||
- If an error is raised, propagate the error upwards.
|
||||
|
||||
:param c: A function that takes in input an object and performs a check on it, returning either :data:`True`
|
||||
or :data:`False`.
|
||||
:param error: The reason for which objects should be discarded.
|
||||
:return: A new :class:`Filter` with this new condition.
|
||||
|
||||
.. seealso:: :meth:`._deco_filter`, :func:`filter`
|
||||
"""
|
||||
return self.__class__(self._deco_filter(c, error=error)(self.func))
|
||||
|
||||
@staticmethod
|
||||
def _deco_map(c: Callable[[Any], object]):
|
||||
"""
|
||||
A decorator which applies the function ``c`` on all objects transiting through the queue:
|
||||
|
||||
- If the function **returns**, return its return value;
|
||||
- If the function **raises** an error, it is propagated upwards.
|
||||
|
||||
:param c: A function that takes in input an enqueued object and returns either the same object or something
|
||||
else.
|
||||
|
||||
.. seealso:: :func:`map`
|
||||
"""
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def decorated(obj):
|
||||
result: Any = await func(obj)
|
||||
return c(result)
|
||||
return decorated
|
||||
return decorator
|
||||
|
||||
def map(self, c: Callable[[Any], object]) -> Filter:
|
||||
"""
|
||||
Apply the function ``c`` on all objects transiting through the queue:
|
||||
|
||||
- If the function **returns**, its return value replaces the object in the queue;
|
||||
- If the function **raises** :exc:`.exc.Discard`, the object is discarded;
|
||||
- If the function **raises another error**, propagate the error upwards.
|
||||
|
||||
:param c: A function that takes in input an enqueued object and returns either the same object or something
|
||||
else.
|
||||
:return: A new :class:`Filter` with this new condition.
|
||||
|
||||
.. seealso:: :meth:`._deco_map`, :func:`filter`
|
||||
"""
|
||||
return self.__class__(self._deco_map(c)(self.func))
|
||||
|
||||
def type(self, t: type) -> Filter:
|
||||
"""
|
||||
Check if an object passing through the queue :func:`isinstance` of the type ``t``.
|
||||
|
||||
:param t: The type that objects should be instances of.
|
||||
:return: A new :class:`Filter` with this new condition.
|
||||
|
||||
.. seealso:: :func:`isinstance`
|
||||
"""
|
||||
return self.filter(lambda o: isinstance(o, t), error=f"Not instance of type {t}")
|
||||
|
||||
def msg(self) -> Filter:
|
||||
"""
|
||||
Check if an object passing through the queue :func:`isinstance` of :class:`.blueprints.Message`.
|
||||
|
||||
:return: A new :class:`Filter` with this new condition.
|
||||
"""
|
||||
return self.type(blueprints.Message)
|
||||
|
||||
def requires(self, *fields,
|
||||
propagate_not_available=False,
|
||||
propagate_never_available=True) -> Filter:
|
||||
"""
|
||||
Test a :class:`.blueprints.Blueprint`'s fields by using its ``.requires()`` method:
|
||||
|
||||
- If the :class:`.blueprints.Blueprint` has the appropriate fields, return it;
|
||||
- If the :class:`.blueprints.Blueprint` doesn't have data for at least one of the fields, the object is discarded;
|
||||
- the :class:`.blueprints.Blueprint` never has data for at least one of the fields, :exc:`.exc.NotAvailableError` is propagated upwards.
|
||||
|
||||
:param fields: The fields to test for.
|
||||
:param propagate_not_available: If :exc:`.exc.NotAvailableError` should be propagated
|
||||
instead of discarding the errored object.
|
||||
:param propagate_never_available: If :exc:`.exc.NeverAvailableError` should be propagated
|
||||
instead of discarding the errored object.
|
||||
:return: A new :class:`Filter` with this new condition.
|
||||
|
||||
.. seealso:: :meth:`.blueprints.Blueprint.requires`
|
||||
"""
|
||||
def check(obj):
|
||||
try:
|
||||
return obj.requires(*fields)
|
||||
except exc.NotAvailableError:
|
||||
if propagate_not_available:
|
||||
raise
|
||||
raise exc.Discard(obj, "Data is not available")
|
||||
except exc.NeverAvailableError:
|
||||
if propagate_never_available:
|
||||
raise
|
||||
raise exc.Discard(obj, "Data is never available")
|
||||
|
||||
return self.filter(check, error=".requires() method returned False")
|
||||
|
||||
def field(self, field: str,
|
||||
propagate_not_available=False,
|
||||
propagate_never_available=True) -> Filter:
|
||||
"""
|
||||
Replace a :class:`.blueprints.Blueprint` with the value of one of its fields.
|
||||
|
||||
:param field: The field to access.
|
||||
:param propagate_not_available: If :exc:`.exc.NotAvailableError` should be propagated
|
||||
instead of discarding the errored object.
|
||||
:param propagate_never_available: If :exc:`.exc.NeverAvailableError` should be propagated
|
||||
instead of discarding the errored object.
|
||||
:return: A new :class:`Filter` with the new requirements.
|
||||
"""
|
||||
def replace(obj):
|
||||
try:
|
||||
return obj.__getattribute__(field)()
|
||||
except exc.NotAvailableError:
|
||||
if propagate_not_available:
|
||||
raise
|
||||
raise exc.Discard(obj, "Data is not available")
|
||||
except exc.NeverAvailableError:
|
||||
if propagate_never_available:
|
||||
raise
|
||||
raise exc.Discard(obj, "Data is never available")
|
||||
|
||||
return self.map(replace)
|
||||
|
||||
def startswith(self, prefix: str):
|
||||
"""
|
||||
Check if an object starts with the specified prefix and discard the objects that do not.
|
||||
|
||||
:param prefix: The prefix object should start with.
|
||||
:return: A new :class:`Filter` with the new requirements.
|
||||
|
||||
.. seealso:: :meth:`str.startswith`
|
||||
"""
|
||||
return self.filter(lambda x: x.startswith(prefix), error=f"Text didn't start with {prefix}")
|
||||
|
||||
def endswith(self, suffix: str):
|
||||
"""
|
||||
Check if an object ends with the specified suffix and discard the objects that do not.
|
||||
|
||||
:param suffix: The prefix object should start with.
|
||||
:return: A new :class:`Filter` with the new requirements.
|
||||
|
||||
.. seealso:: :meth:`str.endswith`
|
||||
"""
|
||||
return self.filter(lambda x: x.endswith(suffix), error=f"Text didn't end with {suffix}")
|
||||
|
||||
def regex(self, pattern: Pattern):
|
||||
"""
|
||||
Apply a regex over an object and discard the object if it does not match.
|
||||
|
||||
:param pattern: The pattern that should be matched by the text.
|
||||
:return: A new :class:`Filter` with the new requirements.
|
||||
"""
|
||||
def mapping(x):
|
||||
if match := pattern.match(x):
|
||||
return match
|
||||
else:
|
||||
raise exc.Discard(x, f"Text didn't match pattern {pattern}")
|
||||
|
||||
return self.map(mapping)
|
||||
|
||||
def choices(self, *choices):
|
||||
"""
|
||||
Ensure an object is in the ``choices`` list, discarding the object otherwise.
|
||||
|
||||
:param choices: The pattern that should be matched by the text.
|
||||
:return: A new :class:`Filter` with the new requirements.
|
||||
"""
|
||||
return self.filter(lambda o: o in choices, error="Not a valid choice")
|
||||
|
||||
|
||||
__all__ = (
|
||||
"Filter",
|
||||
)
|
|
@ -1,58 +0,0 @@
|
|||
from __future__ import annotations
|
||||
from royalnet.royaltyping import *
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from .filter import Filter
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Sentry:
|
||||
"""
|
||||
A class that allows using the ``await`` keyword to suspend a command execution until a new message is received.
|
||||
"""
|
||||
|
||||
QUEUE_SIZE = 12
|
||||
"""
|
||||
The size of the object :attr:`.queue`.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_type: Type[Filter] = Filter):
|
||||
self.queue: asyncio.Queue = asyncio.Queue(maxsize=self.QUEUE_SIZE)
|
||||
"""
|
||||
An object queue where incoming :class:`object` are stored, with a size limit of :attr:`.QUEUE_SIZE`.
|
||||
"""
|
||||
|
||||
self.filter_type: Type[Filter] = filter_type
|
||||
"""
|
||||
The filter to be used in :meth:`.f` calls, by default :class:`.filters.Filter`.
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Sentry>"
|
||||
|
||||
def f(self):
|
||||
"""
|
||||
Create a :attr:`.filter_type` object, which can be configured through its fluent interface.
|
||||
|
||||
Remember to call ``.get()`` on the end of the chain to finally get the object.
|
||||
|
||||
To get any object, call:
|
||||
|
||||
.. code-block::
|
||||
|
||||
await sentry.f().get()
|
||||
|
||||
.. seealso:: :class:`.filters.Filter`
|
||||
|
||||
:return: The created :class:`.filters.Filter`.
|
||||
"""
|
||||
async def func(_):
|
||||
return await self.queue.get()
|
||||
return self.filter_type(func)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"Sentry",
|
||||
)
|
|
@ -1,13 +1,19 @@
|
|||
import functools
|
||||
"""
|
||||
The teleporter uses :mod:`pydantic` to validate function parameters and return values.
|
||||
"""
|
||||
|
||||
from royalnet.royaltyping import *
|
||||
from __future__ import annotations
|
||||
import royalnet.royaltyping as t
|
||||
|
||||
import logging
|
||||
import pydantic
|
||||
import inspect
|
||||
import functools
|
||||
|
||||
from . import exc
|
||||
|
||||
|
||||
Model = TypeVar("Model")
|
||||
Value = TypeVar("Value")
|
||||
Value = t.TypeVar("Value")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeleporterConfig(pydantic.BaseConfig):
|
||||
|
@ -17,7 +23,7 @@ class TeleporterConfig(pydantic.BaseConfig):
|
|||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
def parameter_to_field(param: inspect.Parameter, **kwargs) -> Tuple[type, pydantic.fields.FieldInfo]:
|
||||
def parameter_to_field(param: inspect.Parameter, **kwargs) -> t.Tuple[type, pydantic.fields.FieldInfo]:
|
||||
"""
|
||||
Convert a :class:`inspect.Parameter` to a type-field :class:`tuple`, which can be easily passed to
|
||||
:func:`pydantic.create_model`.
|
||||
|
@ -45,9 +51,9 @@ def parameter_to_field(param: inspect.Parameter, **kwargs) -> Tuple[type, pydant
|
|||
)
|
||||
|
||||
|
||||
def signature_to_model(f: Callable,
|
||||
__config__: Type[pydantic.BaseConfig] = TeleporterConfig,
|
||||
extra_params: Dict[str, type] = None) -> Tuple[type, type]:
|
||||
def signature_to_model(f: t.Callable,
|
||||
__config__: t.Type[pydantic.BaseConfig] = TeleporterConfig,
|
||||
extra_params: t.Dict[str, type] = None) -> t.Tuple[type, type]:
|
||||
"""
|
||||
Convert the signature of a function to two pydantic models: one for the input and another one for the output.
|
||||
|
||||
|
@ -80,7 +86,7 @@ def signature_to_model(f: Callable,
|
|||
return input_model, output_model
|
||||
|
||||
|
||||
def split_kwargs(**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
def split_kwargs(**kwargs) -> t.Tuple[t.Dict[str, t.Any], t.Dict[str, t.Any]]:
|
||||
"""
|
||||
Split the kwargs passed to this function in two different :class:`dict`, based on whether their name starts with
|
||||
``_`` or not.
|
||||
|
@ -129,7 +135,7 @@ def teleport_out(__model: type, value: Value) -> Value:
|
|||
raise exc.OutTeleporterError(errors=e.raw_errors, model=e.model)
|
||||
|
||||
|
||||
def teleporter(__config__: Type[pydantic.BaseConfig] = TeleporterConfig,
|
||||
def teleporter(__config__: t.Type[pydantic.BaseConfig] = TeleporterConfig,
|
||||
is_async: bool = False,
|
||||
validate_input: bool = True,
|
||||
validate_output: bool = True):
|
||||
|
@ -148,7 +154,7 @@ def teleporter(__config__: Type[pydantic.BaseConfig] = TeleporterConfig,
|
|||
|
||||
.. seealso:: :func:`.signature_to_model`
|
||||
"""
|
||||
def decorator(f: Callable):
|
||||
def decorator(f: t.Callable):
|
||||
# noinspection PyPep8Naming
|
||||
InputModel, OutputModel = signature_to_model(f, __config__=__config__)
|
||||
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
"""
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import royalnet.royaltyping as t
|
||||
|
||||
import logging
|
||||
|
||||
from . import *
|
||||
|
||||
log = logging.getLogger(__name__)
|
|
@ -1,230 +0,0 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
import async_timeout
|
||||
import re
|
||||
from royalnet.engineer import sentry, exc, blueprints
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s() -> sentry.Sentry:
|
||||
return sentry.Sentry()
|
||||
|
||||
|
||||
class TestSentry:
|
||||
def test_creation(self, s: sentry.Sentry):
|
||||
assert s
|
||||
assert isinstance(s, sentry.Sentry)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_put(self, s: sentry.Sentry):
|
||||
await s.queue.put(None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get(self, s: sentry.Sentry):
|
||||
await s.queue.put(None)
|
||||
assert await s.queue.get() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_f(self, s: sentry.Sentry):
|
||||
await s.queue.put(None)
|
||||
f = s.f()
|
||||
assert f
|
||||
assert isinstance(f, sentry.Filter)
|
||||
assert hasattr(f, "get")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def discarding_filter() -> sentry.Filter:
|
||||
async def discard(_):
|
||||
raise exc.Discard(None, "This filter discards everything!")
|
||||
|
||||
return sentry.Filter(discard)
|
||||
|
||||
|
||||
class ErrorTest(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def error_test(*_, **__):
|
||||
raise ErrorTest("This was raised by error_raiser.")
|
||||
|
||||
|
||||
class TestFilter:
|
||||
def test_creation(self):
|
||||
f = sentry.Filter(lambda _: _)
|
||||
assert f
|
||||
assert isinstance(f, sentry.Filter)
|
||||
|
||||
class TestGetSingle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success(self, s: sentry.Sentry):
|
||||
await s.queue.put(None)
|
||||
assert await s.f().get_single() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure(self, discarding_filter: sentry.Filter):
|
||||
with pytest.raises(exc.Discard):
|
||||
await discarding_filter.get_single()
|
||||
|
||||
class TestGet:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success(self, s: sentry.Sentry):
|
||||
await s.queue.put(None)
|
||||
assert await s.f().get() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout(self, s: sentry.Sentry):
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
async with async_timeout.timeout(0.001):
|
||||
await s.f().get()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter(self, s: sentry.Sentry):
|
||||
await s.queue.put(None)
|
||||
await s.queue.put(None)
|
||||
await s.queue.put(None)
|
||||
|
||||
assert await s.f().filter(lambda x: x is None, "Is not None").get_single() is None
|
||||
|
||||
with pytest.raises(exc.Discard):
|
||||
await s.f().filter(lambda x: isinstance(x, type), error="Is not type").get_single()
|
||||
|
||||
with pytest.raises(ErrorTest):
|
||||
await s.f().filter(error_test, error="Is error").get_single()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map(self, s: sentry.Sentry):
|
||||
await s.queue.put(None)
|
||||
await s.queue.put(None)
|
||||
|
||||
assert await s.f().map(lambda x: 1).get_single() == 1
|
||||
|
||||
with pytest.raises(ErrorTest):
|
||||
await s.f().map(error_test).get_single()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_type(self, s: sentry.Sentry):
|
||||
await s.queue.put(1)
|
||||
await s.queue.put("no")
|
||||
|
||||
assert await s.f().type(int).get_single() == 1
|
||||
|
||||
with pytest.raises(exc.Discard):
|
||||
await s.f().type(int).get_single()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_msg(self, s: sentry.Sentry):
|
||||
class ExampleMessage(blueprints.Message):
|
||||
def __hash__(self):
|
||||
return 1
|
||||
|
||||
msg = ExampleMessage()
|
||||
await s.queue.put(msg)
|
||||
await s.queue.put("no")
|
||||
|
||||
assert await s.f().msg().get_single() is msg
|
||||
|
||||
with pytest.raises(exc.Discard):
|
||||
await s.f().msg().get_single()
|
||||
|
||||
class AvailableMessage(blueprints.Message):
|
||||
def __hash__(self):
|
||||
return 1
|
||||
|
||||
def text(self) -> str:
|
||||
return "1"
|
||||
|
||||
class NotAvailableMessage(blueprints.Message):
|
||||
def __hash__(self):
|
||||
return 2
|
||||
|
||||
def text(self) -> str:
|
||||
raise exc.NotAvailableError()
|
||||
|
||||
class NeverAvailableMessage(blueprints.Message):
|
||||
def __hash__(self):
|
||||
return 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires(self, s: sentry.Sentry):
|
||||
avmsg = self.AvailableMessage()
|
||||
await s.queue.put(avmsg)
|
||||
assert await s.f().requires("text").get_single() is avmsg
|
||||
|
||||
await s.queue.put(self.NotAvailableMessage())
|
||||
with pytest.raises(exc.Discard):
|
||||
await s.f().requires("text").get_single()
|
||||
|
||||
await s.queue.put(self.NeverAvailableMessage())
|
||||
with pytest.raises(exc.NeverAvailableError):
|
||||
await s.f().requires("text").get_single()
|
||||
|
||||
await s.queue.put(self.NotAvailableMessage())
|
||||
with pytest.raises(exc.NotAvailableError):
|
||||
await s.f().requires("text", propagate_not_available=True).get_single()
|
||||
|
||||
await s.queue.put(self.NeverAvailableMessage())
|
||||
with pytest.raises(exc.Discard):
|
||||
await s.f().requires("text", propagate_never_available=False).get_single()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_field(self, s: sentry.Sentry):
|
||||
avmsg = self.AvailableMessage()
|
||||
await s.queue.put(avmsg)
|
||||
assert await s.f().field("text").get_single() == "1"
|
||||
|
||||
await s.queue.put(self.NotAvailableMessage())
|
||||
with pytest.raises(exc.Discard):
|
||||
await s.f().field("text").get_single()
|
||||
|
||||
await s.queue.put(self.NeverAvailableMessage())
|
||||
with pytest.raises(exc.NeverAvailableError):
|
||||
await s.f().field("text").get_single()
|
||||
|
||||
await s.queue.put(self.NotAvailableMessage())
|
||||
with pytest.raises(exc.NotAvailableError):
|
||||
await s.f().field("text", propagate_not_available=True).get_single()
|
||||
|
||||
await s.queue.put(self.NeverAvailableMessage())
|
||||
with pytest.raises(exc.Discard):
|
||||
await s.f().field("text", propagate_never_available=False).get_single()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_startswith(self, s: sentry.Sentry):
|
||||
await s.queue.put("yarrharr")
|
||||
await s.queue.put("yohoho")
|
||||
|
||||
assert await s.f().startswith("yarr").get_single() == "yarrharr"
|
||||
|
||||
with pytest.raises(exc.Discard):
|
||||
await s.f().startswith("yarr").get_single()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endswith(self, s: sentry.Sentry):
|
||||
await s.queue.put("yarrharr")
|
||||
await s.queue.put("yohoho")
|
||||
|
||||
assert await s.f().endswith("harr").get_single() == "yarrharr"
|
||||
|
||||
with pytest.raises(exc.Discard):
|
||||
await s.f().endswith("harr").get_single()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regex(self, s: sentry.Sentry):
|
||||
await s.queue.put("yarrharr")
|
||||
await s.queue.put("yohoho")
|
||||
|
||||
assert isinstance(await s.f().regex(re.compile(r"[yh]arr")).get_single(), re.Match)
|
||||
|
||||
with pytest.raises(exc.Discard):
|
||||
await s.f().regex(re.compile(r"[yh]arr")).get_single()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_choices(self, s: sentry.Sentry):
|
||||
await s.queue.put("yarrharr")
|
||||
await s.queue.put("yohoho")
|
||||
|
||||
assert await s.f().choices("yarrharr", "banana").get_single() == "yarrharr"
|
||||
|
||||
with pytest.raises(exc.Discard):
|
||||
await s.f().choices("yarrharr", "banana").get_single()
|
|
@ -2,12 +2,12 @@ import pytest
|
|||
import inspect
|
||||
import pydantic
|
||||
import pydantic.fields
|
||||
import royalnet.engineer as re
|
||||
import typing
|
||||
import royalnet.engineer.teleporter as tp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def my_function():
|
||||
# noinspection PyUnusedLocal
|
||||
def f(*, big_f: str, _hidden: int) -> int:
|
||||
return _hidden
|
||||
return f
|
||||
|
@ -16,14 +16,15 @@ def my_function():
|
|||
def test_parameter_to_field(my_function):
|
||||
signature = inspect.signature(my_function)
|
||||
parameter = signature.parameters["big_f"]
|
||||
t, fieldinfo = re.parameter_to_field(parameter)
|
||||
t, fieldinfo = tp.parameter_to_field(parameter)
|
||||
assert isinstance(fieldinfo, pydantic.fields.FieldInfo)
|
||||
assert fieldinfo.default is ...
|
||||
assert fieldinfo.title == parameter.name == "big_f"
|
||||
|
||||
|
||||
def test_signature_to_model(my_function):
|
||||
InputModel, OutputModel = re.signature_to_model(my_function)
|
||||
# noinspection PyPep8Naming
|
||||
InputModel, OutputModel = tp.signature_to_model(my_function)
|
||||
assert callable(InputModel)
|
||||
|
||||
model = InputModel(big_f="banana")
|
||||
|
@ -58,7 +59,7 @@ def test_signature_to_model(my_function):
|
|||
# noinspection PyTypeChecker
|
||||
class TestTeleporter:
|
||||
def test_standard_function(self):
|
||||
@re.teleporter()
|
||||
@tp.teleporter()
|
||||
def standard_function(a: int, b: int, _return_str: bool = False) -> int:
|
||||
if _return_str:
|
||||
return "You asked me this."
|
||||
|
@ -91,7 +92,7 @@ class TestTeleporter:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function(self):
|
||||
@re.teleporter(is_async=True)
|
||||
@tp.teleporter(is_async=True)
|
||||
async def async_function(a: int, b: int, _return_str: bool = False) -> int:
|
||||
if _return_str:
|
||||
return "You asked me this."
|
||||
|
@ -123,7 +124,7 @@ class TestTeleporter:
|
|||
_ = await async_function(1, 2)
|
||||
|
||||
def test_only_input(self):
|
||||
@re.teleporter(validate_output=False)
|
||||
@tp.teleporter(validate_output=False)
|
||||
def standard_function(a: int, b: int, _return_str: bool = False) -> int:
|
||||
if _return_str:
|
||||
return "You asked me this."
|
||||
|
@ -154,7 +155,7 @@ class TestTeleporter:
|
|||
_ = standard_function(1, 2)
|
||||
|
||||
def test_only_output(self):
|
||||
@re.teleporter(validate_input=False)
|
||||
@tp.teleporter(validate_input=False)
|
||||
def standard_function(a: int, b: int, _return_str: bool = False) -> int:
|
||||
if _return_str:
|
||||
return "You asked me this."
|
||||
|
|
306
royalnet/engineer/wrench.py
Normal file
306
royalnet/engineer/wrench.py
Normal file
|
@ -0,0 +1,306 @@
|
|||
"""
|
||||
Wrenches are objects which can used instead of coroutines in Sentry receiver filters,
|
||||
acting similarly to function factories and allowing to easily define filter functions with parameters.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import royalnet.royaltyping as t
|
||||
|
||||
import abc
|
||||
|
||||
from . import discard
|
||||
from . import exc
|
||||
|
||||
|
||||
class Wrench(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
The abstract base class for Wrenches.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def filter(self, obj: t.Any) -> t.Any:
|
||||
"""
|
||||
The function applied to all objects transiting through the pipeline:
|
||||
|
||||
- If the function **returns**, its return value will be passed to the next node in the pipeline;
|
||||
- If the function **raises**, the error is propagated downwards.
|
||||
|
||||
A special exception is available for discarding objects: :exc:`.discard.Discard`.
|
||||
If raised, the object will be silently ignored.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(self, obj: t.Any) -> t.Awaitable[t.Any]:
|
||||
"""
|
||||
Allow instances to be directly called, emulating coroutine functions.
|
||||
"""
|
||||
return self.filter(obj)
|
||||
|
||||
|
||||
class PassAll(Wrench):
|
||||
"""
|
||||
**Return** each received object as it is.
|
||||
|
||||
.. note:: To be used only in testing.
|
||||
"""
|
||||
|
||||
async def filter(self, obj: t.Any) -> t.Any:
|
||||
return obj
|
||||
|
||||
|
||||
class DiscardAll(Wrench):
|
||||
"""
|
||||
**Discard** each received object.
|
||||
|
||||
.. note:: To be used only in testing.
|
||||
"""
|
||||
|
||||
async def filter(self, obj: t.Any) -> t.Any:
|
||||
raise discard.Discard(obj, "Discard filter discards everything")
|
||||
|
||||
|
||||
class ErrorAll(Wrench):
|
||||
"""
|
||||
**Raise** :exc:`.exc.DeliberateException` for each received object.
|
||||
|
||||
.. note:: To be used only in testing.
|
||||
"""
|
||||
|
||||
async def filter(self, obj: t.Any) -> t.Any:
|
||||
raise exc.DeliberateException("ErrorAll received an object")
|
||||
|
||||
|
||||
class CheckBase(Wrench, metaclass=abc.ABCMeta):
|
||||
"""
|
||||
Check a condition on the received objects:
|
||||
|
||||
- If the check returns :data:`True`, the object is **returned**;
|
||||
- If the check returns :data:`False`, the object is **discarded**;
|
||||
- If an error is raised, it is **propagated**.
|
||||
"""
|
||||
|
||||
def __init__(self, *, invert: bool = False):
|
||||
self.invert: bool = invert
|
||||
"""
|
||||
If set to :data:`True`, this Nut will invert its results:
|
||||
|
||||
- If the check returns :data:`True`, the object is **discarded**;
|
||||
- If the check returns :data:`False`, the object is **returned**;
|
||||
- If an error is raised, it is **propagated**.
|
||||
"""
|
||||
|
||||
def __invert__(self):
|
||||
return self.__class__(invert=not self.invert)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def check(self, obj: t.Any) -> bool:
|
||||
"""
|
||||
The condition to check.
|
||||
|
||||
:param obj: The object passing through the pipeline.
|
||||
:return: Whether the check was successful or not.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def error(self, obj: t.Any) -> str:
|
||||
"""
|
||||
The error message to attach as :attr:`.Discard.message` if the object is discarded.
|
||||
|
||||
:param obj: The object passing through the pipeline.
|
||||
:return: The error message.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def filter(self, obj: t.Any) -> t.Any:
|
||||
if await self.check(obj) ^ self.invert:
|
||||
return obj
|
||||
else:
|
||||
raise discard.Discard(obj=obj, message=self.error(obj))
|
||||
|
||||
|
||||
class Type(CheckBase):
|
||||
"""
|
||||
Check the type of an object:
|
||||
|
||||
- If the object **is** of the specified type, it is **returned**;
|
||||
- If the object **isn't** of the specified type, it is **discarded**;
|
||||
"""
|
||||
|
||||
def __init__(self, type_: t.Type, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.type: t.Type = type_
|
||||
|
||||
async def check(self, obj: t.Any) -> bool:
|
||||
return isinstance(obj, self.type)
|
||||
|
||||
def error(self, obj: t.Any) -> str:
|
||||
return f"Not instance of type {self.type}"
|
||||
|
||||
|
||||
class StartsWith(CheckBase):
|
||||
"""
|
||||
Check if an object :func:`startswith` a certain prefix.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.prefix: str = prefix
|
||||
|
||||
async def check(self, obj: t.Any) -> bool:
|
||||
return obj.startswith(self.prefix)
|
||||
|
||||
def error(self, obj: t.Any) -> str:
|
||||
return f"Didn't start with {self.prefix}"
|
||||
|
||||
|
||||
class EndsWith(CheckBase):
|
||||
"""
|
||||
Check if an object :func:`endswith` a certain suffix.
|
||||
"""
|
||||
|
||||
def __init__(self, suffix: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.suffix: str = suffix
|
||||
|
||||
async def check(self, obj: t.Any) -> bool:
|
||||
return obj.startswith(self.suffix)
|
||||
|
||||
def error(self, obj: t.Any) -> str:
|
||||
return f"Didn't end with {self.suffix}"
|
||||
|
||||
|
||||
class Choice(CheckBase):
|
||||
"""
|
||||
Check if an object is among the accepted list.
|
||||
"""
|
||||
|
||||
def __init__(self, *accepted, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.accepted: t.Collection = accepted
|
||||
"""
|
||||
A collection of elements which can be chosen.
|
||||
"""
|
||||
|
||||
async def check(self, obj: t.Any) -> bool:
|
||||
return obj in self.accepted
|
||||
|
||||
def error(self, obj: t.Any) -> str:
|
||||
return f"Not a valid choice"
|
||||
|
||||
|
||||
class RegexCheck(CheckBase):
|
||||
"""
|
||||
Check if an object matches a regex pattern.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern: t.Pattern, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pattern: t.Pattern = pattern
|
||||
"""
|
||||
The pattern that should be matched.
|
||||
"""
|
||||
|
||||
async def check(self, obj: t.Any) -> bool:
|
||||
return bool(self.pattern.match(obj))
|
||||
|
||||
def error(self, obj: t.Any) -> str:
|
||||
return f"Didn't match pattern {self.pattern}"
|
||||
|
||||
|
||||
class RegexMatch(Wrench):
|
||||
"""
|
||||
Apply a regex over an object:
|
||||
|
||||
- If it matches, **return** the :class:`re.Match` object;
|
||||
- If it doesn't match, **discard** the object.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern: t.Pattern, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pattern: t.Pattern = pattern
|
||||
"""
|
||||
The pattern that should be matched.
|
||||
"""
|
||||
|
||||
async def filter(self, obj: t.Any) -> t.Any:
|
||||
if match := self.pattern.match(obj):
|
||||
return match
|
||||
else:
|
||||
raise discard.Discard(obj, f"Didn't match pattern {obj}")
|
||||
|
||||
|
||||
class RegexReplace(Wrench):
|
||||
"""
|
||||
Apply a regex over an object:
|
||||
|
||||
- If it matches, replace the match(es) with :attr:`.replacement` and **return** the result.
|
||||
- If it doesn't match, **return** the object as it is.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern: t.Pattern, replacement: t.Union[str, bytes]):
|
||||
self.pattern: t.Pattern = pattern
|
||||
"""
|
||||
The pattern that should be matched.
|
||||
"""
|
||||
|
||||
self.replacement: t.Union[str, bytes] = replacement
|
||||
"""
|
||||
The substitution string for the object.
|
||||
"""
|
||||
|
||||
async def filter(self, obj: t.Any) -> t.Any:
|
||||
return self.pattern.sub(self.replacement, obj)
|
||||
|
||||
|
||||
class Lambda(Wrench):
|
||||
"""
|
||||
Apply a syncronous function over the received objects.
|
||||
"""
|
||||
|
||||
def __init__(self, func: t.Callable[[t.Any], t.Any]):
|
||||
self.func: t.Callable[[t.Any], t.Any] = func
|
||||
"""
|
||||
The function to apply.
|
||||
"""
|
||||
|
||||
async def filter(self, obj: t.Any) -> t.Any:
|
||||
return self.func(obj)
|
||||
|
||||
|
||||
class Check(CheckBase):
|
||||
"""
|
||||
Check a condition on the received objects.
|
||||
"""
|
||||
|
||||
def __init__(self, func: t.Callable[[t.Any], t.Any], error: str, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.func: t.Callable[[t.Any], t.Any] = func
|
||||
"""
|
||||
The condition to check.
|
||||
"""
|
||||
|
||||
self.error: str = error
|
||||
"""
|
||||
The error message to display if the check fails.
|
||||
"""
|
||||
|
||||
async def check(self, obj: t.Any) -> bool:
|
||||
return self.func(obj)
|
||||
|
||||
def error(self, obj: t.Any) -> str:
|
||||
return self.error
|
||||
|
||||
|
||||
__all__ = (
|
||||
"Wrench",
|
||||
"CheckBase",
|
||||
"Type",
|
||||
"StartsWith",
|
||||
"EndsWith",
|
||||
"Choice",
|
||||
"RegexCheck",
|
||||
"RegexMatch",
|
||||
"RegexReplace",
|
||||
"Lambda",
|
||||
)
|
|
@ -74,3 +74,6 @@ An async generator yielding either:
|
|||
* another :data:`.AsyncAdventure`;
|
||||
* :data:`None`.
|
||||
"""
|
||||
|
||||
|
||||
Conversation = Callable[["Sentry"], Awaitable[Optional["Conversation"]]]
|
||||
|
|
Loading…
Reference in a new issue