1
Fork 0
mirror of https://github.com/RYGhub/royalnet.git synced 2024-11-27 13:34:28 +00:00

Refactor most of the DiscordBot

This commit is contained in:
Steffo 2019-04-19 02:12:37 +02:00
parent f4fc7dd971
commit e677cbe9b3
4 changed files with 185 additions and 154 deletions

View file

@ -3,10 +3,11 @@ import asyncio
import typing
import logging as _logging
import sys
from .generic import GenericBot
from ..commands import NullCommand
from ..utils import asyncify, Call, Command, NetworkHandler
from ..utils import asyncify, Call, Command
from ..error import UnregisteredError, NoneFoundError, TooManyFoundError, InvalidConfigError
from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError
from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError, RoyalnetConfig
from ..database import Alchemy, relationshiplinkchain, DatabaseConfig
from ..audio import RoyalPCMFile, PlayMode, Playlist
@ -18,34 +19,16 @@ if not discord.opus.is_loaded():
log.error("Opus is not loaded. Weird behaviour might emerge.")
class DiscordBot:
def __init__(self,
token: str,
master_server_uri: str,
master_server_secret: str,
commands: typing.List[typing.Type[Command]],
missing_command: typing.Type[Command] = NullCommand,
error_command: typing.Type[Command] = NullCommand,
database_config: typing.Optional[DatabaseConfig] = None):
class DiscordConfig:
def __init__(self, token: str):
self.token = token
# Generate the Alchemy database
if database_config:
self.alchemy = Alchemy(database_config.database_uri, required_tables)
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name)
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
else:
if required_tables:
raise InvalidConfigError("Tables are required by the _commands, but Alchemy is not configured")
self.alchemy = None
self.master_table = None
self.identity_table = None
self.identity_column = None
self.identity_chain = None
# Create the PlayModes dictionary
class DiscordBot(GenericBot):
def _init_voice(self):
self.music_data: typing.Dict[discord.Guild, PlayMode] = {}
def _call_factory(self) -> typing.Type[Call]:
# noinspection PyMethodParameters
class DiscordCall(Call):
interface_name = "discord"
@ -55,6 +38,7 @@ class DiscordBot:
alchemy = self.alchemy
async def reply(call, text: str):
# TODO: don't escape characters inside [c][/c] blocks
escaped_text = text.replace("*", "\\*") \
.replace("_", "\\_") \
.replace("`", "\\`") \
@ -69,6 +53,8 @@ class DiscordBot:
await call.channel.send(escaped_text)
async def net_request(call, message: Message, destination: str):
if self.network is None:
raise InvalidConfigError("Royalnet is not enabled on this bot")
response = await self.network.request(message, destination)
if isinstance(response, RequestError):
raise response.exc
@ -86,19 +72,20 @@ class DiscordBot:
raise UnregisteredError("Author is not registered")
return result
self.DiscordCall = DiscordCall
return DiscordCall
def _bot_factory(self) -> typing.Type[discord.Client]:
"""Create a new DiscordClient class based on this DiscordBot."""
# noinspection PyMethodParameters
class DiscordClient(discord.Client):
@staticmethod
async def vc_connect_or_move(channel: discord.VoiceChannel):
async def vc_connect_or_move(cli, channel: discord.VoiceChannel):
# Connect to voice chat
try:
await channel.connect()
except discord.errors.ClientException:
# Move to the selected channel, instead of connecting
# noinspection PyUnusedLocal
for voice_client in self.bot.voice_clients:
for voice_client in cli.voice_clients:
voice_client: discord.VoiceClient
if voice_client.guild != channel.guild:
continue
@ -107,124 +94,123 @@ class DiscordBot:
if not self.music_data.get(channel.guild):
self.music_data[channel.guild] = Playlist()
async def on_message(cli, message: discord.Message):
@staticmethod # Not really static because of the self reference
async def on_message(message: discord.Message):
text = message.content
# Skip non-text messages
if not text:
return
# Find and clean parameters
command_text, *parameters = text.split(" ")
# Find the function
try:
selected_command = self.commands[command_text]
except KeyError:
# Skip inexistent _commands
selected_command = self.missing_command
log.error(f"Running {selected_command}")
# Call the command
try:
return await self.DiscordCall(message.channel, selected_command, parameters, log,
message=message).run()
except Exception as exc:
try:
return await self.DiscordCall(message.channel, self.error_command, parameters, log,
message=message,
exception_info=sys.exc_info(),
previous_command=selected_command).run()
except Exception as exc2:
log.error(f"Exception in error handler command: {exc2}")
await self.call(command_text, message.channel, parameters, message=message)
self.DiscordClient = DiscordClient
self.bot = self.DiscordClient()
def find_guild_by_name(cli, name: str) -> discord.Guild:
"""Find the Guild with the specified name. Case-insensitive.
Will raise a NoneFoundError if no channels are found, or a TooManyFoundError if more than one is found."""
all_guilds: typing.List[discord.Guild] = cli.guilds
matching_channels: typing.List[discord.Guild] = []
for guild in all_guilds:
if guild.name.lower() == name.lower():
matching_channels.append(guild)
if len(matching_channels) == 0:
raise NoneFoundError("No channels were found")
elif len(matching_channels) > 1:
raise TooManyFoundError("Too many channels were found")
return matching_channels[0]
def find_guild(self, identifier: typing.Union[str, int]) -> discord.Guild:
"""Find the Guild with the specified identifier. Names are case-insensitive."""
if isinstance(identifier, str):
all_guilds: typing.List[discord.Guild] = self.bot.guilds
matching_channels: typing.List[discord.Guild] = []
for guild in all_guilds:
if guild.name.lower() == identifier.lower():
matching_channels.append(guild)
if len(matching_channels) == 0:
raise NoneFoundError("No channels were found")
elif len(matching_channels) > 1:
raise TooManyFoundError("Too many channels were found")
return matching_channels[0]
elif isinstance(identifier, int):
return self.bot.get_guild(identifier)
raise TypeError("Invalid identifier type, should be str or int")
def find_channel_by_name(cli,
name: str,
guild: typing.Optional[discord.Guild] = None) -> discord.abc.GuildChannel:
"""Find the TextChannel, VoiceChannel or CategoryChannel with the specified name. Case-insensitive.
Guild is optional, but the method will raise a TooManyFoundError if none is specified and there is more than one channel with the same name.
Will also raise a NoneFoundError if no channels are found."""
if guild is not None:
all_channels = guild.channels
else:
all_channels: typing.List[discord.abc.GuildChannel] = cli.get_all_channels()
matching_channels: typing.List[discord.abc.GuildChannel] = []
for channel in all_channels:
if not (isinstance(channel, discord.TextChannel)
or isinstance(channel, discord.VoiceChannel)
or isinstance(channel, discord.CategoryChannel)):
continue
if channel.name.lower() == name.lower():
matching_channels.append(channel)
if len(matching_channels) == 0:
raise NoneFoundError("No channels were found")
elif len(matching_channels) > 1:
raise TooManyFoundError("Too many channels were found")
return matching_channels[0]
def find_channel(self,
identifier: typing.Union[str, int],
guild: typing.Optional[discord.Guild] = None) -> discord.abc.GuildChannel:
"""Find the GuildChannel with the specified identifier. Names are case-insensitive."""
if isinstance(identifier, str):
if guild is not None:
all_channels = guild.channels
else:
all_channels: typing.List[discord.abc.GuildChannel] = self.bot.get_all_channels()
matching_channels: typing.List[discord.abc.GuildChannel] = []
for channel in all_channels:
if not (isinstance(channel, discord.TextChannel)
or isinstance(channel, discord.VoiceChannel)
or isinstance(channel, discord.CategoryChannel)):
continue
if channel.name.lower() == identifier.lower():
matching_channels.append(channel)
if len(matching_channels) == 0:
raise NoneFoundError("No channels were found")
elif len(matching_channels) > 1:
raise TooManyFoundError("Too many channels were found")
return matching_channels[0]
elif isinstance(identifier, int):
channel: typing.List[discord.abc.GuildChannel] = self.bot.get_channel(identifier)
if ((isinstance(channel, discord.TextChannel)
or isinstance(channel, discord.VoiceChannel)
or isinstance(channel, discord.CategoryChannel))
and guild):
assert channel.guild == guild
return channel
raise TypeError("Invalid identifier type, should be str or int")
def find_voice_client_by_guild(cli, guild: discord.Guild):
"""Find the VoiceClient belonging to a specific Guild.
Raises a NoneFoundError if the Guild currently has no VoiceClient."""
for voice_client in cli.voice_clients:
voice_client: discord.VoiceClient
if voice_client.guild == guild:
return voice_client
raise NoneFoundError("No voice clients found")
def find_voice_client(self, guild: discord.Guild):
for voice_client in self.bot.voice_clients:
voice_client: discord.VoiceClient
if voice_client.guild == guild:
return voice_client
raise NoneFoundError("No voice clients found")
return DiscordClient
async def add_to_music_data(self, url: str, guild: discord.Guild):
"""Add a file to the corresponding music_data object."""
log.debug(f"Downloading {url} to add to music_data")
files: typing.List[RoyalPCMFile] = await asyncify(RoyalPCMFile.create_from_url, url)
guild_music_data = self.music_data[guild]
for file in files:
log.debug(f"Adding {file} to music_data")
guild_music_data.add(file)
if guild_music_data.now_playing is None:
log.debug(f"Starting playback chain")
await self.advance_music_data(guild)
def _init_bot(self):
"""Create a bot instance."""
self.bot = self._bot_factory()()
async def advance_music_data(self, guild: discord.Guild):
"""Try to play the next song, while it exists. Otherwise, just return."""
guild_music_data = self.music_data[guild]
voice_client = self.find_voice_client(guild)
next_file: RoyalPCMFile = await guild_music_data.next()
if next_file is None:
log.debug(f"Ending playback chain")
return
def advance(error=None):
log.debug(f"Deleting {next_file}")
next_file.delete_audio_file()
loop.create_task(self.advance_music_data(guild))
log.debug(f"Creating AudioSource of {next_file}")
next_source = next_file.create_audio_source()
log.debug(f"Starting playback of {next_source}")
voice_client.play(next_source, after=advance)
def __init__(self, *,
discord_config: DiscordConfig,
royalnet_config: RoyalnetConfig,
database_config: typing.Optional[DatabaseConfig] = None,
commands: typing.List[typing.Type[Command]] = None,
missing_command: typing.Type[Command] = NullCommand,
error_command: typing.Type[Command] = NullCommand):
super().__init__(royalnet_config=royalnet_config,
database_config=database_config,
commands=commands,
missing_command=missing_command,
error_command=error_command)
self._discord_config = discord_config
self._init_bot()
async def run(self):
await self.bot.login(self.token)
await self.bot.login(self._discord_config.token)
await self.bot.connect()
# TODO: how to stop?
# class DiscordBot:
# async def add_to_music_data(self, url: str, guild: discord.Guild):
# """Add a file to the corresponding music_data object."""
# log.debug(f"Downloading {url} to add to music_data")
# files: typing.List[RoyalPCMFile] = await asyncify(RoyalPCMFile.create_from_url, url)
# guild_music_data = self.music_data[guild]
# for file in files:
# log.debug(f"Adding {file} to music_data")
# guild_music_data.add(file)
# if guild_music_data.now_playing is None:
# log.debug(f"Starting playback chain")
# await self.advance_music_data(guild)
#
# async def advance_music_data(self, guild: discord.Guild):
# """Try to play the next song, while it exists. Otherwise, just return."""
# guild_music_data = self.music_data[guild]
# voice_client = self.find_voice_client(guild)
# next_file: RoyalPCMFile = await guild_music_data.next()
# if next_file is None:
# log.debug(f"Ending playback chain")
# return
#
# def advance(error=None):
# log.debug(f"Deleting {next_file}")
# next_file.delete_audio_file()
# loop.create_task(self.advance_music_data(guild))
#
# log.debug(f"Creating AudioSource of {next_file}")
# next_source = next_file.create_audio_source()
# log.debug(f"Starting playback of {next_source}")
# voice_client.play(next_source, after=advance)
#
#

View file

@ -1,7 +1,8 @@
import sys
import typing
import asyncio
import logging
from ..utils import Command, NetworkHandler
from ..utils import Command, NetworkHandler, Call
from ..commands import NullCommand
from ..network import RoyalnetLink, Message, RequestError, RoyalnetConfig
from ..database import Alchemy, DatabaseConfig, relationshiplinkchain
@ -11,10 +12,12 @@ log = logging.getLogger(__name__)
class GenericBot:
"""A generic bot class, to be used as base for the other more specific classes, such as TelegramBot and DiscordBot."""
def _init_commands(self,
commands: typing.List[typing.Type[Command]],
missing_command: typing.Type[Command],
error_command: typing.Type[Command]):
"""Generate the commands dictionary required to handle incoming messages, and the network_handlers dictionary required to handle incoming requests."""
log.debug(f"Now generating commands")
self.commands: typing.Dict[str, typing.Type[Command]] = {}
self.network_handlers: typing.Dict[typing.Type[Message], typing.Type[NetworkHandler]] = {}
@ -25,47 +28,61 @@ class GenericBot:
self.error_command: typing.Type[Command] = error_command
log.debug(f"Successfully generated commands")
def _call_factory(self) -> typing.Type[Call]:
"""Create the Call class, representing a Call command. It should inherit from the utils.Call class."""
raise NotImplementedError()
def _init_royalnet(self, royalnet_config: RoyalnetConfig):
"""Create a RoyalnetLink, and run it as a task."""
self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, "discord",
self._network_handler)
log.debug(f"Running RoyalnetLink {self.network}")
loop.create_task(self.network.run())
def _network_handler(self, message: Message) -> Message:
log.debug(f"Received {message} from the RoyalnetLink")
try:
network_handler = self.network_handlers[message.__class__]
except KeyError as exc:
log.debug(f"Missing network_handler for {message}")
return RequestError(KeyError("Missing network_handler"))
try:
log.debug(f"Using {network_handler} as handler for {message}")
return await network_handler.discord(message)
except Exception as exc:
log.debug(f"Exception {exc} in {network_handler}")
return RequestError(exc)
"""Handle a single Message received from the RoyalnetLink"""
log.debug(f"Received {message} from the RoyalnetLink")
try:
network_handler = self.network_handlers[message.__class__]
except KeyError as exc:
log.debug(f"Missing network_handler for {message}")
return RequestError(KeyError("Missing network_handler"))
try:
log.debug(f"Using {network_handler} as handler for {message}")
return await network_handler.discord(message)
except Exception as exc:
log.debug(f"Exception {exc} in {network_handler}")
return RequestError(exc)
def _init_database(self, commands: typing.List[typing.Type[Command]], database_config: DatabaseConfig):
"""Connect to the database, and create the missing tables required by the selected commands."""
log.debug(f"Initializing database")
required_tables = set()
for command in commands:
required_tables = required_tables.union(command.require_alchemy_tables)
log.debug(f"Found {len(required_tables)} required tables")
self.alchemy = Alchemy(database_config.database_uri, required_tables)
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
self.identity_column = self.identity_table.__getattribute__(self.identity_table,
database_config.identity_column_name)
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
log.debug(f"Identity chain is {self.identity_chain}")
def __init__(self, *,
royalnet_config: RoyalnetConfig,
royalnet_config: typing.Optional[RoyalnetConfig] = None,
database_config: typing.Optional[DatabaseConfig] = None,
commands: typing.Optional[typing.List[typing.Type[Command]]] = None,
commands: typing.List[typing.Type[Command]] = None,
missing_command: typing.Type[Command] = NullCommand,
error_command: typing.Type[Command] = NullCommand):
if commands is None:
commands = []
self._init_commands(commands, missing_command=missing_command, error_command=error_command)
self._init_royalnet(royalnet_config=royalnet_config)
self._Call = self._call_factory()
if royalnet_config is None:
self.network = None
else:
self._init_royalnet(royalnet_config=royalnet_config)
if database_config is None:
self.alchemy = None
self.master_table = None
@ -73,3 +90,22 @@ class GenericBot:
self.identity_column = None
else:
self._init_database(commands=commands, database_config=database_config)
async def call(self, command_name: str, channel, parameters: typing.List[str] = None, **kwargs):
"""Call a command by its string, or missing_command if it doesn't exists, or error_command if an exception is raised during the execution."""
if parameters is None:
parameters = []
try:
command: typing.Type[Command] = self.commands[command_name]
except KeyError:
command = self.missing_command
try:
await self._Call(channel, command, parameters, **kwargs).run()
except Exception as exc:
await self._Call(channel, self.error_command,
exception_info=sys.exc_info(),
previous_command=command, **kwargs).run()
async def run(self):
"""A blocking coroutine that should make the bot start listening to commands and requests."""
raise NotImplementedError()

View file

@ -1,3 +1,4 @@
import logging as _logging
import traceback
from ..utils import Command, Call
from ..error import NoneFoundError, \
@ -9,6 +10,9 @@ from ..error import NoneFoundError, \
ExternalError
log = _logging.getLogger(__name__)
class ErrorHandlerCommand(Command):
command_name = "error_handler"
@ -46,4 +50,4 @@ class ErrorHandlerCommand(Command):
return
await call.reply(f"❌ Eccezione non gestita durante l'esecuzione del comando:\n[b]{e_type.__name__}[/b]\n{e_value}")
formatted_tb: str = '\n'.join(traceback.format_tb(e_tb))
call.logger.error(f"Unhandled exception - {e_type.__name__}: {e_value}\n{formatted_tb}")
log.error(f"Unhandled exception - {e_type.__name__}: {e_value}\n{formatted_tb}")

View file

@ -36,13 +36,18 @@ class Call:
raise NotImplementedError()
# These parameters / methods should be left alone
def __init__(self, channel, command: typing.Type[Command], command_args: list, logger: logging.Logger, **kwargs):
def __init__(self,
channel,
command: typing.Type[Command],
command_args: typing.List[str] = None,
**kwargs):
if command_args is None:
command_args = []
self.channel = channel
self.command = command
self.args = CommandArgs(command_args)
self.kwargs = kwargs
self.session = None
self.logger = logger
async def session_init(self):
if not self.command.require_alchemy_tables: