diff --git a/royalnet/serf/discord/discord.py b/royalnet/serf/discord/discord.py index 24df3140..50501ad5 100644 --- a/royalnet/serf/discord/discord.py +++ b/royalnet/serf/discord/discord.py @@ -1,35 +1,50 @@ -import discord -import sentry_sdk -import logging as _logging -from .generic import GenericBot -from royalnet.utils import * -from royalnet.error import * -from royalnet.bard import * -from royalnet.commands import * +import logging +import asyncio +from typing import Type, Optional, List, Callable, Union +from royalnet.commands import Command, CommandInterface, CommandData, CommandArgs, CommandError, InvalidInputError, \ + UnsupportedError +from royalnet.utils import asyncify +from .escape import escape +from ..serf import Serf + +try: + import discord +except ImportError: + discord = None + +try: + from sqlalchemy.orm.session import Session + from ..alchemyconfig import AlchemyConfig +except ImportError: + Session = None + AlchemyConfig = None + +try: + from royalnet.herald import Config as HeraldConfig +except ImportError: + HeraldConfig = None + +log = logging.getLogger(__name__) -log = _logging.getLogger(__name__) - - -class MusicData: - def __init__(self): - self.playmode: playmodes.PlayMode = playmodes.Playlist() - self.voice_client: typing.Optional[discord.VoiceClient] = None - - def queue_preview(self): - return self.playmode.queue_preview() - - -class DiscordBot(GenericBot): - """A bot that connects to `Discord `_.""" +class DiscordSerf(Serf): + """A :class:`Serf` that connects to `Discord `_ as a bot.""" interface_name = "discord" - def _init_voice(self): - """Initialize the variables needed for the connection to voice chat.""" - log.debug(f"Creating music_data dict") - self.music_data: typing.Dict[discord.Guild, MusicData] = {} + def __init__(self, *, + alchemy_config: Optional[AlchemyConfig] = None, + commands: List[Type[Command]] = None, + network_config: Optional[HeraldConfig] = None, + secrets_name: str = "__default__"): + if discord is None: + raise ImportError("'discord' extra is not installed") - def _interface_factory(self) -> typing.Type[CommandInterface]: + super().__init__(alchemy_config=alchemy_config, + commands=commands, + network_config=network_config, + secrets_name=secrets_name) + + def _interface_factory(self) -> Type[CommandInterface]: # noinspection PyPep8Naming GenericInterface = super().interface_factory() @@ -40,22 +55,22 @@ class DiscordBot(GenericBot): return DiscordInterface - def _data_factory(self) -> typing.Type[CommandData]: + def _data_factory(self) -> Type[CommandData]: # noinspection PyMethodParameters,PyAbstractClass class DiscordData(CommandData): - def __init__(data, interface: CommandInterface, message: discord.Message): - super().__init__(interface) + def __init__(data, interface: CommandInterface, session, message: discord.Message): + super().__init__(interface=interface, session=session) data.message = message async def reply(data, text: str): - await data.message.channel.send(discord_escape(text)) + await data.message.channel.send(escape(text)) async def get_author(data, error_if_none=False): user: discord.Member = data.message.author - query = data.session.query(self.master_table) - for link in self.identity_chain: + query = data.session.query(self._master_table) + for link in self._identity_chain: query = query.join(link.mapper.class_) - query = query.filter(self.identity_column == user.id) + query = query.filter(self._identity_column == user.id) result = await asyncify(query.one_or_none) if result is None and error_if_none: raise CommandError("You must be registered to use this command.") @@ -66,117 +81,96 @@ class DiscordBot(GenericBot): return DiscordData - def _bot_factory(self) -> typing.Type[discord.Client]: - """Create a custom DiscordClient class inheriting from :py:class:`discord.Client`.""" - log.debug(f"Creating DiscordClient") + async def handle_message(self, message: discord.Message): + """Handle a Discord message by calling a command if appropriate.""" + text = message.content + # Skip non-text messages + if not text: + return + # Skip non-command updates + if not text.startswith("!"): + return + # Skip bot messages + author: Union[discord.User] = message.author + if author.bot: + return + # Find and clean parameters + command_text, *parameters = text.split(" ") + # Don't use a case-sensitive command name + command_name = command_text.lower() + # Find the command + try: + command = self.commands[command_name] + except KeyError: + # Skip the message + return + # Call the command + log.debug(f"Calling command '{command.name}'") + with message.channel.typing(): + # Open an alchemy session, if available + if self.alchemy is not None: + session = await asyncify(self.alchemy.Session) + else: + session = None + # Prepare data + data = self.Data(interface=command.interface, session=session, message=message) + try: + # Run the command + await command.run(CommandArgs(parameters), data) + except InvalidInputError as e: + await data.reply(f":warning: {e.message}\n" + f"Syntax: [c]/{command.name} {command.syntax}[/c]") + except UnsupportedError as e: + await data.reply(f":warning: {e.message}") + except CommandError as e: + await data.reply(f":warning: {e.message}") + except Exception as e: + self.sentry_exc(e) + error_message = f"🦀 [b]{e.__class__.__name__}[/b] 🦀\n" \ + '\n'.join(e.args) + await data.reply(error_message) + finally: + # Close the alchemy session + if session is not None: + await asyncify(session.close) + def _bot_factory(self) -> Type[discord.Client]: + """Create a custom class inheriting from :py:class:`discord.Client`.""" # noinspection PyMethodParameters class DiscordClient(discord.Client): - async def vc_connect_or_move(cli, channel: discord.VoiceChannel): - music_data = self.music_data.get(channel.guild) - if music_data is None: - # Create a MusicData object - music_data = MusicData() - self.music_data[channel.guild] = music_data - # Connect to voice - log.debug(f"Connecting to Voice in {channel}") - try: - music_data.voice_client = await channel.connect(reconnect=False, timeout=10) - except Exception: - log.warning(f"Failed to connect to Voice in {channel}") - del self.music_data[channel.guild] - raise - else: - log.debug(f"Connected to Voice in {channel}") - else: - if music_data.voice_client is None: - # TODO: change exception type - raise Exception("Another connection attempt is already in progress.") - # Try to move to a different channel - voice_client = music_data.voice_client - log.debug(f"Moving {voice_client} to {channel}") - await voice_client.move_to(channel) - log.debug(f"Moved {voice_client} to {channel}") - async def on_message(cli, message: discord.Message): - self.loop.create_task(cli._handle_message(message)) - - async def _handle_message(cli, message: discord.Message): - text = message.content - # Skip non-text messages - if not text: - return - # Skip non-command updates - if not text.startswith("!"): - return - # Skip bot messages - author: typing.Union[discord.User] = message.author - if author.bot: - return - # Find and clean parameters - command_text, *parameters = text.split(" ") - # Don't use a case-sensitive command name - command_name = command_text.lower() - # Find the command - try: - command = self.commands[command_name] - except KeyError: - # Skip the message - return - # Prepare data - data = self._Data(interface=command.interface, message=message) - # Call the command - log.debug(f"Calling command '{command.name}'") - with message.channel.typing(): - # Run the command - try: - await command.run(CommandArgs(parameters), data) - except InvalidInputError as e: - await data.reply(f":warning: {e.message}\n" - f"Syntax: [c]/{command.name} {command.syntax}[/c]") - except UnsupportedError as e: - await data.reply(f":warning: {e.message}") - except CommandError as e: - await data.reply(f":warning: {e.message}") - except Exception as e: - sentry_sdk.capture_exception(e) - error_message = f"🦀 [b]{e.__class__.__name__}[/b] 🦀\n" - error_message += '\n'.join(e.args) - await data.reply(error_message) - # Close the data session - await data.session_close() - - async def on_connect(cli): - log.debug("Connected to Discord") - - async def on_disconnect(cli): - log.error("Disconnected from Discord!") + """Handle messages received by passing them to the handle_message method of the bot.""" + # TODO: keep reference to these tasks somewhere + self.loop.create_task(self.handle_message(message)) async def on_ready(cli) -> None: - log.debug("Connection successful, client is ready") + """Change the bot presence to ``online`` when the bot is ready.""" await cli.change_presence(status=discord.Status.online) - def find_guild_by_name(cli, name: str) -> typing.List[discord.Guild]: - """Find the :py:class:`discord.Guild` with the specified name (case insensitive).""" - all_guilds: typing.List[discord.Guild] = cli.guilds - matching_channels: typing.List[discord.Guild] = [] + def find_guild(cli, name: str) -> List[discord.Guild]: + """Find the :class:`discord.Guild`s with the specified name (case insensitive). + + Returns: + A :class:`list` of :class:`discord.Guild` having the specified name.""" + all_guilds: List[discord.Guild] = cli.guilds + matching_channels: List[discord.Guild] = [] for guild in all_guilds: if guild.name.lower() == name.lower(): matching_channels.append(guild) return matching_channels - def find_channel_by_name(cli, - name: str, - guild: typing.Optional[discord.Guild] = None) -> typing.List[discord.abc.GuildChannel]: - """Find the :py:class:`TextChannel`, :py:class:`VoiceChannel` or :py:class:`CategoryChannel` with the + def find_channel(cli, + name: str, + guild: Optional[discord.Guild] = None) -> List[discord.abc.GuildChannel]: + """Find the :class:`TextChannel`, :class:`VoiceChannel` or :class:`CategoryChannel` with the specified name (case insensitive). - You can specify a guild to only find channels in that specific guild.""" + You can specify a guild to only search in that specific guild.""" if guild is not None: all_channels = guild.channels else: - all_channels: typing.List[discord.abc.GuildChannel] = cli.get_all_channels() - matching_channels: typing.List[discord.abc.GuildChannel] = [] + all_channels: List[discord.abc.GuildChannel] = cli.get_all_channels() + matching_channels: List[discord.abc.GuildChannel] = [] for channel in all_channels: if not (isinstance(channel, discord.TextChannel) or isinstance(channel, discord.VoiceChannel) @@ -186,8 +180,9 @@ class DiscordBot(GenericBot): matching_channels.append(channel) return matching_channels - def find_voice_client_by_guild(cli, guild: discord.Guild) -> typing.Optional[discord.VoiceClient]: + def find_voice_client(cli, guild: discord.Guild) -> Optional[discord.VoiceClient]: """Find the :py:class:`discord.VoiceClient` belonging to a specific :py:class:`discord.Guild`.""" + # TODO: the bug I was looking for might be here for voice_client in cli.voice_clients: if voice_client.guild == guild: return voice_client @@ -195,6 +190,8 @@ class DiscordBot(GenericBot): return DiscordClient + # TODO: restart from here + def _init_client(self): """Create an instance of the DiscordClient class created in :py:func:`royalnet.bots.DiscordBot._bot_factory`.""" log.debug(f"Creating DiscordClient instance") diff --git a/royalnet/serf/telegram/telegramserf.py b/royalnet/serf/telegram/telegramserf.py index 3ee679b6..c6db695c 100644 --- a/royalnet/serf/telegram/telegramserf.py +++ b/royalnet/serf/telegram/telegramserf.py @@ -32,7 +32,7 @@ log = logging.getLogger(__name__) class TelegramSerf(Serf): - """A Serf that connects to `Telegram `_.""" + """A Serf that connects to `Telegram `_ as a bot.""" interface_name = "telegram" def __init__(self, *, @@ -94,9 +94,6 @@ class TelegramSerf(Serf): name = self.interface_name prefix = "/" - def __init__(self): - super().__init__() - return TelegramInterface def data_factory(self) -> Type[CommandData]: