diff --git a/pyproject.toml b/pyproject.toml index 5641f378..3259981f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ [tool.poetry] name = "royalnet" -version = "5.11.3" +version = "5.11.4" description = "A multipurpose bot and web framework" authors = ["Stefano Pigozzi "] license = "AGPL-3.0+" diff --git a/royalnet/commands/commanddata.py b/royalnet/commands/commanddata.py index a3809ee1..884a1fa5 100644 --- a/royalnet/commands/commanddata.py +++ b/royalnet/commands/commanddata.py @@ -4,7 +4,7 @@ import logging from typing import * from royalnet.backpack.tables.users import User -from .errors import UnsupportedError +from .errors import * if TYPE_CHECKING: from .keyboardkey import KeyboardKey @@ -46,14 +46,6 @@ class CommandData: caption: The caption to attach to the image.""" raise UnsupportedError(f"'{self.reply_image.__name__}' is not supported") - async def get_author(self, error_if_none: bool = False): - """Try to find the identifier of the user that sent the message. - That probably means, the database row identifying the user. - - Parameters: - error_if_none: Raise an exception if this is True and the call has no author.""" - raise UnsupportedError(f"'{self.get_author.__name__}' is not supported") - async def delete_invoking(self, error_if_unavailable: bool = False) -> None: """Delete the invoking message, if supported by the interface. @@ -64,13 +56,28 @@ class CommandData: if error_if_unavailable: raise UnsupportedError(f"'{self.delete_invoking.__name__}' is not supported") - async def find_user(self, identifier: Union[str, int], *, session) -> Optional["User"]: + async def find_author(self, *, session, required: bool = False) -> Optional["User"]: + """Try to find the identifier of the user that sent the message. + That probably means, the database row identifying the user. + + Parameters: + session: the session that the user should be returned from. + required: Raise an exception if this is True and the call has no author. + """ + raise UnsupportedError(f"'{self.find_author.__name__}' is not supported") + + async def find_user(self, identifier: Union[str, int], *, session, required: bool = False) -> Optional["User"]: """Find the User having a specific identifier. Parameters: identifier: the identifier to search for. - session: the session that the user should be returned from""" - return await User.find(alchemy=self.alchemy, session=session, identifier=identifier) + session: the session that the user should be returned from. + required: Raise an exception if this is True and no user was found.. + """ + user: Optional["User"] = await User.find(alchemy=self.alchemy, session=session, identifier=identifier) + if required and user is None: + raise InvalidInputError(f"User '{identifier}' was not found.") + return user @contextlib.asynccontextmanager async def keyboard(self, text, keys: List["KeyboardKey"]): diff --git a/royalnet/serf/discord/discordserf.py b/royalnet/serf/discord/discordserf.py index 0a8b1731..338934aa 100644 --- a/royalnet/serf/discord/discordserf.py +++ b/royalnet/serf/discord/discordserf.py @@ -20,9 +20,6 @@ class DiscordSerf(Serf): interface_name = "discord" prefix = "!" - _identity_table = rbt.Discord - _identity_column = "discord_id" - def __init__(self, loop: aio.AbstractEventLoop, alchemy_cfg: rc.ConfigDict, @@ -67,15 +64,16 @@ class DiscordSerf(Serf): async def reply_image(data, image: io.IOBase, caption: Optional[str] = None) -> None: await data.message.channel.send(caption, file=discord.File(image, 'image')) - async def get_author(data, error_if_none=False): - user: "discord.Member" = data.message.author - async with data.session_acm() as session: - query = session.query(self.master_table) - for link in self.identity_chain: - query = query.join(link.mapper.class_) - query = query.filter(self.identity_column == user.id) - result = await asyncify(query.one_or_none) - if result is None and error_if_none: + async def find_author(data, + *, + session, + required: bool = False) -> Optional[rbt.User]: + user: Union["discord.User", "discord.Member"] = data.message.author + DiscordT = data.alchemy.get(rbt.Discord) + result = await asyncify( + session.query(DiscordT).filter(DiscordT.discord_id == user.id).one_or_none + ) + if result is None and required: raise rc.CommandError("You must be registered to use this command.") return result diff --git a/royalnet/serf/telegram/telegramserf.py b/royalnet/serf/telegram/telegramserf.py index f49ae944..6766d86f 100644 --- a/royalnet/serf/telegram/telegramserf.py +++ b/royalnet/serf/telegram/telegramserf.py @@ -121,21 +121,18 @@ class TelegramSerf(Serf): parse_mode="HTML", disable_web_page_preview=True) - async def get_author(data, error_if_none=False): - user: Optional[telegram.User] = data.message.from_user - if user is None: - if error_if_none: - raise rc.CommandError("No command caller for this message") - return None - async with data.session_acm() as session: - query = session.query(self.master_table) - for link in self.identity_chain: - query = query.join(link.mapper.class_) - query = query.filter(self.identity_column == user.id) - result = await ru.asyncify(query.one_or_none) - if result is None and error_if_none: - raise rc.CommandError("Command caller is not registered") - return result + async def find_author(data, + *, + session, + required: bool = False) -> Optional[rbt.User]: + user: "telegram.User" = data.message.from_user + TelegramT = data.alchemy.get(rbt.Telegram) + result = await ru.asyncify( + session.query(TelegramT).filter(TelegramT.discord_id == user.id).one_or_none + ) + if result is None and required: + raise rc.CommandError("You must be registered to use this command.") + return result.user async def delete_invoking(data, error_if_unavailable=False) -> None: await self.api_call(data.message.delete)