From e677cbe9b3a4361c82fd7d7ba3741ff771402dd5 Mon Sep 17 00:00:00 2001 From: Stefano Pigozzi Date: Fri, 19 Apr 2019 02:12:37 +0200 Subject: [PATCH] Refactor most of the DiscordBot --- royalnet/bots/discord.py | 256 ++++++++++----------- royalnet/{utils/bot.py => bots/generic.py} | 68 ++++-- royalnet/commands/error_handler.py | 6 +- royalnet/utils/call.py | 9 +- 4 files changed, 185 insertions(+), 154 deletions(-) rename royalnet/{utils/bot.py => bots/generic.py} (52%) diff --git a/royalnet/bots/discord.py b/royalnet/bots/discord.py index 83f52ebb..8f600520 100644 --- a/royalnet/bots/discord.py +++ b/royalnet/bots/discord.py @@ -3,10 +3,11 @@ import asyncio import typing import logging as _logging import sys +from .generic import GenericBot from ..commands import NullCommand -from ..utils import asyncify, Call, Command, NetworkHandler +from ..utils import asyncify, Call, Command from ..error import UnregisteredError, NoneFoundError, TooManyFoundError, InvalidConfigError -from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError +from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError, RoyalnetConfig from ..database import Alchemy, relationshiplinkchain, DatabaseConfig from ..audio import RoyalPCMFile, PlayMode, Playlist @@ -18,34 +19,16 @@ if not discord.opus.is_loaded(): log.error("Opus is not loaded. Weird behaviour might emerge.") -class DiscordBot: - def __init__(self, - token: str, - master_server_uri: str, - master_server_secret: str, - commands: typing.List[typing.Type[Command]], - missing_command: typing.Type[Command] = NullCommand, - error_command: typing.Type[Command] = NullCommand, - database_config: typing.Optional[DatabaseConfig] = None): +class DiscordConfig: + def __init__(self, token: str): self.token = token - # Generate the Alchemy database - if database_config: - self.alchemy = Alchemy(database_config.database_uri, required_tables) - self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__) - self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__) - self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name) - self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table) - else: - if required_tables: - raise InvalidConfigError("Tables are required by the _commands, but Alchemy is not configured") - self.alchemy = None - self.master_table = None - self.identity_table = None - self.identity_column = None - self.identity_chain = None - # Create the PlayModes dictionary + + +class DiscordBot(GenericBot): + def _init_voice(self): self.music_data: typing.Dict[discord.Guild, PlayMode] = {} + def _call_factory(self) -> typing.Type[Call]: # noinspection PyMethodParameters class DiscordCall(Call): interface_name = "discord" @@ -55,6 +38,7 @@ class DiscordBot: alchemy = self.alchemy async def reply(call, text: str): + # TODO: don't escape characters inside [c][/c] blocks escaped_text = text.replace("*", "\\*") \ .replace("_", "\\_") \ .replace("`", "\\`") \ @@ -69,6 +53,8 @@ class DiscordBot: await call.channel.send(escaped_text) async def net_request(call, message: Message, destination: str): + if self.network is None: + raise InvalidConfigError("Royalnet is not enabled on this bot") response = await self.network.request(message, destination) if isinstance(response, RequestError): raise response.exc @@ -86,19 +72,20 @@ class DiscordBot: raise UnregisteredError("Author is not registered") return result - self.DiscordCall = DiscordCall + return DiscordCall + def _bot_factory(self) -> typing.Type[discord.Client]: + """Create a new DiscordClient class based on this DiscordBot.""" # noinspection PyMethodParameters class DiscordClient(discord.Client): - @staticmethod - async def vc_connect_or_move(channel: discord.VoiceChannel): + async def vc_connect_or_move(cli, channel: discord.VoiceChannel): # Connect to voice chat try: await channel.connect() except discord.errors.ClientException: # Move to the selected channel, instead of connecting # noinspection PyUnusedLocal - for voice_client in self.bot.voice_clients: + for voice_client in cli.voice_clients: voice_client: discord.VoiceClient if voice_client.guild != channel.guild: continue @@ -107,124 +94,123 @@ class DiscordBot: if not self.music_data.get(channel.guild): self.music_data[channel.guild] = Playlist() - async def on_message(cli, message: discord.Message): + @staticmethod # Not really static because of the self reference + async def on_message(message: discord.Message): text = message.content # Skip non-text messages if not text: return # Find and clean parameters command_text, *parameters = text.split(" ") - # Find the function - try: - selected_command = self.commands[command_text] - except KeyError: - # Skip inexistent _commands - selected_command = self.missing_command - log.error(f"Running {selected_command}") # Call the command - try: - return await self.DiscordCall(message.channel, selected_command, parameters, log, - message=message).run() - except Exception as exc: - try: - return await self.DiscordCall(message.channel, self.error_command, parameters, log, - message=message, - exception_info=sys.exc_info(), - previous_command=selected_command).run() - except Exception as exc2: - log.error(f"Exception in error handler command: {exc2}") + await self.call(command_text, message.channel, parameters, message=message) - self.DiscordClient = DiscordClient - self.bot = self.DiscordClient() + def find_guild_by_name(cli, name: str) -> discord.Guild: + """Find the Guild with the specified name. Case-insensitive. + Will raise a NoneFoundError if no channels are found, or a TooManyFoundError if more than one is found.""" + all_guilds: typing.List[discord.Guild] = cli.guilds + matching_channels: typing.List[discord.Guild] = [] + for guild in all_guilds: + if guild.name.lower() == name.lower(): + matching_channels.append(guild) + if len(matching_channels) == 0: + raise NoneFoundError("No channels were found") + elif len(matching_channels) > 1: + raise TooManyFoundError("Too many channels were found") + return matching_channels[0] - def find_guild(self, identifier: typing.Union[str, int]) -> discord.Guild: - """Find the Guild with the specified identifier. Names are case-insensitive.""" - if isinstance(identifier, str): - all_guilds: typing.List[discord.Guild] = self.bot.guilds - matching_channels: typing.List[discord.Guild] = [] - for guild in all_guilds: - if guild.name.lower() == identifier.lower(): - matching_channels.append(guild) - if len(matching_channels) == 0: - raise NoneFoundError("No channels were found") - elif len(matching_channels) > 1: - raise TooManyFoundError("Too many channels were found") - return matching_channels[0] - elif isinstance(identifier, int): - return self.bot.get_guild(identifier) - raise TypeError("Invalid identifier type, should be str or int") + def find_channel_by_name(cli, + name: str, + guild: typing.Optional[discord.Guild] = None) -> discord.abc.GuildChannel: + """Find the TextChannel, VoiceChannel or CategoryChannel with the specified name. Case-insensitive. + Guild is optional, but the method will raise a TooManyFoundError if none is specified and there is more than one channel with the same name. + Will also raise a NoneFoundError if no channels are found.""" + if guild is not None: + all_channels = guild.channels + else: + all_channels: typing.List[discord.abc.GuildChannel] = cli.get_all_channels() + matching_channels: typing.List[discord.abc.GuildChannel] = [] + for channel in all_channels: + if not (isinstance(channel, discord.TextChannel) + or isinstance(channel, discord.VoiceChannel) + or isinstance(channel, discord.CategoryChannel)): + continue + if channel.name.lower() == name.lower(): + matching_channels.append(channel) + if len(matching_channels) == 0: + raise NoneFoundError("No channels were found") + elif len(matching_channels) > 1: + raise TooManyFoundError("Too many channels were found") + return matching_channels[0] - def find_channel(self, - identifier: typing.Union[str, int], - guild: typing.Optional[discord.Guild] = None) -> discord.abc.GuildChannel: - """Find the GuildChannel with the specified identifier. Names are case-insensitive.""" - if isinstance(identifier, str): - if guild is not None: - all_channels = guild.channels - else: - all_channels: typing.List[discord.abc.GuildChannel] = self.bot.get_all_channels() - matching_channels: typing.List[discord.abc.GuildChannel] = [] - for channel in all_channels: - if not (isinstance(channel, discord.TextChannel) - or isinstance(channel, discord.VoiceChannel) - or isinstance(channel, discord.CategoryChannel)): - continue - if channel.name.lower() == identifier.lower(): - matching_channels.append(channel) - if len(matching_channels) == 0: - raise NoneFoundError("No channels were found") - elif len(matching_channels) > 1: - raise TooManyFoundError("Too many channels were found") - return matching_channels[0] - elif isinstance(identifier, int): - channel: typing.List[discord.abc.GuildChannel] = self.bot.get_channel(identifier) - if ((isinstance(channel, discord.TextChannel) - or isinstance(channel, discord.VoiceChannel) - or isinstance(channel, discord.CategoryChannel)) - and guild): - assert channel.guild == guild - return channel - raise TypeError("Invalid identifier type, should be str or int") + def find_voice_client_by_guild(cli, guild: discord.Guild): + """Find the VoiceClient belonging to a specific Guild. + Raises a NoneFoundError if the Guild currently has no VoiceClient.""" + for voice_client in cli.voice_clients: + voice_client: discord.VoiceClient + if voice_client.guild == guild: + return voice_client + raise NoneFoundError("No voice clients found") - def find_voice_client(self, guild: discord.Guild): - for voice_client in self.bot.voice_clients: - voice_client: discord.VoiceClient - if voice_client.guild == guild: - return voice_client - raise NoneFoundError("No voice clients found") + return DiscordClient - async def add_to_music_data(self, url: str, guild: discord.Guild): - """Add a file to the corresponding music_data object.""" - log.debug(f"Downloading {url} to add to music_data") - files: typing.List[RoyalPCMFile] = await asyncify(RoyalPCMFile.create_from_url, url) - guild_music_data = self.music_data[guild] - for file in files: - log.debug(f"Adding {file} to music_data") - guild_music_data.add(file) - if guild_music_data.now_playing is None: - log.debug(f"Starting playback chain") - await self.advance_music_data(guild) + def _init_bot(self): + """Create a bot instance.""" + self.bot = self._bot_factory()() - async def advance_music_data(self, guild: discord.Guild): - """Try to play the next song, while it exists. Otherwise, just return.""" - guild_music_data = self.music_data[guild] - voice_client = self.find_voice_client(guild) - next_file: RoyalPCMFile = await guild_music_data.next() - if next_file is None: - log.debug(f"Ending playback chain") - return - - def advance(error=None): - log.debug(f"Deleting {next_file}") - next_file.delete_audio_file() - loop.create_task(self.advance_music_data(guild)) - - log.debug(f"Creating AudioSource of {next_file}") - next_source = next_file.create_audio_source() - log.debug(f"Starting playback of {next_source}") - voice_client.play(next_source, after=advance) + def __init__(self, *, + discord_config: DiscordConfig, + royalnet_config: RoyalnetConfig, + database_config: typing.Optional[DatabaseConfig] = None, + commands: typing.List[typing.Type[Command]] = None, + missing_command: typing.Type[Command] = NullCommand, + error_command: typing.Type[Command] = NullCommand): + super().__init__(royalnet_config=royalnet_config, + database_config=database_config, + commands=commands, + missing_command=missing_command, + error_command=error_command) + self._discord_config = discord_config + self._init_bot() async def run(self): - await self.bot.login(self.token) + await self.bot.login(self._discord_config.token) await self.bot.connect() # TODO: how to stop? + + + +# class DiscordBot: +# async def add_to_music_data(self, url: str, guild: discord.Guild): +# """Add a file to the corresponding music_data object.""" +# log.debug(f"Downloading {url} to add to music_data") +# files: typing.List[RoyalPCMFile] = await asyncify(RoyalPCMFile.create_from_url, url) +# guild_music_data = self.music_data[guild] +# for file in files: +# log.debug(f"Adding {file} to music_data") +# guild_music_data.add(file) +# if guild_music_data.now_playing is None: +# log.debug(f"Starting playback chain") +# await self.advance_music_data(guild) +# +# async def advance_music_data(self, guild: discord.Guild): +# """Try to play the next song, while it exists. Otherwise, just return.""" +# guild_music_data = self.music_data[guild] +# voice_client = self.find_voice_client(guild) +# next_file: RoyalPCMFile = await guild_music_data.next() +# if next_file is None: +# log.debug(f"Ending playback chain") +# return +# +# def advance(error=None): +# log.debug(f"Deleting {next_file}") +# next_file.delete_audio_file() +# loop.create_task(self.advance_music_data(guild)) +# +# log.debug(f"Creating AudioSource of {next_file}") +# next_source = next_file.create_audio_source() +# log.debug(f"Starting playback of {next_source}") +# voice_client.play(next_source, after=advance) +# + +# \ No newline at end of file diff --git a/royalnet/utils/bot.py b/royalnet/bots/generic.py similarity index 52% rename from royalnet/utils/bot.py rename to royalnet/bots/generic.py index 6d46de38..f9dd8592 100644 --- a/royalnet/utils/bot.py +++ b/royalnet/bots/generic.py @@ -1,7 +1,8 @@ +import sys import typing import asyncio import logging -from ..utils import Command, NetworkHandler +from ..utils import Command, NetworkHandler, Call from ..commands import NullCommand from ..network import RoyalnetLink, Message, RequestError, RoyalnetConfig from ..database import Alchemy, DatabaseConfig, relationshiplinkchain @@ -11,10 +12,12 @@ log = logging.getLogger(__name__) class GenericBot: + """A generic bot class, to be used as base for the other more specific classes, such as TelegramBot and DiscordBot.""" def _init_commands(self, commands: typing.List[typing.Type[Command]], missing_command: typing.Type[Command], error_command: typing.Type[Command]): + """Generate the commands dictionary required to handle incoming messages, and the network_handlers dictionary required to handle incoming requests.""" log.debug(f"Now generating commands") self.commands: typing.Dict[str, typing.Type[Command]] = {} self.network_handlers: typing.Dict[typing.Type[Message], typing.Type[NetworkHandler]] = {} @@ -25,47 +28,61 @@ class GenericBot: self.error_command: typing.Type[Command] = error_command log.debug(f"Successfully generated commands") + def _call_factory(self) -> typing.Type[Call]: + """Create the Call class, representing a Call command. It should inherit from the utils.Call class.""" + raise NotImplementedError() + def _init_royalnet(self, royalnet_config: RoyalnetConfig): + """Create a RoyalnetLink, and run it as a task.""" self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, "discord", self._network_handler) log.debug(f"Running RoyalnetLink {self.network}") loop.create_task(self.network.run()) def _network_handler(self, message: Message) -> Message: - log.debug(f"Received {message} from the RoyalnetLink") - try: - network_handler = self.network_handlers[message.__class__] - except KeyError as exc: - log.debug(f"Missing network_handler for {message}") - return RequestError(KeyError("Missing network_handler")) - try: - log.debug(f"Using {network_handler} as handler for {message}") - return await network_handler.discord(message) - except Exception as exc: - log.debug(f"Exception {exc} in {network_handler}") - return RequestError(exc) + """Handle a single Message received from the RoyalnetLink""" + log.debug(f"Received {message} from the RoyalnetLink") + try: + network_handler = self.network_handlers[message.__class__] + except KeyError as exc: + log.debug(f"Missing network_handler for {message}") + return RequestError(KeyError("Missing network_handler")) + try: + log.debug(f"Using {network_handler} as handler for {message}") + return await network_handler.discord(message) + except Exception as exc: + log.debug(f"Exception {exc} in {network_handler}") + return RequestError(exc) def _init_database(self, commands: typing.List[typing.Type[Command]], database_config: DatabaseConfig): + """Connect to the database, and create the missing tables required by the selected commands.""" + log.debug(f"Initializing database") required_tables = set() for command in commands: required_tables = required_tables.union(command.require_alchemy_tables) + log.debug(f"Found {len(required_tables)} required tables") self.alchemy = Alchemy(database_config.database_uri, required_tables) self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__) self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__) self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name) self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table) + log.debug(f"Identity chain is {self.identity_chain}") def __init__(self, *, - royalnet_config: RoyalnetConfig, + royalnet_config: typing.Optional[RoyalnetConfig] = None, database_config: typing.Optional[DatabaseConfig] = None, - commands: typing.Optional[typing.List[typing.Type[Command]]] = None, + commands: typing.List[typing.Type[Command]] = None, missing_command: typing.Type[Command] = NullCommand, error_command: typing.Type[Command] = NullCommand): if commands is None: commands = [] self._init_commands(commands, missing_command=missing_command, error_command=error_command) - self._init_royalnet(royalnet_config=royalnet_config) + self._Call = self._call_factory() + if royalnet_config is None: + self.network = None + else: + self._init_royalnet(royalnet_config=royalnet_config) if database_config is None: self.alchemy = None self.master_table = None @@ -73,3 +90,22 @@ class GenericBot: self.identity_column = None else: self._init_database(commands=commands, database_config=database_config) + + async def call(self, command_name: str, channel, parameters: typing.List[str] = None, **kwargs): + """Call a command by its string, or missing_command if it doesn't exists, or error_command if an exception is raised during the execution.""" + if parameters is None: + parameters = [] + try: + command: typing.Type[Command] = self.commands[command_name] + except KeyError: + command = self.missing_command + try: + await self._Call(channel, command, parameters, **kwargs).run() + except Exception as exc: + await self._Call(channel, self.error_command, + exception_info=sys.exc_info(), + previous_command=command, **kwargs).run() + + async def run(self): + """A blocking coroutine that should make the bot start listening to commands and requests.""" + raise NotImplementedError() diff --git a/royalnet/commands/error_handler.py b/royalnet/commands/error_handler.py index 0081ffde..7ef38ef2 100644 --- a/royalnet/commands/error_handler.py +++ b/royalnet/commands/error_handler.py @@ -1,3 +1,4 @@ +import logging as _logging import traceback from ..utils import Command, Call from ..error import NoneFoundError, \ @@ -9,6 +10,9 @@ from ..error import NoneFoundError, \ ExternalError +log = _logging.getLogger(__name__) + + class ErrorHandlerCommand(Command): command_name = "error_handler" @@ -46,4 +50,4 @@ class ErrorHandlerCommand(Command): return await call.reply(f"❌ Eccezione non gestita durante l'esecuzione del comando:\n[b]{e_type.__name__}[/b]\n{e_value}") formatted_tb: str = '\n'.join(traceback.format_tb(e_tb)) - call.logger.error(f"Unhandled exception - {e_type.__name__}: {e_value}\n{formatted_tb}") + log.error(f"Unhandled exception - {e_type.__name__}: {e_value}\n{formatted_tb}") diff --git a/royalnet/utils/call.py b/royalnet/utils/call.py index e5fb007d..d5c7db91 100644 --- a/royalnet/utils/call.py +++ b/royalnet/utils/call.py @@ -36,13 +36,18 @@ class Call: raise NotImplementedError() # These parameters / methods should be left alone - def __init__(self, channel, command: typing.Type[Command], command_args: list, logger: logging.Logger, **kwargs): + def __init__(self, + channel, + command: typing.Type[Command], + command_args: typing.List[str] = None, + **kwargs): + if command_args is None: + command_args = [] self.channel = channel self.command = command self.args = CommandArgs(command_args) self.kwargs = kwargs self.session = None - self.logger = logger async def session_init(self): if not self.command.require_alchemy_tables: