mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 19:44:20 +00:00
Refactor most of the DiscordBot
This commit is contained in:
parent
f4fc7dd971
commit
e677cbe9b3
4 changed files with 185 additions and 154 deletions
|
@ -3,10 +3,11 @@ import asyncio
|
|||
import typing
|
||||
import logging as _logging
|
||||
import sys
|
||||
from .generic import GenericBot
|
||||
from ..commands import NullCommand
|
||||
from ..utils import asyncify, Call, Command, NetworkHandler
|
||||
from ..utils import asyncify, Call, Command
|
||||
from ..error import UnregisteredError, NoneFoundError, TooManyFoundError, InvalidConfigError
|
||||
from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError
|
||||
from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError, RoyalnetConfig
|
||||
from ..database import Alchemy, relationshiplinkchain, DatabaseConfig
|
||||
from ..audio import RoyalPCMFile, PlayMode, Playlist
|
||||
|
||||
|
@ -18,34 +19,16 @@ if not discord.opus.is_loaded():
|
|||
log.error("Opus is not loaded. Weird behaviour might emerge.")
|
||||
|
||||
|
||||
class DiscordBot:
|
||||
def __init__(self,
|
||||
token: str,
|
||||
master_server_uri: str,
|
||||
master_server_secret: str,
|
||||
commands: typing.List[typing.Type[Command]],
|
||||
missing_command: typing.Type[Command] = NullCommand,
|
||||
error_command: typing.Type[Command] = NullCommand,
|
||||
database_config: typing.Optional[DatabaseConfig] = None):
|
||||
class DiscordConfig:
|
||||
def __init__(self, token: str):
|
||||
self.token = token
|
||||
# Generate the Alchemy database
|
||||
if database_config:
|
||||
self.alchemy = Alchemy(database_config.database_uri, required_tables)
|
||||
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
|
||||
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
|
||||
self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name)
|
||||
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
||||
else:
|
||||
if required_tables:
|
||||
raise InvalidConfigError("Tables are required by the _commands, but Alchemy is not configured")
|
||||
self.alchemy = None
|
||||
self.master_table = None
|
||||
self.identity_table = None
|
||||
self.identity_column = None
|
||||
self.identity_chain = None
|
||||
# Create the PlayModes dictionary
|
||||
|
||||
|
||||
class DiscordBot(GenericBot):
|
||||
def _init_voice(self):
|
||||
self.music_data: typing.Dict[discord.Guild, PlayMode] = {}
|
||||
|
||||
def _call_factory(self) -> typing.Type[Call]:
|
||||
# noinspection PyMethodParameters
|
||||
class DiscordCall(Call):
|
||||
interface_name = "discord"
|
||||
|
@ -55,6 +38,7 @@ class DiscordBot:
|
|||
alchemy = self.alchemy
|
||||
|
||||
async def reply(call, text: str):
|
||||
# TODO: don't escape characters inside [c][/c] blocks
|
||||
escaped_text = text.replace("*", "\\*") \
|
||||
.replace("_", "\\_") \
|
||||
.replace("`", "\\`") \
|
||||
|
@ -69,6 +53,8 @@ class DiscordBot:
|
|||
await call.channel.send(escaped_text)
|
||||
|
||||
async def net_request(call, message: Message, destination: str):
|
||||
if self.network is None:
|
||||
raise InvalidConfigError("Royalnet is not enabled on this bot")
|
||||
response = await self.network.request(message, destination)
|
||||
if isinstance(response, RequestError):
|
||||
raise response.exc
|
||||
|
@ -86,19 +72,20 @@ class DiscordBot:
|
|||
raise UnregisteredError("Author is not registered")
|
||||
return result
|
||||
|
||||
self.DiscordCall = DiscordCall
|
||||
return DiscordCall
|
||||
|
||||
def _bot_factory(self) -> typing.Type[discord.Client]:
|
||||
"""Create a new DiscordClient class based on this DiscordBot."""
|
||||
# noinspection PyMethodParameters
|
||||
class DiscordClient(discord.Client):
|
||||
@staticmethod
|
||||
async def vc_connect_or_move(channel: discord.VoiceChannel):
|
||||
async def vc_connect_or_move(cli, channel: discord.VoiceChannel):
|
||||
# Connect to voice chat
|
||||
try:
|
||||
await channel.connect()
|
||||
except discord.errors.ClientException:
|
||||
# Move to the selected channel, instead of connecting
|
||||
# noinspection PyUnusedLocal
|
||||
for voice_client in self.bot.voice_clients:
|
||||
for voice_client in cli.voice_clients:
|
||||
voice_client: discord.VoiceClient
|
||||
if voice_client.guild != channel.guild:
|
||||
continue
|
||||
|
@ -107,124 +94,123 @@ class DiscordBot:
|
|||
if not self.music_data.get(channel.guild):
|
||||
self.music_data[channel.guild] = Playlist()
|
||||
|
||||
async def on_message(cli, message: discord.Message):
|
||||
@staticmethod # Not really static because of the self reference
|
||||
async def on_message(message: discord.Message):
|
||||
text = message.content
|
||||
# Skip non-text messages
|
||||
if not text:
|
||||
return
|
||||
# Find and clean parameters
|
||||
command_text, *parameters = text.split(" ")
|
||||
# Find the function
|
||||
try:
|
||||
selected_command = self.commands[command_text]
|
||||
except KeyError:
|
||||
# Skip inexistent _commands
|
||||
selected_command = self.missing_command
|
||||
log.error(f"Running {selected_command}")
|
||||
# Call the command
|
||||
try:
|
||||
return await self.DiscordCall(message.channel, selected_command, parameters, log,
|
||||
message=message).run()
|
||||
except Exception as exc:
|
||||
try:
|
||||
return await self.DiscordCall(message.channel, self.error_command, parameters, log,
|
||||
message=message,
|
||||
exception_info=sys.exc_info(),
|
||||
previous_command=selected_command).run()
|
||||
except Exception as exc2:
|
||||
log.error(f"Exception in error handler command: {exc2}")
|
||||
await self.call(command_text, message.channel, parameters, message=message)
|
||||
|
||||
self.DiscordClient = DiscordClient
|
||||
self.bot = self.DiscordClient()
|
||||
|
||||
def find_guild(self, identifier: typing.Union[str, int]) -> discord.Guild:
|
||||
"""Find the Guild with the specified identifier. Names are case-insensitive."""
|
||||
if isinstance(identifier, str):
|
||||
all_guilds: typing.List[discord.Guild] = self.bot.guilds
|
||||
def find_guild_by_name(cli, name: str) -> discord.Guild:
|
||||
"""Find the Guild with the specified name. Case-insensitive.
|
||||
Will raise a NoneFoundError if no channels are found, or a TooManyFoundError if more than one is found."""
|
||||
all_guilds: typing.List[discord.Guild] = cli.guilds
|
||||
matching_channels: typing.List[discord.Guild] = []
|
||||
for guild in all_guilds:
|
||||
if guild.name.lower() == identifier.lower():
|
||||
if guild.name.lower() == name.lower():
|
||||
matching_channels.append(guild)
|
||||
if len(matching_channels) == 0:
|
||||
raise NoneFoundError("No channels were found")
|
||||
elif len(matching_channels) > 1:
|
||||
raise TooManyFoundError("Too many channels were found")
|
||||
return matching_channels[0]
|
||||
elif isinstance(identifier, int):
|
||||
return self.bot.get_guild(identifier)
|
||||
raise TypeError("Invalid identifier type, should be str or int")
|
||||
|
||||
def find_channel(self,
|
||||
identifier: typing.Union[str, int],
|
||||
def find_channel_by_name(cli,
|
||||
name: str,
|
||||
guild: typing.Optional[discord.Guild] = None) -> discord.abc.GuildChannel:
|
||||
"""Find the GuildChannel with the specified identifier. Names are case-insensitive."""
|
||||
if isinstance(identifier, str):
|
||||
"""Find the TextChannel, VoiceChannel or CategoryChannel with the specified name. Case-insensitive.
|
||||
Guild is optional, but the method will raise a TooManyFoundError if none is specified and there is more than one channel with the same name.
|
||||
Will also raise a NoneFoundError if no channels are found."""
|
||||
if guild is not None:
|
||||
all_channels = guild.channels
|
||||
else:
|
||||
all_channels: typing.List[discord.abc.GuildChannel] = self.bot.get_all_channels()
|
||||
all_channels: typing.List[discord.abc.GuildChannel] = cli.get_all_channels()
|
||||
matching_channels: typing.List[discord.abc.GuildChannel] = []
|
||||
for channel in all_channels:
|
||||
if not (isinstance(channel, discord.TextChannel)
|
||||
or isinstance(channel, discord.VoiceChannel)
|
||||
or isinstance(channel, discord.CategoryChannel)):
|
||||
continue
|
||||
if channel.name.lower() == identifier.lower():
|
||||
if channel.name.lower() == name.lower():
|
||||
matching_channels.append(channel)
|
||||
if len(matching_channels) == 0:
|
||||
raise NoneFoundError("No channels were found")
|
||||
elif len(matching_channels) > 1:
|
||||
raise TooManyFoundError("Too many channels were found")
|
||||
return matching_channels[0]
|
||||
elif isinstance(identifier, int):
|
||||
channel: typing.List[discord.abc.GuildChannel] = self.bot.get_channel(identifier)
|
||||
if ((isinstance(channel, discord.TextChannel)
|
||||
or isinstance(channel, discord.VoiceChannel)
|
||||
or isinstance(channel, discord.CategoryChannel))
|
||||
and guild):
|
||||
assert channel.guild == guild
|
||||
return channel
|
||||
raise TypeError("Invalid identifier type, should be str or int")
|
||||
|
||||
def find_voice_client(self, guild: discord.Guild):
|
||||
for voice_client in self.bot.voice_clients:
|
||||
def find_voice_client_by_guild(cli, guild: discord.Guild):
|
||||
"""Find the VoiceClient belonging to a specific Guild.
|
||||
Raises a NoneFoundError if the Guild currently has no VoiceClient."""
|
||||
for voice_client in cli.voice_clients:
|
||||
voice_client: discord.VoiceClient
|
||||
if voice_client.guild == guild:
|
||||
return voice_client
|
||||
raise NoneFoundError("No voice clients found")
|
||||
|
||||
async def add_to_music_data(self, url: str, guild: discord.Guild):
|
||||
"""Add a file to the corresponding music_data object."""
|
||||
log.debug(f"Downloading {url} to add to music_data")
|
||||
files: typing.List[RoyalPCMFile] = await asyncify(RoyalPCMFile.create_from_url, url)
|
||||
guild_music_data = self.music_data[guild]
|
||||
for file in files:
|
||||
log.debug(f"Adding {file} to music_data")
|
||||
guild_music_data.add(file)
|
||||
if guild_music_data.now_playing is None:
|
||||
log.debug(f"Starting playback chain")
|
||||
await self.advance_music_data(guild)
|
||||
return DiscordClient
|
||||
|
||||
async def advance_music_data(self, guild: discord.Guild):
|
||||
"""Try to play the next song, while it exists. Otherwise, just return."""
|
||||
guild_music_data = self.music_data[guild]
|
||||
voice_client = self.find_voice_client(guild)
|
||||
next_file: RoyalPCMFile = await guild_music_data.next()
|
||||
if next_file is None:
|
||||
log.debug(f"Ending playback chain")
|
||||
return
|
||||
def _init_bot(self):
|
||||
"""Create a bot instance."""
|
||||
self.bot = self._bot_factory()()
|
||||
|
||||
def advance(error=None):
|
||||
log.debug(f"Deleting {next_file}")
|
||||
next_file.delete_audio_file()
|
||||
loop.create_task(self.advance_music_data(guild))
|
||||
|
||||
log.debug(f"Creating AudioSource of {next_file}")
|
||||
next_source = next_file.create_audio_source()
|
||||
log.debug(f"Starting playback of {next_source}")
|
||||
voice_client.play(next_source, after=advance)
|
||||
def __init__(self, *,
|
||||
discord_config: DiscordConfig,
|
||||
royalnet_config: RoyalnetConfig,
|
||||
database_config: typing.Optional[DatabaseConfig] = None,
|
||||
commands: typing.List[typing.Type[Command]] = None,
|
||||
missing_command: typing.Type[Command] = NullCommand,
|
||||
error_command: typing.Type[Command] = NullCommand):
|
||||
super().__init__(royalnet_config=royalnet_config,
|
||||
database_config=database_config,
|
||||
commands=commands,
|
||||
missing_command=missing_command,
|
||||
error_command=error_command)
|
||||
self._discord_config = discord_config
|
||||
self._init_bot()
|
||||
|
||||
async def run(self):
|
||||
await self.bot.login(self.token)
|
||||
await self.bot.login(self._discord_config.token)
|
||||
await self.bot.connect()
|
||||
# TODO: how to stop?
|
||||
|
||||
|
||||
|
||||
# class DiscordBot:
|
||||
# async def add_to_music_data(self, url: str, guild: discord.Guild):
|
||||
# """Add a file to the corresponding music_data object."""
|
||||
# log.debug(f"Downloading {url} to add to music_data")
|
||||
# files: typing.List[RoyalPCMFile] = await asyncify(RoyalPCMFile.create_from_url, url)
|
||||
# guild_music_data = self.music_data[guild]
|
||||
# for file in files:
|
||||
# log.debug(f"Adding {file} to music_data")
|
||||
# guild_music_data.add(file)
|
||||
# if guild_music_data.now_playing is None:
|
||||
# log.debug(f"Starting playback chain")
|
||||
# await self.advance_music_data(guild)
|
||||
#
|
||||
# async def advance_music_data(self, guild: discord.Guild):
|
||||
# """Try to play the next song, while it exists. Otherwise, just return."""
|
||||
# guild_music_data = self.music_data[guild]
|
||||
# voice_client = self.find_voice_client(guild)
|
||||
# next_file: RoyalPCMFile = await guild_music_data.next()
|
||||
# if next_file is None:
|
||||
# log.debug(f"Ending playback chain")
|
||||
# return
|
||||
#
|
||||
# def advance(error=None):
|
||||
# log.debug(f"Deleting {next_file}")
|
||||
# next_file.delete_audio_file()
|
||||
# loop.create_task(self.advance_music_data(guild))
|
||||
#
|
||||
# log.debug(f"Creating AudioSource of {next_file}")
|
||||
# next_source = next_file.create_audio_source()
|
||||
# log.debug(f"Starting playback of {next_source}")
|
||||
# voice_client.play(next_source, after=advance)
|
||||
#
|
||||
|
||||
#
|
|
@ -1,7 +1,8 @@
|
|||
import sys
|
||||
import typing
|
||||
import asyncio
|
||||
import logging
|
||||
from ..utils import Command, NetworkHandler
|
||||
from ..utils import Command, NetworkHandler, Call
|
||||
from ..commands import NullCommand
|
||||
from ..network import RoyalnetLink, Message, RequestError, RoyalnetConfig
|
||||
from ..database import Alchemy, DatabaseConfig, relationshiplinkchain
|
||||
|
@ -11,10 +12,12 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class GenericBot:
|
||||
"""A generic bot class, to be used as base for the other more specific classes, such as TelegramBot and DiscordBot."""
|
||||
def _init_commands(self,
|
||||
commands: typing.List[typing.Type[Command]],
|
||||
missing_command: typing.Type[Command],
|
||||
error_command: typing.Type[Command]):
|
||||
"""Generate the commands dictionary required to handle incoming messages, and the network_handlers dictionary required to handle incoming requests."""
|
||||
log.debug(f"Now generating commands")
|
||||
self.commands: typing.Dict[str, typing.Type[Command]] = {}
|
||||
self.network_handlers: typing.Dict[typing.Type[Message], typing.Type[NetworkHandler]] = {}
|
||||
|
@ -25,13 +28,19 @@ class GenericBot:
|
|||
self.error_command: typing.Type[Command] = error_command
|
||||
log.debug(f"Successfully generated commands")
|
||||
|
||||
def _call_factory(self) -> typing.Type[Call]:
|
||||
"""Create the Call class, representing a Call command. It should inherit from the utils.Call class."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _init_royalnet(self, royalnet_config: RoyalnetConfig):
|
||||
"""Create a RoyalnetLink, and run it as a task."""
|
||||
self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, "discord",
|
||||
self._network_handler)
|
||||
log.debug(f"Running RoyalnetLink {self.network}")
|
||||
loop.create_task(self.network.run())
|
||||
|
||||
def _network_handler(self, message: Message) -> Message:
|
||||
"""Handle a single Message received from the RoyalnetLink"""
|
||||
log.debug(f"Received {message} from the RoyalnetLink")
|
||||
try:
|
||||
network_handler = self.network_handlers[message.__class__]
|
||||
|
@ -46,25 +55,33 @@ class GenericBot:
|
|||
return RequestError(exc)
|
||||
|
||||
def _init_database(self, commands: typing.List[typing.Type[Command]], database_config: DatabaseConfig):
|
||||
"""Connect to the database, and create the missing tables required by the selected commands."""
|
||||
log.debug(f"Initializing database")
|
||||
required_tables = set()
|
||||
for command in commands:
|
||||
required_tables = required_tables.union(command.require_alchemy_tables)
|
||||
log.debug(f"Found {len(required_tables)} required tables")
|
||||
self.alchemy = Alchemy(database_config.database_uri, required_tables)
|
||||
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
|
||||
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
|
||||
self.identity_column = self.identity_table.__getattribute__(self.identity_table,
|
||||
database_config.identity_column_name)
|
||||
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
||||
log.debug(f"Identity chain is {self.identity_chain}")
|
||||
|
||||
def __init__(self, *,
|
||||
royalnet_config: RoyalnetConfig,
|
||||
royalnet_config: typing.Optional[RoyalnetConfig] = None,
|
||||
database_config: typing.Optional[DatabaseConfig] = None,
|
||||
commands: typing.Optional[typing.List[typing.Type[Command]]] = None,
|
||||
commands: typing.List[typing.Type[Command]] = None,
|
||||
missing_command: typing.Type[Command] = NullCommand,
|
||||
error_command: typing.Type[Command] = NullCommand):
|
||||
if commands is None:
|
||||
commands = []
|
||||
self._init_commands(commands, missing_command=missing_command, error_command=error_command)
|
||||
self._Call = self._call_factory()
|
||||
if royalnet_config is None:
|
||||
self.network = None
|
||||
else:
|
||||
self._init_royalnet(royalnet_config=royalnet_config)
|
||||
if database_config is None:
|
||||
self.alchemy = None
|
||||
|
@ -73,3 +90,22 @@ class GenericBot:
|
|||
self.identity_column = None
|
||||
else:
|
||||
self._init_database(commands=commands, database_config=database_config)
|
||||
|
||||
async def call(self, command_name: str, channel, parameters: typing.List[str] = None, **kwargs):
|
||||
"""Call a command by its string, or missing_command if it doesn't exists, or error_command if an exception is raised during the execution."""
|
||||
if parameters is None:
|
||||
parameters = []
|
||||
try:
|
||||
command: typing.Type[Command] = self.commands[command_name]
|
||||
except KeyError:
|
||||
command = self.missing_command
|
||||
try:
|
||||
await self._Call(channel, command, parameters, **kwargs).run()
|
||||
except Exception as exc:
|
||||
await self._Call(channel, self.error_command,
|
||||
exception_info=sys.exc_info(),
|
||||
previous_command=command, **kwargs).run()
|
||||
|
||||
async def run(self):
|
||||
"""A blocking coroutine that should make the bot start listening to commands and requests."""
|
||||
raise NotImplementedError()
|
|
@ -1,3 +1,4 @@
|
|||
import logging as _logging
|
||||
import traceback
|
||||
from ..utils import Command, Call
|
||||
from ..error import NoneFoundError, \
|
||||
|
@ -9,6 +10,9 @@ from ..error import NoneFoundError, \
|
|||
ExternalError
|
||||
|
||||
|
||||
log = _logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ErrorHandlerCommand(Command):
|
||||
|
||||
command_name = "error_handler"
|
||||
|
@ -46,4 +50,4 @@ class ErrorHandlerCommand(Command):
|
|||
return
|
||||
await call.reply(f"❌ Eccezione non gestita durante l'esecuzione del comando:\n[b]{e_type.__name__}[/b]\n{e_value}")
|
||||
formatted_tb: str = '\n'.join(traceback.format_tb(e_tb))
|
||||
call.logger.error(f"Unhandled exception - {e_type.__name__}: {e_value}\n{formatted_tb}")
|
||||
log.error(f"Unhandled exception - {e_type.__name__}: {e_value}\n{formatted_tb}")
|
||||
|
|
|
@ -36,13 +36,18 @@ class Call:
|
|||
raise NotImplementedError()
|
||||
|
||||
# These parameters / methods should be left alone
|
||||
def __init__(self, channel, command: typing.Type[Command], command_args: list, logger: logging.Logger, **kwargs):
|
||||
def __init__(self,
|
||||
channel,
|
||||
command: typing.Type[Command],
|
||||
command_args: typing.List[str] = None,
|
||||
**kwargs):
|
||||
if command_args is None:
|
||||
command_args = []
|
||||
self.channel = channel
|
||||
self.command = command
|
||||
self.args = CommandArgs(command_args)
|
||||
self.kwargs = kwargs
|
||||
self.session = None
|
||||
self.logger = logger
|
||||
|
||||
async def session_init(self):
|
||||
if not self.command.require_alchemy_tables:
|
||||
|
|
Loading…
Reference in a new issue