1
Fork 0
mirror of https://github.com/RYGhub/royalnet.git synced 2024-11-23 19:44:20 +00:00

Refactor most of the DiscordBot

This commit is contained in:
Steffo 2019-04-19 02:12:37 +02:00
parent f4fc7dd971
commit e677cbe9b3
4 changed files with 185 additions and 154 deletions

View file

@ -3,10 +3,11 @@ import asyncio
import typing import typing
import logging as _logging import logging as _logging
import sys import sys
from .generic import GenericBot
from ..commands import NullCommand from ..commands import NullCommand
from ..utils import asyncify, Call, Command, NetworkHandler from ..utils import asyncify, Call, Command
from ..error import UnregisteredError, NoneFoundError, TooManyFoundError, InvalidConfigError from ..error import UnregisteredError, NoneFoundError, TooManyFoundError, InvalidConfigError
from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError, RoyalnetConfig
from ..database import Alchemy, relationshiplinkchain, DatabaseConfig from ..database import Alchemy, relationshiplinkchain, DatabaseConfig
from ..audio import RoyalPCMFile, PlayMode, Playlist from ..audio import RoyalPCMFile, PlayMode, Playlist
@ -18,34 +19,16 @@ if not discord.opus.is_loaded():
log.error("Opus is not loaded. Weird behaviour might emerge.") log.error("Opus is not loaded. Weird behaviour might emerge.")
class DiscordBot: class DiscordConfig:
def __init__(self, def __init__(self, token: str):
token: str,
master_server_uri: str,
master_server_secret: str,
commands: typing.List[typing.Type[Command]],
missing_command: typing.Type[Command] = NullCommand,
error_command: typing.Type[Command] = NullCommand,
database_config: typing.Optional[DatabaseConfig] = None):
self.token = token self.token = token
# Generate the Alchemy database
if database_config:
self.alchemy = Alchemy(database_config.database_uri, required_tables) class DiscordBot(GenericBot):
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__) def _init_voice(self):
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name)
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
else:
if required_tables:
raise InvalidConfigError("Tables are required by the _commands, but Alchemy is not configured")
self.alchemy = None
self.master_table = None
self.identity_table = None
self.identity_column = None
self.identity_chain = None
# Create the PlayModes dictionary
self.music_data: typing.Dict[discord.Guild, PlayMode] = {} self.music_data: typing.Dict[discord.Guild, PlayMode] = {}
def _call_factory(self) -> typing.Type[Call]:
# noinspection PyMethodParameters # noinspection PyMethodParameters
class DiscordCall(Call): class DiscordCall(Call):
interface_name = "discord" interface_name = "discord"
@ -55,6 +38,7 @@ class DiscordBot:
alchemy = self.alchemy alchemy = self.alchemy
async def reply(call, text: str): async def reply(call, text: str):
# TODO: don't escape characters inside [c][/c] blocks
escaped_text = text.replace("*", "\\*") \ escaped_text = text.replace("*", "\\*") \
.replace("_", "\\_") \ .replace("_", "\\_") \
.replace("`", "\\`") \ .replace("`", "\\`") \
@ -69,6 +53,8 @@ class DiscordBot:
await call.channel.send(escaped_text) await call.channel.send(escaped_text)
async def net_request(call, message: Message, destination: str): async def net_request(call, message: Message, destination: str):
if self.network is None:
raise InvalidConfigError("Royalnet is not enabled on this bot")
response = await self.network.request(message, destination) response = await self.network.request(message, destination)
if isinstance(response, RequestError): if isinstance(response, RequestError):
raise response.exc raise response.exc
@ -86,19 +72,20 @@ class DiscordBot:
raise UnregisteredError("Author is not registered") raise UnregisteredError("Author is not registered")
return result return result
self.DiscordCall = DiscordCall return DiscordCall
def _bot_factory(self) -> typing.Type[discord.Client]:
"""Create a new DiscordClient class based on this DiscordBot."""
# noinspection PyMethodParameters # noinspection PyMethodParameters
class DiscordClient(discord.Client): class DiscordClient(discord.Client):
@staticmethod async def vc_connect_or_move(cli, channel: discord.VoiceChannel):
async def vc_connect_or_move(channel: discord.VoiceChannel):
# Connect to voice chat # Connect to voice chat
try: try:
await channel.connect() await channel.connect()
except discord.errors.ClientException: except discord.errors.ClientException:
# Move to the selected channel, instead of connecting # Move to the selected channel, instead of connecting
# noinspection PyUnusedLocal # noinspection PyUnusedLocal
for voice_client in self.bot.voice_clients: for voice_client in cli.voice_clients:
voice_client: discord.VoiceClient voice_client: discord.VoiceClient
if voice_client.guild != channel.guild: if voice_client.guild != channel.guild:
continue continue
@ -107,124 +94,123 @@ class DiscordBot:
if not self.music_data.get(channel.guild): if not self.music_data.get(channel.guild):
self.music_data[channel.guild] = Playlist() self.music_data[channel.guild] = Playlist()
async def on_message(cli, message: discord.Message): @staticmethod # Not really static because of the self reference
async def on_message(message: discord.Message):
text = message.content text = message.content
# Skip non-text messages # Skip non-text messages
if not text: if not text:
return return
# Find and clean parameters # Find and clean parameters
command_text, *parameters = text.split(" ") command_text, *parameters = text.split(" ")
# Find the function
try:
selected_command = self.commands[command_text]
except KeyError:
# Skip inexistent _commands
selected_command = self.missing_command
log.error(f"Running {selected_command}")
# Call the command # Call the command
try: await self.call(command_text, message.channel, parameters, message=message)
return await self.DiscordCall(message.channel, selected_command, parameters, log,
message=message).run()
except Exception as exc:
try:
return await self.DiscordCall(message.channel, self.error_command, parameters, log,
message=message,
exception_info=sys.exc_info(),
previous_command=selected_command).run()
except Exception as exc2:
log.error(f"Exception in error handler command: {exc2}")
self.DiscordClient = DiscordClient def find_guild_by_name(cli, name: str) -> discord.Guild:
self.bot = self.DiscordClient() """Find the Guild with the specified name. Case-insensitive.
Will raise a NoneFoundError if no channels are found, or a TooManyFoundError if more than one is found."""
all_guilds: typing.List[discord.Guild] = cli.guilds
matching_channels: typing.List[discord.Guild] = []
for guild in all_guilds:
if guild.name.lower() == name.lower():
matching_channels.append(guild)
if len(matching_channels) == 0:
raise NoneFoundError("No channels were found")
elif len(matching_channels) > 1:
raise TooManyFoundError("Too many channels were found")
return matching_channels[0]
def find_guild(self, identifier: typing.Union[str, int]) -> discord.Guild: def find_channel_by_name(cli,
"""Find the Guild with the specified identifier. Names are case-insensitive.""" name: str,
if isinstance(identifier, str): guild: typing.Optional[discord.Guild] = None) -> discord.abc.GuildChannel:
all_guilds: typing.List[discord.Guild] = self.bot.guilds """Find the TextChannel, VoiceChannel or CategoryChannel with the specified name. Case-insensitive.
matching_channels: typing.List[discord.Guild] = [] Guild is optional, but the method will raise a TooManyFoundError if none is specified and there is more than one channel with the same name.
for guild in all_guilds: Will also raise a NoneFoundError if no channels are found."""
if guild.name.lower() == identifier.lower(): if guild is not None:
matching_channels.append(guild) all_channels = guild.channels
if len(matching_channels) == 0: else:
raise NoneFoundError("No channels were found") all_channels: typing.List[discord.abc.GuildChannel] = cli.get_all_channels()
elif len(matching_channels) > 1: matching_channels: typing.List[discord.abc.GuildChannel] = []
raise TooManyFoundError("Too many channels were found") for channel in all_channels:
return matching_channels[0] if not (isinstance(channel, discord.TextChannel)
elif isinstance(identifier, int): or isinstance(channel, discord.VoiceChannel)
return self.bot.get_guild(identifier) or isinstance(channel, discord.CategoryChannel)):
raise TypeError("Invalid identifier type, should be str or int") continue
if channel.name.lower() == name.lower():
matching_channels.append(channel)
if len(matching_channels) == 0:
raise NoneFoundError("No channels were found")
elif len(matching_channels) > 1:
raise TooManyFoundError("Too many channels were found")
return matching_channels[0]
def find_channel(self, def find_voice_client_by_guild(cli, guild: discord.Guild):
identifier: typing.Union[str, int], """Find the VoiceClient belonging to a specific Guild.
guild: typing.Optional[discord.Guild] = None) -> discord.abc.GuildChannel: Raises a NoneFoundError if the Guild currently has no VoiceClient."""
"""Find the GuildChannel with the specified identifier. Names are case-insensitive.""" for voice_client in cli.voice_clients:
if isinstance(identifier, str): voice_client: discord.VoiceClient
if guild is not None: if voice_client.guild == guild:
all_channels = guild.channels return voice_client
else: raise NoneFoundError("No voice clients found")
all_channels: typing.List[discord.abc.GuildChannel] = self.bot.get_all_channels()
matching_channels: typing.List[discord.abc.GuildChannel] = []
for channel in all_channels:
if not (isinstance(channel, discord.TextChannel)
or isinstance(channel, discord.VoiceChannel)
or isinstance(channel, discord.CategoryChannel)):
continue
if channel.name.lower() == identifier.lower():
matching_channels.append(channel)
if len(matching_channels) == 0:
raise NoneFoundError("No channels were found")
elif len(matching_channels) > 1:
raise TooManyFoundError("Too many channels were found")
return matching_channels[0]
elif isinstance(identifier, int):
channel: typing.List[discord.abc.GuildChannel] = self.bot.get_channel(identifier)
if ((isinstance(channel, discord.TextChannel)
or isinstance(channel, discord.VoiceChannel)
or isinstance(channel, discord.CategoryChannel))
and guild):
assert channel.guild == guild
return channel
raise TypeError("Invalid identifier type, should be str or int")
def find_voice_client(self, guild: discord.Guild): return DiscordClient
for voice_client in self.bot.voice_clients:
voice_client: discord.VoiceClient
if voice_client.guild == guild:
return voice_client
raise NoneFoundError("No voice clients found")
async def add_to_music_data(self, url: str, guild: discord.Guild): def _init_bot(self):
"""Add a file to the corresponding music_data object.""" """Create a bot instance."""
log.debug(f"Downloading {url} to add to music_data") self.bot = self._bot_factory()()
files: typing.List[RoyalPCMFile] = await asyncify(RoyalPCMFile.create_from_url, url)
guild_music_data = self.music_data[guild]
for file in files:
log.debug(f"Adding {file} to music_data")
guild_music_data.add(file)
if guild_music_data.now_playing is None:
log.debug(f"Starting playback chain")
await self.advance_music_data(guild)
async def advance_music_data(self, guild: discord.Guild): def __init__(self, *,
"""Try to play the next song, while it exists. Otherwise, just return.""" discord_config: DiscordConfig,
guild_music_data = self.music_data[guild] royalnet_config: RoyalnetConfig,
voice_client = self.find_voice_client(guild) database_config: typing.Optional[DatabaseConfig] = None,
next_file: RoyalPCMFile = await guild_music_data.next() commands: typing.List[typing.Type[Command]] = None,
if next_file is None: missing_command: typing.Type[Command] = NullCommand,
log.debug(f"Ending playback chain") error_command: typing.Type[Command] = NullCommand):
return super().__init__(royalnet_config=royalnet_config,
database_config=database_config,
def advance(error=None): commands=commands,
log.debug(f"Deleting {next_file}") missing_command=missing_command,
next_file.delete_audio_file() error_command=error_command)
loop.create_task(self.advance_music_data(guild)) self._discord_config = discord_config
self._init_bot()
log.debug(f"Creating AudioSource of {next_file}")
next_source = next_file.create_audio_source()
log.debug(f"Starting playback of {next_source}")
voice_client.play(next_source, after=advance)
async def run(self): async def run(self):
await self.bot.login(self.token) await self.bot.login(self._discord_config.token)
await self.bot.connect() await self.bot.connect()
# TODO: how to stop? # TODO: how to stop?
# class DiscordBot:
# async def add_to_music_data(self, url: str, guild: discord.Guild):
# """Add a file to the corresponding music_data object."""
# log.debug(f"Downloading {url} to add to music_data")
# files: typing.List[RoyalPCMFile] = await asyncify(RoyalPCMFile.create_from_url, url)
# guild_music_data = self.music_data[guild]
# for file in files:
# log.debug(f"Adding {file} to music_data")
# guild_music_data.add(file)
# if guild_music_data.now_playing is None:
# log.debug(f"Starting playback chain")
# await self.advance_music_data(guild)
#
# async def advance_music_data(self, guild: discord.Guild):
# """Try to play the next song, while it exists. Otherwise, just return."""
# guild_music_data = self.music_data[guild]
# voice_client = self.find_voice_client(guild)
# next_file: RoyalPCMFile = await guild_music_data.next()
# if next_file is None:
# log.debug(f"Ending playback chain")
# return
#
# def advance(error=None):
# log.debug(f"Deleting {next_file}")
# next_file.delete_audio_file()
# loop.create_task(self.advance_music_data(guild))
#
# log.debug(f"Creating AudioSource of {next_file}")
# next_source = next_file.create_audio_source()
# log.debug(f"Starting playback of {next_source}")
# voice_client.play(next_source, after=advance)
#
#

View file

@ -1,7 +1,8 @@
import sys
import typing import typing
import asyncio import asyncio
import logging import logging
from ..utils import Command, NetworkHandler from ..utils import Command, NetworkHandler, Call
from ..commands import NullCommand from ..commands import NullCommand
from ..network import RoyalnetLink, Message, RequestError, RoyalnetConfig from ..network import RoyalnetLink, Message, RequestError, RoyalnetConfig
from ..database import Alchemy, DatabaseConfig, relationshiplinkchain from ..database import Alchemy, DatabaseConfig, relationshiplinkchain
@ -11,10 +12,12 @@ log = logging.getLogger(__name__)
class GenericBot: class GenericBot:
"""A generic bot class, to be used as base for the other more specific classes, such as TelegramBot and DiscordBot."""
def _init_commands(self, def _init_commands(self,
commands: typing.List[typing.Type[Command]], commands: typing.List[typing.Type[Command]],
missing_command: typing.Type[Command], missing_command: typing.Type[Command],
error_command: typing.Type[Command]): error_command: typing.Type[Command]):
"""Generate the commands dictionary required to handle incoming messages, and the network_handlers dictionary required to handle incoming requests."""
log.debug(f"Now generating commands") log.debug(f"Now generating commands")
self.commands: typing.Dict[str, typing.Type[Command]] = {} self.commands: typing.Dict[str, typing.Type[Command]] = {}
self.network_handlers: typing.Dict[typing.Type[Message], typing.Type[NetworkHandler]] = {} self.network_handlers: typing.Dict[typing.Type[Message], typing.Type[NetworkHandler]] = {}
@ -25,47 +28,61 @@ class GenericBot:
self.error_command: typing.Type[Command] = error_command self.error_command: typing.Type[Command] = error_command
log.debug(f"Successfully generated commands") log.debug(f"Successfully generated commands")
def _call_factory(self) -> typing.Type[Call]:
"""Create the Call class, representing a Call command. It should inherit from the utils.Call class."""
raise NotImplementedError()
def _init_royalnet(self, royalnet_config: RoyalnetConfig): def _init_royalnet(self, royalnet_config: RoyalnetConfig):
"""Create a RoyalnetLink, and run it as a task."""
self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, "discord", self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, "discord",
self._network_handler) self._network_handler)
log.debug(f"Running RoyalnetLink {self.network}") log.debug(f"Running RoyalnetLink {self.network}")
loop.create_task(self.network.run()) loop.create_task(self.network.run())
def _network_handler(self, message: Message) -> Message: def _network_handler(self, message: Message) -> Message:
log.debug(f"Received {message} from the RoyalnetLink") """Handle a single Message received from the RoyalnetLink"""
try: log.debug(f"Received {message} from the RoyalnetLink")
network_handler = self.network_handlers[message.__class__] try:
except KeyError as exc: network_handler = self.network_handlers[message.__class__]
log.debug(f"Missing network_handler for {message}") except KeyError as exc:
return RequestError(KeyError("Missing network_handler")) log.debug(f"Missing network_handler for {message}")
try: return RequestError(KeyError("Missing network_handler"))
log.debug(f"Using {network_handler} as handler for {message}") try:
return await network_handler.discord(message) log.debug(f"Using {network_handler} as handler for {message}")
except Exception as exc: return await network_handler.discord(message)
log.debug(f"Exception {exc} in {network_handler}") except Exception as exc:
return RequestError(exc) log.debug(f"Exception {exc} in {network_handler}")
return RequestError(exc)
def _init_database(self, commands: typing.List[typing.Type[Command]], database_config: DatabaseConfig): def _init_database(self, commands: typing.List[typing.Type[Command]], database_config: DatabaseConfig):
"""Connect to the database, and create the missing tables required by the selected commands."""
log.debug(f"Initializing database")
required_tables = set() required_tables = set()
for command in commands: for command in commands:
required_tables = required_tables.union(command.require_alchemy_tables) required_tables = required_tables.union(command.require_alchemy_tables)
log.debug(f"Found {len(required_tables)} required tables")
self.alchemy = Alchemy(database_config.database_uri, required_tables) self.alchemy = Alchemy(database_config.database_uri, required_tables)
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__) self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__) self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
self.identity_column = self.identity_table.__getattribute__(self.identity_table, self.identity_column = self.identity_table.__getattribute__(self.identity_table,
database_config.identity_column_name) database_config.identity_column_name)
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table) self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
log.debug(f"Identity chain is {self.identity_chain}")
def __init__(self, *, def __init__(self, *,
royalnet_config: RoyalnetConfig, royalnet_config: typing.Optional[RoyalnetConfig] = None,
database_config: typing.Optional[DatabaseConfig] = None, database_config: typing.Optional[DatabaseConfig] = None,
commands: typing.Optional[typing.List[typing.Type[Command]]] = None, commands: typing.List[typing.Type[Command]] = None,
missing_command: typing.Type[Command] = NullCommand, missing_command: typing.Type[Command] = NullCommand,
error_command: typing.Type[Command] = NullCommand): error_command: typing.Type[Command] = NullCommand):
if commands is None: if commands is None:
commands = [] commands = []
self._init_commands(commands, missing_command=missing_command, error_command=error_command) self._init_commands(commands, missing_command=missing_command, error_command=error_command)
self._init_royalnet(royalnet_config=royalnet_config) self._Call = self._call_factory()
if royalnet_config is None:
self.network = None
else:
self._init_royalnet(royalnet_config=royalnet_config)
if database_config is None: if database_config is None:
self.alchemy = None self.alchemy = None
self.master_table = None self.master_table = None
@ -73,3 +90,22 @@ class GenericBot:
self.identity_column = None self.identity_column = None
else: else:
self._init_database(commands=commands, database_config=database_config) self._init_database(commands=commands, database_config=database_config)
async def call(self, command_name: str, channel, parameters: typing.List[str] = None, **kwargs):
"""Call a command by its string, or missing_command if it doesn't exists, or error_command if an exception is raised during the execution."""
if parameters is None:
parameters = []
try:
command: typing.Type[Command] = self.commands[command_name]
except KeyError:
command = self.missing_command
try:
await self._Call(channel, command, parameters, **kwargs).run()
except Exception as exc:
await self._Call(channel, self.error_command,
exception_info=sys.exc_info(),
previous_command=command, **kwargs).run()
async def run(self):
"""A blocking coroutine that should make the bot start listening to commands and requests."""
raise NotImplementedError()

View file

@ -1,3 +1,4 @@
import logging as _logging
import traceback import traceback
from ..utils import Command, Call from ..utils import Command, Call
from ..error import NoneFoundError, \ from ..error import NoneFoundError, \
@ -9,6 +10,9 @@ from ..error import NoneFoundError, \
ExternalError ExternalError
log = _logging.getLogger(__name__)
class ErrorHandlerCommand(Command): class ErrorHandlerCommand(Command):
command_name = "error_handler" command_name = "error_handler"
@ -46,4 +50,4 @@ class ErrorHandlerCommand(Command):
return return
await call.reply(f"❌ Eccezione non gestita durante l'esecuzione del comando:\n[b]{e_type.__name__}[/b]\n{e_value}") await call.reply(f"❌ Eccezione non gestita durante l'esecuzione del comando:\n[b]{e_type.__name__}[/b]\n{e_value}")
formatted_tb: str = '\n'.join(traceback.format_tb(e_tb)) formatted_tb: str = '\n'.join(traceback.format_tb(e_tb))
call.logger.error(f"Unhandled exception - {e_type.__name__}: {e_value}\n{formatted_tb}") log.error(f"Unhandled exception - {e_type.__name__}: {e_value}\n{formatted_tb}")

View file

@ -36,13 +36,18 @@ class Call:
raise NotImplementedError() raise NotImplementedError()
# These parameters / methods should be left alone # These parameters / methods should be left alone
def __init__(self, channel, command: typing.Type[Command], command_args: list, logger: logging.Logger, **kwargs): def __init__(self,
channel,
command: typing.Type[Command],
command_args: typing.List[str] = None,
**kwargs):
if command_args is None:
command_args = []
self.channel = channel self.channel = channel
self.command = command self.command = command
self.args = CommandArgs(command_args) self.args = CommandArgs(command_args)
self.kwargs = kwargs self.kwargs = kwargs
self.session = None self.session = None
self.logger = logger
async def session_init(self): async def session_init(self):
if not self.command.require_alchemy_tables: if not self.command.require_alchemy_tables: