From a4f2cc5e46b679917754b3ed29b395e31819a6da Mon Sep 17 00:00:00 2001 From: Stefano Pigozzi Date: Tue, 18 Aug 2020 03:43:11 +0200 Subject: [PATCH] Improve upon a few thingies --- royalnet/alchemy/alchemy.py | 17 ++- royalnet/backpack/commands/royalnetsync.py | 115 +++++++++--------- royalnet/backpack/commands/royalnetversion.py | 2 +- royalnet/backpack/stars/api_user_create.py | 1 + royalnet/backpack/tables/discord.py | 2 +- royalnet/backpack/tables/matrix.py | 2 +- royalnet/backpack/tables/telegram.py | 2 +- royalnet/commands/command.py | 9 +- royalnet/commands/commanddata.py | 10 ++ royalnet/constellation/constellation.py | 4 +- royalnet/serf/discord/discordserf.py | 6 +- royalnet/serf/serf.py | 16 +-- royalnet/serf/telegram/telegramserf.py | 57 ++++++--- royalnet/utils/__init__.py | 2 + royalnet/utils/multilock.py | 5 +- royalnet/utils/taskslist.py | 38 ++++++ 16 files changed, 190 insertions(+), 98 deletions(-) create mode 100644 royalnet/utils/taskslist.py diff --git a/royalnet/alchemy/alchemy.py b/royalnet/alchemy/alchemy.py index 34b3abd7..eb08be3f 100644 --- a/royalnet/alchemy/alchemy.py +++ b/royalnet/alchemy/alchemy.py @@ -1,7 +1,9 @@ -from contextlib import contextmanager, asynccontextmanager from typing import * +import contextlib +import logging from sqlalchemy import create_engine +from sqlalchemy.exc import ProgrammingError from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative.api import DeclarativeMeta from sqlalchemy.orm import sessionmaker @@ -15,6 +17,8 @@ if TYPE_CHECKING: # noinspection PyProtectedMember from sqlalchemy.engine import Engine +log = logging.getLogger(__name__) + class Alchemy: """A wrapper around :mod:`sqlalchemy.orm` that allows the instantiation of multiple engines at once while @@ -42,7 +46,12 @@ class Alchemy: # noinspection PyTypeChecker bound_table: Table = type(name, (self._Base, table), {}) self._tables[name] = bound_table - self._Base.metadata.create_all() + # FIXME: Dirty hack + try: + self._Base.metadata.create_all() + except ProgrammingError: + log.warning("Skipping table creation, as it is probably being created by a different process.") + def get(self, table: Union[str, type]) -> DeclarativeMeta: """Get the table with a specified name or class. @@ -66,7 +75,7 @@ class Alchemy: else: raise TypeError(f"Can't get tables with objects of type '{table.__class__.__qualname__}'") - @contextmanager + @contextlib.contextmanager def session_cm(self) -> Iterator[Session]: """Create a Session as a context manager (that can be used in ``with`` statements). @@ -91,7 +100,7 @@ class Alchemy: finally: session.close() - @asynccontextmanager + @contextlib.asynccontextmanager async def session_acm(self) -> AsyncIterator[Session]: """Create a Session as a async context manager (that can be used in ``async with`` statements). diff --git a/royalnet/backpack/commands/royalnetsync.py b/royalnet/backpack/commands/royalnetsync.py index 77d61a2d..82e9ddbd 100644 --- a/royalnet/backpack/commands/royalnetsync.py +++ b/royalnet/backpack/commands/royalnetsync.py @@ -15,12 +15,16 @@ class RoyalnetsyncCommand(rc.Command): syntax: str = "{username} {password}" async def run(self, args: rc.CommandArgs, data: rc.CommandData) -> None: + author = await data.get_author(error_if_none=False) + if author is not None: + raise rc.UserError(f"This account is already connected to {author}!") + username = args[0] password = " ".join(args[1:]) - author = await data.get_author(error_if_none=True) - user = await data.find_user(username) + if user is None: + raise rc.UserError("No such user.") try: successful = user.test_password(password) except ValueError: @@ -28,60 +32,61 @@ class RoyalnetsyncCommand(rc.Command): if not successful: raise rc.InvalidInputError(f"Invalid password!") - if isinstance(self.serf, rst.TelegramSerf): - import telegram - message: telegram.Message = data.message - from_user: telegram.User = message.from_user - TelegramT = self.alchemy.get(Telegram) - tg_user: Telegram = await ru.asyncify( - data.session.query(TelegramT).filter_by(tg_id=from_user.id).one_or_none - ) - if tg_user is None: - # Create - tg_user = TelegramT( - user=author, - tg_id=from_user.id, - first_name=from_user.first_name, - last_name=from_user.last_name, - username=from_user.username + async with data.session_acm() as session: + if isinstance(self.serf, rst.TelegramSerf): + import telegram + message: telegram.Message = data.message + from_user: telegram.User = message.from_user + TelegramT = self.alchemy.get(Telegram) + tg_user: Telegram = await ru.asyncify( + session.query(TelegramT).filter_by(tg_id=from_user.id).one_or_none ) - data.session.add(tg_user) - else: - # Edit - tg_user.first_name = from_user.first_name - tg_user.last_name = from_user.last_name - tg_user.username = from_user.username - await data.session_commit() - await data.reply(f"↔️ Account {tg_user} synced to {author}!") + if tg_user is None: + # Create + tg_user = TelegramT( + user=user, + tg_id=from_user.id, + first_name=from_user.first_name, + last_name=from_user.last_name, + username=from_user.username + ) + session.add(tg_user) + else: + # Edit + tg_user.first_name = from_user.first_name + tg_user.last_name = from_user.last_name + tg_user.username = from_user.username + await ru.asyncify(session.commit) + await data.reply(f"↔️ Account {tg_user} synced to {user}!") - elif isinstance(self.serf, rsd.DiscordSerf): - import discord - message: discord.Message = data.message - author: discord.User = message.author - DiscordT = self.alchemy.get(Discord) - ds_user: Discord = await ru.asyncify( - data.session.query(DiscordT).filter_by(discord_id=author.id).one_or_none - ) - if ds_user is None: - # Create - ds_user = DiscordT( - user=author, - discord_id=author.id, - username=author.name, - discriminator=author.discriminator, - avatar_url=author.avatar_url + elif isinstance(self.serf, rsd.DiscordSerf): + import discord + message: discord.Message = data.message + ds_author: discord.User = message.author + DiscordT = self.alchemy.get(Discord) + ds_user: Discord = await ru.asyncify( + session.query(DiscordT).filter_by(discord_id=ds_author.id).one_or_none ) - data.session.add(ds_user) + if ds_user is None: + # Create + ds_user = DiscordT( + user=user, + discord_id=ds_author.id, + username=ds_author.name, + discriminator=ds_author.discriminator, + avatar_url=ds_author.avatar_url + ) + session.add(ds_user) + else: + # Edit + ds_user.username = ds_author.name + ds_user.discriminator = ds_author.discriminator + ds_user.avatar_url = ds_author.avatar_url + await ru.asyncify(session.commit) + await data.reply(f"↔️ Account {ds_user} synced to {ds_author}!") + + elif isinstance(self.serf, rsm.MatrixSerf): + raise rc.UnsupportedError(f"{self} hasn't been implemented for Matrix yet") + else: - # Edit - ds_user.username = author.name - ds_user.discriminator = author.discriminator - ds_user.avatar_url = author.avatar_url - await data.session_commit() - await data.reply(f"↔️ Account {ds_user} synced to {author}!") - - elif isinstance(self.serf, rsm.MatrixSerf): - raise rc.UnsupportedError(f"{self} hasn't been implemented for Matrix yet") - - else: - raise rc.UnsupportedError(f"Unknown interface: {self.serf.__class__.__qualname__}") + raise rc.UnsupportedError(f"Unknown interface: {self.serf.__class__.__qualname__}") diff --git a/royalnet/backpack/commands/royalnetversion.py b/royalnet/backpack/commands/royalnetversion.py index 805942e8..18c7dadd 100644 --- a/royalnet/backpack/commands/royalnetversion.py +++ b/royalnet/backpack/commands/royalnetversion.py @@ -19,6 +19,6 @@ class RoyalnetversionCommand(Command): else: message = f"ℹ️ Royalnet [url=https://github.com/Steffo99/royalnet/releases/tag/{self.royalnet_version}]" \ f"{self.royalnet_version}[/url]\n" - if "69" in royalnet.version.semantic: + if "69" in self.royalnet_version: message += "(Nice.)" await data.reply(message) diff --git a/royalnet/backpack/stars/api_user_create.py b/royalnet/backpack/stars/api_user_create.py index 723f2df5..7e3f276f 100644 --- a/royalnet/backpack/stars/api_user_create.py +++ b/royalnet/backpack/stars/api_user_create.py @@ -17,6 +17,7 @@ class ApiUserCreateStar(rca.ApiStar): tags = ["user"] + @rca.magic async def post(self, data: rca.ApiData) -> ru.JSON: """Create a new Royalnet account.""" UserT = self.alchemy.get(User) diff --git a/royalnet/backpack/tables/discord.py b/royalnet/backpack/tables/discord.py index 9d808f74..bc840745 100644 --- a/royalnet/backpack/tables/discord.py +++ b/royalnet/backpack/tables/discord.py @@ -15,7 +15,7 @@ class Discord: @declared_attr def user_id(self): - return Column(Integer, ForeignKey("users.uid")) + return Column(Integer, ForeignKey("users.uid"), nullable=False) @declared_attr def user(self): diff --git a/royalnet/backpack/tables/matrix.py b/royalnet/backpack/tables/matrix.py index fe4be9b5..f7223ce2 100644 --- a/royalnet/backpack/tables/matrix.py +++ b/royalnet/backpack/tables/matrix.py @@ -13,7 +13,7 @@ class Matrix: @declared_attr def user_id(self): - return Column(Integer, ForeignKey("users.uid")) + return Column(Integer, ForeignKey("users.uid"), nullable=False) @declared_attr def user(self): diff --git a/royalnet/backpack/tables/telegram.py b/royalnet/backpack/tables/telegram.py index 53f41955..99b1bf2f 100644 --- a/royalnet/backpack/tables/telegram.py +++ b/royalnet/backpack/tables/telegram.py @@ -15,7 +15,7 @@ class Telegram: @declared_attr def user_id(self): - return Column(Integer, ForeignKey("users.uid")) + return Column(Integer, ForeignKey("users.uid"), nullable=False) @declared_attr def user(self): diff --git a/royalnet/commands/command.py b/royalnet/commands/command.py index 5bd6f321..f75c8954 100644 --- a/royalnet/commands/command.py +++ b/royalnet/commands/command.py @@ -40,12 +40,17 @@ class Command(metaclass=abc.ABCMeta): @property def alchemy(self) -> "Alchemy": - """A shortcut for :attr:`.interface.alchemy`.""" + """A shortcut for :attr:`.serf.alchemy`.""" return self.serf.alchemy + @property + def session_acm(self): + """A shortcut for :attr:`.alchemy.session_acm`.""" + return self.alchemy.session_acm + @property def loop(self) -> aio.AbstractEventLoop: - """A shortcut for :attr:`.interface.loop`.""" + """A shortcut for :attr:`.serf.loop`.""" return self.serf.loop @abc.abstractmethod diff --git a/royalnet/commands/commanddata.py b/royalnet/commands/commanddata.py index fe8d0420..72dd036e 100644 --- a/royalnet/commands/commanddata.py +++ b/royalnet/commands/commanddata.py @@ -21,6 +21,16 @@ class CommandData: def loop(self): return self.command.serf.loop + @property + def alchemy(self): + """A shortcut for :attr:`.command.alchemy`.""" + return self.command.alchemy + + @property + def session_acm(self): + """A shortcut for :attr:`.alchemy.session_acm`.""" + return self.alchemy.session_acm + async def reply(self, text: str) -> None: """Send a text message to the channel where the call was made. diff --git a/royalnet/constellation/constellation.py b/royalnet/constellation/constellation.py index ce712733..30741000 100644 --- a/royalnet/constellation/constellation.py +++ b/royalnet/constellation/constellation.py @@ -220,7 +220,7 @@ class Constellation: for SelectedEvent in events: # Initialize the event try: - event = SelectedEvent(constellation=self, config=pack_cfg) + event = SelectedEvent(parent=self, config=pack_cfg) except Exception as e: log.error(f"Skipping: " f"{SelectedEvent.__qualname__} - {e.__class__.__qualname__} in the initialization.") @@ -262,7 +262,7 @@ class Constellation: self.starlette.add_route(*self._page_star_wrapper(page_star_instance)) def run_blocking(self): - log.info(f"Running Constellation on https://{self.address}:{self.port}/...") + log.info(f"Running Constellation on http://{self.address}:{self.port}/...") self.running = True try: uvicorn.run(self.starlette, host=self.address, port=self.port, log_config=UVICORN_LOGGING_CONFIG) diff --git a/royalnet/serf/discord/discordserf.py b/royalnet/serf/discord/discordserf.py index 5f31049f..ec7543ca 100644 --- a/royalnet/serf/discord/discordserf.py +++ b/royalnet/serf/discord/discordserf.py @@ -73,7 +73,8 @@ class DiscordSerf(Serf): async def get_author(data, error_if_none=False): user: "discord.Member" = data.message.author - query = data.session.query(self.master_table) + async with data.session_acm() as session: + query = session.query(self.master_table) for link in self.identity_chain: query = query.join(link.mapper.class_) query = query.filter(self.identity_column == user.id) @@ -135,8 +136,7 @@ class DiscordSerf(Serf): # noinspection PyMethodMayBeStatic async def on_message(cli, message: "discord.Message") -> None: """Handle messages received by passing them to the handle_message method of the bot.""" - # TODO: keep reference to these tasks somewhere - self.loop.create_task(self.handle_message(message)) + self.tasks.add(self.handle_message(message)) async def on_ready(cli) -> None: """Change the bot presence to ``online`` when the bot is ready.""" diff --git a/royalnet/serf/serf.py b/royalnet/serf/serf.py index 57fa3766..d759ea70 100644 --- a/royalnet/serf/serf.py +++ b/royalnet/serf/serf.py @@ -36,6 +36,9 @@ class Serf(abc.ABC): self.loop: Optional[aio.AbstractEventLoop] = loop """The event loop this Serf is running on.""" + self.tasks: Optional[ru.TaskList] = ru.TaskList(self.loop) + """A list of all running tasks of the serf. Initialized at the serf start.""" + # Import packs pack_names = packs_cfg["active"] packs = {} @@ -63,8 +66,8 @@ class Serf(abc.ABC): """The identity table containing the interface data (such as the Telegram user data) and that is in a many-to-one relationship with the master table.""" - # TODO: I'm not sure what this is either self.identity_column: Optional[str] = None + """The name of the column in the identity table that contains a unique user identifier. (???)""" # Alchemy if ra.Alchemy is None: @@ -76,6 +79,7 @@ class Serf(abc.ABC): tables = set() for pack in packs.values(): try: + # noinspection PyUnresolvedReferences tables = tables.union(pack["tables"].available_tables) except AttributeError: log.warning(f"Pack `{pack}` does not have the `available_tables` attribute.") @@ -100,12 +104,14 @@ class Serf(abc.ABC): pack = packs[pack_name] pack_cfg = packs_cfg.get(pack_name, {}) try: + # noinspection PyUnresolvedReferences events = pack["events"].available_events except AttributeError: log.warning(f"Pack `{pack}` does not have the `available_events` attribute.") else: self.register_events(events, pack_cfg) try: + # noinspection PyUnresolvedReferences commands = pack["commands"].available_commands except AttributeError: log.warning(f"Pack `{pack}` does not have the `available_commands` attribute.") @@ -216,7 +222,7 @@ class Serf(abc.ABC): for SelectedEvent in events: # Initialize the event try: - event = SelectedEvent(serf=self, config=pack_cfg) + event = SelectedEvent(parent=self, config=pack_cfg) except Exception as e: log.error(f"Skipping: " f"{SelectedEvent.__qualname__} - {e.__class__.__qualname__} in the initialization.") @@ -282,8 +288,6 @@ class Serf(abc.ABC): except Exception as e: ru.sentry_exc(e) await data.reply(f"⛔️ [b]{e.__class__.__name__}[/b]\n" + '\n'.join(map(lambda a: repr(a), e.args))) - finally: - await data.session_close() @staticmethod async def press(key: rc.KeyboardKey, data: rc.CommandData): @@ -307,12 +311,10 @@ class Serf(abc.ABC): except Exception as e: ru.sentry_exc(e) await data.reply(f"⛔️ [b]{e.__class__.__name__}[/b]\n" + '\n'.join(map(lambda a: repr(a), e.args))) - finally: - await data.session_close() async def run(self): """A coroutine that starts the event loop and handles command calls.""" - self.herald_task = self.loop.create_task(self.herald.run()) + self.herald_task = self.tasks.add(self.herald.run()) # OVERRIDE THIS METHOD! @classmethod diff --git a/royalnet/serf/telegram/telegramserf.py b/royalnet/serf/telegram/telegramserf.py index d74b0b76..f49ae944 100644 --- a/royalnet/serf/telegram/telegramserf.py +++ b/royalnet/serf/telegram/telegramserf.py @@ -127,7 +127,8 @@ class TelegramSerf(Serf): if error_if_none: raise rc.CommandError("No command caller for this message") return None - query = data.session.query(self.master_table) + async with data.session_acm() as session: + query = session.query(self.master_table) for link in self.identity_chain: query = query.join(link.mapper.class_) query = query.filter(self.identity_column == user.id) @@ -189,7 +190,8 @@ class TelegramSerf(Serf): if error_if_none: raise rc.CommandError("No command caller for this message") return None - query = data.session.query(self.master_table) + async with data.session_acm() as session: + query = session.query(self.master_table) for link in self.identity_chain: query = query.join(link.mapper.class_) query = query.filter(self.identity_column == user.id) @@ -206,25 +208,27 @@ class TelegramSerf(Serf): async def handle_update(self, update: telegram.Update): """Delegate :class:`telegram.Update` handling to the correct message type submethod.""" if update.message is not None: + log.debug(f"Handling update as a message") await self.handle_message(update.message) elif update.edited_message is not None: - pass + log.debug(f"Update is a edited message, not doing anything") elif update.channel_post is not None: - pass + log.debug(f"Update is a channel post, not doing anything") elif update.edited_channel_post is not None: - pass + log.debug(f"Update is a channel edit, not doing anything") elif update.inline_query is not None: - pass + log.debug(f"Update is a inline query, not doing anything") elif update.chosen_inline_result is not None: - pass + log.debug(f"Update is a chosen inline result, not doing anything") elif update.callback_query is not None: + log.debug(f"Handling update as a callback query") await self.handle_callback_query(update.callback_query) elif update.shipping_query is not None: - pass + log.debug(f"Update is a shipping query, not doing anything") elif update.pre_checkout_query is not None: - pass + log.debug(f"Update is a precheckout query, not doing anything") elif update.poll is not None: - pass + log.debug(f"Update is a poll, not doing anything") else: log.warning(f"Unknown update type: {update}") @@ -236,25 +240,31 @@ class TelegramSerf(Serf): text: str = message.caption # No text or caption, ignore the message if text is None: + log.debug("Skipping message as it had no text or caption") return # Skip non-command updates - if not text.startswith("/"): + if not text.startswith(self.prefix): + log.debug(f"Skipping message as it didn't start with {self.prefix}") return # Find and clean parameters command_text, *parameters = text.split(" ") - command_name = command_text.replace(f"@{self.client.username}", "").lower() + command_name = command_text.replace(f"@{self.client.username}", "").lstrip(self.prefix).lower() + log.debug(f"Parsed '{command_name}' as command name") # Find the command try: command = self.commands[command_name] except KeyError: # Skip the message + log.debug(f"Skipping message as I could not find the command {command_name}") return # Send a typing notification + log.debug(f"Sending typing notification") await self.api_call(message.chat.send_action, telegram.ChatAction.TYPING) # Prepare data # noinspection PyArgumentList data = self.MessageData(command=command, message=message) # Call the command + log.debug(f"Calling {command}") await self.call(command, data, parameters) async def handle_callback_query(self, cbq: telegram.CallbackQuery): @@ -270,16 +280,25 @@ class TelegramSerf(Serf): async def run(self): await super().run() while True: + # Collect ended tasks + self.tasks.collect() # Get the latest 100 updates - last_updates: List[telegram.Update] = await self.api_call(self.client.get_updates, - offset=self.update_offset, - timeout=60, - read_latency=5.0) + log.debug("Getting updates...") + last_updates: Optional[List[telegram.Update]] = await self.api_call( + self.client.get_updates, + offset=self.update_offset, + timeout=60, + read_latency=5.0 + ) + # Ensure a list was returned from the API call + if not isinstance(last_updates, list): + log.warning("Received invalid data from get_updates, sleeping for 60 seconds, hoping it fixes itself.") + await aio.sleep(60) + continue # Handle updates + log.debug("Handling updates...") for update in last_updates: - # TODO: don't lose the reference to the task - # noinspection PyAsyncCall - self.loop.create_task(self.handle_update(update)) + self.tasks.add(self.handle_update(update)) # Recalculate offset try: self.update_offset = last_updates[-1].update_id + 1 diff --git a/royalnet/utils/__init__.py b/royalnet/utils/__init__.py index 68d9c28b..97ea0ad4 100644 --- a/royalnet/utils/__init__.py +++ b/royalnet/utils/__init__.py @@ -7,6 +7,7 @@ from .sentry import init_sentry, sentry_exc, sentry_wrap, sentry_async_wrap from .sleep_until import sleep_until from .strip_tabs import strip_tabs from .urluuid import to_urluuid, from_urluuid +from .taskslist import TaskList __all__ = [ "asyncify", @@ -26,4 +27,5 @@ __all__ = [ "init_logging", "JSON", "strip_tabs", + "TaskList", ] diff --git a/royalnet/utils/multilock.py b/royalnet/utils/multilock.py index 227d9f0f..120c2dd4 100644 --- a/royalnet/utils/multilock.py +++ b/royalnet/utils/multilock.py @@ -6,7 +6,9 @@ log = logging.getLogger(__name__) class MultiLock: - """A lock that can allow both simultaneous access and exclusive access to a resource.""" + """A lock that allows either simultaneous read access or exclusive write access. + + Basically, a reimplementation of Rust's `RwLock `_ .""" def __init__(self): self._counter: int = 0 @@ -40,7 +42,6 @@ class MultiLock: async def exclusive(self): """Acquire the lock for exclusive access.""" log.debug(f"Waiting for exclusive lock end: {self}") - # TODO: check if this actually works await self._exclusive_event.wait() self._exclusive_event.clear() log.debug(f"Waiting for normal lock end: {self}") diff --git a/royalnet/utils/taskslist.py b/royalnet/utils/taskslist.py new file mode 100644 index 00000000..ee7d397e --- /dev/null +++ b/royalnet/utils/taskslist.py @@ -0,0 +1,38 @@ +from typing import * +import asyncio as aio +import logging +from .sentry import sentry_exc + +log = logging.getLogger(__name__) + + +class TaskList: + def __init__(self, loop: aio.AbstractEventLoop): + self.loop: aio.AbstractEventLoop = loop + self.tasks: List[aio.Task] = [] + + def collect(self): + """Remove finished tasks from the list.""" + log.debug(f"Collecting done tasks") + new_list = [] + for task in self.tasks: + try: + task.result() + except aio.CancelledError: + log.warning(f"Task {task} was unexpectedly cancelled.") + except aio.InvalidStateError: + log.debug(f"Task {task} hasn't finished running yet, readding it to the list.") + new_list.append(task) + except Exception as err: + sentry_exc(err) + self.tasks = new_list + + def add(self, coroutine: Awaitable[Any], timeout: float = None) -> aio.Task: + """Add a new task to the list; the task will be cancelled if ``timeout`` seconds pass.""" + log.debug(f"Creating new task {coroutine}") + if timeout: + task = self.loop.create_task(aio.wait_for(coroutine, timeout=timeout)) + else: + task = self.loop.create_task(coroutine) + self.tasks.append(task) + return task