diff --git a/royalnet/serf/discord/discord.py b/royalnet/serf/discord/discord.py
index 24df3140..50501ad5 100644
--- a/royalnet/serf/discord/discord.py
+++ b/royalnet/serf/discord/discord.py
@@ -1,35 +1,50 @@
-import discord
-import sentry_sdk
-import logging as _logging
-from .generic import GenericBot
-from royalnet.utils import *
-from royalnet.error import *
-from royalnet.bard import *
-from royalnet.commands import *
+import logging
+import asyncio
+from typing import Type, Optional, List, Callable, Union
+from royalnet.commands import Command, CommandInterface, CommandData, CommandArgs, CommandError, InvalidInputError, \
+ UnsupportedError
+from royalnet.utils import asyncify
+from .escape import escape
+from ..serf import Serf
+
+try:
+ import discord
+except ImportError:
+ discord = None
+
+try:
+ from sqlalchemy.orm.session import Session
+ from ..alchemyconfig import AlchemyConfig
+except ImportError:
+ Session = None
+ AlchemyConfig = None
+
+try:
+ from royalnet.herald import Config as HeraldConfig
+except ImportError:
+ HeraldConfig = None
+
+log = logging.getLogger(__name__)
-log = _logging.getLogger(__name__)
-
-
-class MusicData:
- def __init__(self):
- self.playmode: playmodes.PlayMode = playmodes.Playlist()
- self.voice_client: typing.Optional[discord.VoiceClient] = None
-
- def queue_preview(self):
- return self.playmode.queue_preview()
-
-
-class DiscordBot(GenericBot):
- """A bot that connects to `Discord `_."""
+class DiscordSerf(Serf):
+ """A :class:`Serf` that connects to `Discord `_ as a bot."""
interface_name = "discord"
- def _init_voice(self):
- """Initialize the variables needed for the connection to voice chat."""
- log.debug(f"Creating music_data dict")
- self.music_data: typing.Dict[discord.Guild, MusicData] = {}
+ def __init__(self, *,
+ alchemy_config: Optional[AlchemyConfig] = None,
+ commands: List[Type[Command]] = None,
+ network_config: Optional[HeraldConfig] = None,
+ secrets_name: str = "__default__"):
+ if discord is None:
+ raise ImportError("'discord' extra is not installed")
- def _interface_factory(self) -> typing.Type[CommandInterface]:
+ super().__init__(alchemy_config=alchemy_config,
+ commands=commands,
+ network_config=network_config,
+ secrets_name=secrets_name)
+
+ def _interface_factory(self) -> Type[CommandInterface]:
# noinspection PyPep8Naming
GenericInterface = super().interface_factory()
@@ -40,22 +55,22 @@ class DiscordBot(GenericBot):
return DiscordInterface
- def _data_factory(self) -> typing.Type[CommandData]:
+ def _data_factory(self) -> Type[CommandData]:
# noinspection PyMethodParameters,PyAbstractClass
class DiscordData(CommandData):
- def __init__(data, interface: CommandInterface, message: discord.Message):
- super().__init__(interface)
+ def __init__(data, interface: CommandInterface, session, message: discord.Message):
+ super().__init__(interface=interface, session=session)
data.message = message
async def reply(data, text: str):
- await data.message.channel.send(discord_escape(text))
+ await data.message.channel.send(escape(text))
async def get_author(data, error_if_none=False):
user: discord.Member = data.message.author
- query = data.session.query(self.master_table)
- for link in self.identity_chain:
+ query = data.session.query(self._master_table)
+ for link in self._identity_chain:
query = query.join(link.mapper.class_)
- query = query.filter(self.identity_column == user.id)
+ query = query.filter(self._identity_column == user.id)
result = await asyncify(query.one_or_none)
if result is None and error_if_none:
raise CommandError("You must be registered to use this command.")
@@ -66,117 +81,96 @@ class DiscordBot(GenericBot):
return DiscordData
- def _bot_factory(self) -> typing.Type[discord.Client]:
- """Create a custom DiscordClient class inheriting from :py:class:`discord.Client`."""
- log.debug(f"Creating DiscordClient")
+ async def handle_message(self, message: discord.Message):
+ """Handle a Discord message by calling a command if appropriate."""
+ text = message.content
+ # Skip non-text messages
+ if not text:
+ return
+ # Skip non-command updates
+ if not text.startswith("!"):
+ return
+ # Skip bot messages
+ author: Union[discord.User] = message.author
+ if author.bot:
+ return
+ # Find and clean parameters
+ command_text, *parameters = text.split(" ")
+ # Don't use a case-sensitive command name
+ command_name = command_text.lower()
+ # Find the command
+ try:
+ command = self.commands[command_name]
+ except KeyError:
+ # Skip the message
+ return
+ # Call the command
+ log.debug(f"Calling command '{command.name}'")
+ with message.channel.typing():
+ # Open an alchemy session, if available
+ if self.alchemy is not None:
+ session = await asyncify(self.alchemy.Session)
+ else:
+ session = None
+ # Prepare data
+ data = self.Data(interface=command.interface, session=session, message=message)
+ try:
+ # Run the command
+ await command.run(CommandArgs(parameters), data)
+ except InvalidInputError as e:
+ await data.reply(f":warning: {e.message}\n"
+ f"Syntax: [c]/{command.name} {command.syntax}[/c]")
+ except UnsupportedError as e:
+ await data.reply(f":warning: {e.message}")
+ except CommandError as e:
+ await data.reply(f":warning: {e.message}")
+ except Exception as e:
+ self.sentry_exc(e)
+ error_message = f"🦀 [b]{e.__class__.__name__}[/b] 🦀\n" \
+ '\n'.join(e.args)
+ await data.reply(error_message)
+ finally:
+ # Close the alchemy session
+ if session is not None:
+ await asyncify(session.close)
+ def _bot_factory(self) -> Type[discord.Client]:
+ """Create a custom class inheriting from :py:class:`discord.Client`."""
# noinspection PyMethodParameters
class DiscordClient(discord.Client):
- async def vc_connect_or_move(cli, channel: discord.VoiceChannel):
- music_data = self.music_data.get(channel.guild)
- if music_data is None:
- # Create a MusicData object
- music_data = MusicData()
- self.music_data[channel.guild] = music_data
- # Connect to voice
- log.debug(f"Connecting to Voice in {channel}")
- try:
- music_data.voice_client = await channel.connect(reconnect=False, timeout=10)
- except Exception:
- log.warning(f"Failed to connect to Voice in {channel}")
- del self.music_data[channel.guild]
- raise
- else:
- log.debug(f"Connected to Voice in {channel}")
- else:
- if music_data.voice_client is None:
- # TODO: change exception type
- raise Exception("Another connection attempt is already in progress.")
- # Try to move to a different channel
- voice_client = music_data.voice_client
- log.debug(f"Moving {voice_client} to {channel}")
- await voice_client.move_to(channel)
- log.debug(f"Moved {voice_client} to {channel}")
-
async def on_message(cli, message: discord.Message):
- self.loop.create_task(cli._handle_message(message))
-
- async def _handle_message(cli, message: discord.Message):
- text = message.content
- # Skip non-text messages
- if not text:
- return
- # Skip non-command updates
- if not text.startswith("!"):
- return
- # Skip bot messages
- author: typing.Union[discord.User] = message.author
- if author.bot:
- return
- # Find and clean parameters
- command_text, *parameters = text.split(" ")
- # Don't use a case-sensitive command name
- command_name = command_text.lower()
- # Find the command
- try:
- command = self.commands[command_name]
- except KeyError:
- # Skip the message
- return
- # Prepare data
- data = self._Data(interface=command.interface, message=message)
- # Call the command
- log.debug(f"Calling command '{command.name}'")
- with message.channel.typing():
- # Run the command
- try:
- await command.run(CommandArgs(parameters), data)
- except InvalidInputError as e:
- await data.reply(f":warning: {e.message}\n"
- f"Syntax: [c]/{command.name} {command.syntax}[/c]")
- except UnsupportedError as e:
- await data.reply(f":warning: {e.message}")
- except CommandError as e:
- await data.reply(f":warning: {e.message}")
- except Exception as e:
- sentry_sdk.capture_exception(e)
- error_message = f"🦀 [b]{e.__class__.__name__}[/b] 🦀\n"
- error_message += '\n'.join(e.args)
- await data.reply(error_message)
- # Close the data session
- await data.session_close()
-
- async def on_connect(cli):
- log.debug("Connected to Discord")
-
- async def on_disconnect(cli):
- log.error("Disconnected from Discord!")
+ """Handle messages received by passing them to the handle_message method of the bot."""
+ # TODO: keep reference to these tasks somewhere
+ self.loop.create_task(self.handle_message(message))
async def on_ready(cli) -> None:
- log.debug("Connection successful, client is ready")
+ """Change the bot presence to ``online`` when the bot is ready."""
await cli.change_presence(status=discord.Status.online)
- def find_guild_by_name(cli, name: str) -> typing.List[discord.Guild]:
- """Find the :py:class:`discord.Guild` with the specified name (case insensitive)."""
- all_guilds: typing.List[discord.Guild] = cli.guilds
- matching_channels: typing.List[discord.Guild] = []
+ def find_guild(cli, name: str) -> List[discord.Guild]:
+ """Find the :class:`discord.Guild`s with the specified name (case insensitive).
+
+ Returns:
+ A :class:`list` of :class:`discord.Guild` having the specified name."""
+ all_guilds: List[discord.Guild] = cli.guilds
+ matching_channels: List[discord.Guild] = []
for guild in all_guilds:
if guild.name.lower() == name.lower():
matching_channels.append(guild)
return matching_channels
- def find_channel_by_name(cli,
- name: str,
- guild: typing.Optional[discord.Guild] = None) -> typing.List[discord.abc.GuildChannel]:
- """Find the :py:class:`TextChannel`, :py:class:`VoiceChannel` or :py:class:`CategoryChannel` with the
+ def find_channel(cli,
+ name: str,
+ guild: Optional[discord.Guild] = None) -> List[discord.abc.GuildChannel]:
+ """Find the :class:`TextChannel`, :class:`VoiceChannel` or :class:`CategoryChannel` with the
specified name (case insensitive).
- You can specify a guild to only find channels in that specific guild."""
+ You can specify a guild to only search in that specific guild."""
if guild is not None:
all_channels = guild.channels
else:
- all_channels: typing.List[discord.abc.GuildChannel] = cli.get_all_channels()
- matching_channels: typing.List[discord.abc.GuildChannel] = []
+ all_channels: List[discord.abc.GuildChannel] = cli.get_all_channels()
+ matching_channels: List[discord.abc.GuildChannel] = []
for channel in all_channels:
if not (isinstance(channel, discord.TextChannel)
or isinstance(channel, discord.VoiceChannel)
@@ -186,8 +180,9 @@ class DiscordBot(GenericBot):
matching_channels.append(channel)
return matching_channels
- def find_voice_client_by_guild(cli, guild: discord.Guild) -> typing.Optional[discord.VoiceClient]:
+ def find_voice_client(cli, guild: discord.Guild) -> Optional[discord.VoiceClient]:
"""Find the :py:class:`discord.VoiceClient` belonging to a specific :py:class:`discord.Guild`."""
+ # TODO: the bug I was looking for might be here
for voice_client in cli.voice_clients:
if voice_client.guild == guild:
return voice_client
@@ -195,6 +190,8 @@ class DiscordBot(GenericBot):
return DiscordClient
+ # TODO: restart from here
+
def _init_client(self):
"""Create an instance of the DiscordClient class created in :py:func:`royalnet.bots.DiscordBot._bot_factory`."""
log.debug(f"Creating DiscordClient instance")
diff --git a/royalnet/serf/telegram/telegramserf.py b/royalnet/serf/telegram/telegramserf.py
index 3ee679b6..c6db695c 100644
--- a/royalnet/serf/telegram/telegramserf.py
+++ b/royalnet/serf/telegram/telegramserf.py
@@ -32,7 +32,7 @@ log = logging.getLogger(__name__)
class TelegramSerf(Serf):
- """A Serf that connects to `Telegram `_."""
+ """A Serf that connects to `Telegram `_ as a bot."""
interface_name = "telegram"
def __init__(self, *,
@@ -94,9 +94,6 @@ class TelegramSerf(Serf):
name = self.interface_name
prefix = "/"
- def __init__(self):
- super().__init__()
-
return TelegramInterface
def data_factory(self) -> Type[CommandData]: