mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 19:44:20 +00:00
Refactor most of the DiscordBot
This commit is contained in:
parent
f4fc7dd971
commit
e677cbe9b3
4 changed files with 185 additions and 154 deletions
|
@ -3,10 +3,11 @@ import asyncio
|
||||||
import typing
|
import typing
|
||||||
import logging as _logging
|
import logging as _logging
|
||||||
import sys
|
import sys
|
||||||
|
from .generic import GenericBot
|
||||||
from ..commands import NullCommand
|
from ..commands import NullCommand
|
||||||
from ..utils import asyncify, Call, Command, NetworkHandler
|
from ..utils import asyncify, Call, Command
|
||||||
from ..error import UnregisteredError, NoneFoundError, TooManyFoundError, InvalidConfigError
|
from ..error import UnregisteredError, NoneFoundError, TooManyFoundError, InvalidConfigError
|
||||||
from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError
|
from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError, RoyalnetConfig
|
||||||
from ..database import Alchemy, relationshiplinkchain, DatabaseConfig
|
from ..database import Alchemy, relationshiplinkchain, DatabaseConfig
|
||||||
from ..audio import RoyalPCMFile, PlayMode, Playlist
|
from ..audio import RoyalPCMFile, PlayMode, Playlist
|
||||||
|
|
||||||
|
@ -18,34 +19,16 @@ if not discord.opus.is_loaded():
|
||||||
log.error("Opus is not loaded. Weird behaviour might emerge.")
|
log.error("Opus is not loaded. Weird behaviour might emerge.")
|
||||||
|
|
||||||
|
|
||||||
class DiscordBot:
|
class DiscordConfig:
|
||||||
def __init__(self,
|
def __init__(self, token: str):
|
||||||
token: str,
|
|
||||||
master_server_uri: str,
|
|
||||||
master_server_secret: str,
|
|
||||||
commands: typing.List[typing.Type[Command]],
|
|
||||||
missing_command: typing.Type[Command] = NullCommand,
|
|
||||||
error_command: typing.Type[Command] = NullCommand,
|
|
||||||
database_config: typing.Optional[DatabaseConfig] = None):
|
|
||||||
self.token = token
|
self.token = token
|
||||||
# Generate the Alchemy database
|
|
||||||
if database_config:
|
|
||||||
self.alchemy = Alchemy(database_config.database_uri, required_tables)
|
class DiscordBot(GenericBot):
|
||||||
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
|
def _init_voice(self):
|
||||||
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
|
|
||||||
self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name)
|
|
||||||
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
|
||||||
else:
|
|
||||||
if required_tables:
|
|
||||||
raise InvalidConfigError("Tables are required by the _commands, but Alchemy is not configured")
|
|
||||||
self.alchemy = None
|
|
||||||
self.master_table = None
|
|
||||||
self.identity_table = None
|
|
||||||
self.identity_column = None
|
|
||||||
self.identity_chain = None
|
|
||||||
# Create the PlayModes dictionary
|
|
||||||
self.music_data: typing.Dict[discord.Guild, PlayMode] = {}
|
self.music_data: typing.Dict[discord.Guild, PlayMode] = {}
|
||||||
|
|
||||||
|
def _call_factory(self) -> typing.Type[Call]:
|
||||||
# noinspection PyMethodParameters
|
# noinspection PyMethodParameters
|
||||||
class DiscordCall(Call):
|
class DiscordCall(Call):
|
||||||
interface_name = "discord"
|
interface_name = "discord"
|
||||||
|
@ -55,6 +38,7 @@ class DiscordBot:
|
||||||
alchemy = self.alchemy
|
alchemy = self.alchemy
|
||||||
|
|
||||||
async def reply(call, text: str):
|
async def reply(call, text: str):
|
||||||
|
# TODO: don't escape characters inside [c][/c] blocks
|
||||||
escaped_text = text.replace("*", "\\*") \
|
escaped_text = text.replace("*", "\\*") \
|
||||||
.replace("_", "\\_") \
|
.replace("_", "\\_") \
|
||||||
.replace("`", "\\`") \
|
.replace("`", "\\`") \
|
||||||
|
@ -69,6 +53,8 @@ class DiscordBot:
|
||||||
await call.channel.send(escaped_text)
|
await call.channel.send(escaped_text)
|
||||||
|
|
||||||
async def net_request(call, message: Message, destination: str):
|
async def net_request(call, message: Message, destination: str):
|
||||||
|
if self.network is None:
|
||||||
|
raise InvalidConfigError("Royalnet is not enabled on this bot")
|
||||||
response = await self.network.request(message, destination)
|
response = await self.network.request(message, destination)
|
||||||
if isinstance(response, RequestError):
|
if isinstance(response, RequestError):
|
||||||
raise response.exc
|
raise response.exc
|
||||||
|
@ -86,19 +72,20 @@ class DiscordBot:
|
||||||
raise UnregisteredError("Author is not registered")
|
raise UnregisteredError("Author is not registered")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
self.DiscordCall = DiscordCall
|
return DiscordCall
|
||||||
|
|
||||||
|
def _bot_factory(self) -> typing.Type[discord.Client]:
|
||||||
|
"""Create a new DiscordClient class based on this DiscordBot."""
|
||||||
# noinspection PyMethodParameters
|
# noinspection PyMethodParameters
|
||||||
class DiscordClient(discord.Client):
|
class DiscordClient(discord.Client):
|
||||||
@staticmethod
|
async def vc_connect_or_move(cli, channel: discord.VoiceChannel):
|
||||||
async def vc_connect_or_move(channel: discord.VoiceChannel):
|
|
||||||
# Connect to voice chat
|
# Connect to voice chat
|
||||||
try:
|
try:
|
||||||
await channel.connect()
|
await channel.connect()
|
||||||
except discord.errors.ClientException:
|
except discord.errors.ClientException:
|
||||||
# Move to the selected channel, instead of connecting
|
# Move to the selected channel, instead of connecting
|
||||||
# noinspection PyUnusedLocal
|
# noinspection PyUnusedLocal
|
||||||
for voice_client in self.bot.voice_clients:
|
for voice_client in cli.voice_clients:
|
||||||
voice_client: discord.VoiceClient
|
voice_client: discord.VoiceClient
|
||||||
if voice_client.guild != channel.guild:
|
if voice_client.guild != channel.guild:
|
||||||
continue
|
continue
|
||||||
|
@ -107,124 +94,123 @@ class DiscordBot:
|
||||||
if not self.music_data.get(channel.guild):
|
if not self.music_data.get(channel.guild):
|
||||||
self.music_data[channel.guild] = Playlist()
|
self.music_data[channel.guild] = Playlist()
|
||||||
|
|
||||||
async def on_message(cli, message: discord.Message):
|
@staticmethod # Not really static because of the self reference
|
||||||
|
async def on_message(message: discord.Message):
|
||||||
text = message.content
|
text = message.content
|
||||||
# Skip non-text messages
|
# Skip non-text messages
|
||||||
if not text:
|
if not text:
|
||||||
return
|
return
|
||||||
# Find and clean parameters
|
# Find and clean parameters
|
||||||
command_text, *parameters = text.split(" ")
|
command_text, *parameters = text.split(" ")
|
||||||
# Find the function
|
|
||||||
try:
|
|
||||||
selected_command = self.commands[command_text]
|
|
||||||
except KeyError:
|
|
||||||
# Skip inexistent _commands
|
|
||||||
selected_command = self.missing_command
|
|
||||||
log.error(f"Running {selected_command}")
|
|
||||||
# Call the command
|
# Call the command
|
||||||
try:
|
await self.call(command_text, message.channel, parameters, message=message)
|
||||||
return await self.DiscordCall(message.channel, selected_command, parameters, log,
|
|
||||||
message=message).run()
|
|
||||||
except Exception as exc:
|
|
||||||
try:
|
|
||||||
return await self.DiscordCall(message.channel, self.error_command, parameters, log,
|
|
||||||
message=message,
|
|
||||||
exception_info=sys.exc_info(),
|
|
||||||
previous_command=selected_command).run()
|
|
||||||
except Exception as exc2:
|
|
||||||
log.error(f"Exception in error handler command: {exc2}")
|
|
||||||
|
|
||||||
self.DiscordClient = DiscordClient
|
def find_guild_by_name(cli, name: str) -> discord.Guild:
|
||||||
self.bot = self.DiscordClient()
|
"""Find the Guild with the specified name. Case-insensitive.
|
||||||
|
Will raise a NoneFoundError if no channels are found, or a TooManyFoundError if more than one is found."""
|
||||||
def find_guild(self, identifier: typing.Union[str, int]) -> discord.Guild:
|
all_guilds: typing.List[discord.Guild] = cli.guilds
|
||||||
"""Find the Guild with the specified identifier. Names are case-insensitive."""
|
|
||||||
if isinstance(identifier, str):
|
|
||||||
all_guilds: typing.List[discord.Guild] = self.bot.guilds
|
|
||||||
matching_channels: typing.List[discord.Guild] = []
|
matching_channels: typing.List[discord.Guild] = []
|
||||||
for guild in all_guilds:
|
for guild in all_guilds:
|
||||||
if guild.name.lower() == identifier.lower():
|
if guild.name.lower() == name.lower():
|
||||||
matching_channels.append(guild)
|
matching_channels.append(guild)
|
||||||
if len(matching_channels) == 0:
|
if len(matching_channels) == 0:
|
||||||
raise NoneFoundError("No channels were found")
|
raise NoneFoundError("No channels were found")
|
||||||
elif len(matching_channels) > 1:
|
elif len(matching_channels) > 1:
|
||||||
raise TooManyFoundError("Too many channels were found")
|
raise TooManyFoundError("Too many channels were found")
|
||||||
return matching_channels[0]
|
return matching_channels[0]
|
||||||
elif isinstance(identifier, int):
|
|
||||||
return self.bot.get_guild(identifier)
|
|
||||||
raise TypeError("Invalid identifier type, should be str or int")
|
|
||||||
|
|
||||||
def find_channel(self,
|
def find_channel_by_name(cli,
|
||||||
identifier: typing.Union[str, int],
|
name: str,
|
||||||
guild: typing.Optional[discord.Guild] = None) -> discord.abc.GuildChannel:
|
guild: typing.Optional[discord.Guild] = None) -> discord.abc.GuildChannel:
|
||||||
"""Find the GuildChannel with the specified identifier. Names are case-insensitive."""
|
"""Find the TextChannel, VoiceChannel or CategoryChannel with the specified name. Case-insensitive.
|
||||||
if isinstance(identifier, str):
|
Guild is optional, but the method will raise a TooManyFoundError if none is specified and there is more than one channel with the same name.
|
||||||
|
Will also raise a NoneFoundError if no channels are found."""
|
||||||
if guild is not None:
|
if guild is not None:
|
||||||
all_channels = guild.channels
|
all_channels = guild.channels
|
||||||
else:
|
else:
|
||||||
all_channels: typing.List[discord.abc.GuildChannel] = self.bot.get_all_channels()
|
all_channels: typing.List[discord.abc.GuildChannel] = cli.get_all_channels()
|
||||||
matching_channels: typing.List[discord.abc.GuildChannel] = []
|
matching_channels: typing.List[discord.abc.GuildChannel] = []
|
||||||
for channel in all_channels:
|
for channel in all_channels:
|
||||||
if not (isinstance(channel, discord.TextChannel)
|
if not (isinstance(channel, discord.TextChannel)
|
||||||
or isinstance(channel, discord.VoiceChannel)
|
or isinstance(channel, discord.VoiceChannel)
|
||||||
or isinstance(channel, discord.CategoryChannel)):
|
or isinstance(channel, discord.CategoryChannel)):
|
||||||
continue
|
continue
|
||||||
if channel.name.lower() == identifier.lower():
|
if channel.name.lower() == name.lower():
|
||||||
matching_channels.append(channel)
|
matching_channels.append(channel)
|
||||||
if len(matching_channels) == 0:
|
if len(matching_channels) == 0:
|
||||||
raise NoneFoundError("No channels were found")
|
raise NoneFoundError("No channels were found")
|
||||||
elif len(matching_channels) > 1:
|
elif len(matching_channels) > 1:
|
||||||
raise TooManyFoundError("Too many channels were found")
|
raise TooManyFoundError("Too many channels were found")
|
||||||
return matching_channels[0]
|
return matching_channels[0]
|
||||||
elif isinstance(identifier, int):
|
|
||||||
channel: typing.List[discord.abc.GuildChannel] = self.bot.get_channel(identifier)
|
|
||||||
if ((isinstance(channel, discord.TextChannel)
|
|
||||||
or isinstance(channel, discord.VoiceChannel)
|
|
||||||
or isinstance(channel, discord.CategoryChannel))
|
|
||||||
and guild):
|
|
||||||
assert channel.guild == guild
|
|
||||||
return channel
|
|
||||||
raise TypeError("Invalid identifier type, should be str or int")
|
|
||||||
|
|
||||||
def find_voice_client(self, guild: discord.Guild):
|
def find_voice_client_by_guild(cli, guild: discord.Guild):
|
||||||
for voice_client in self.bot.voice_clients:
|
"""Find the VoiceClient belonging to a specific Guild.
|
||||||
|
Raises a NoneFoundError if the Guild currently has no VoiceClient."""
|
||||||
|
for voice_client in cli.voice_clients:
|
||||||
voice_client: discord.VoiceClient
|
voice_client: discord.VoiceClient
|
||||||
if voice_client.guild == guild:
|
if voice_client.guild == guild:
|
||||||
return voice_client
|
return voice_client
|
||||||
raise NoneFoundError("No voice clients found")
|
raise NoneFoundError("No voice clients found")
|
||||||
|
|
||||||
async def add_to_music_data(self, url: str, guild: discord.Guild):
|
return DiscordClient
|
||||||
"""Add a file to the corresponding music_data object."""
|
|
||||||
log.debug(f"Downloading {url} to add to music_data")
|
|
||||||
files: typing.List[RoyalPCMFile] = await asyncify(RoyalPCMFile.create_from_url, url)
|
|
||||||
guild_music_data = self.music_data[guild]
|
|
||||||
for file in files:
|
|
||||||
log.debug(f"Adding {file} to music_data")
|
|
||||||
guild_music_data.add(file)
|
|
||||||
if guild_music_data.now_playing is None:
|
|
||||||
log.debug(f"Starting playback chain")
|
|
||||||
await self.advance_music_data(guild)
|
|
||||||
|
|
||||||
async def advance_music_data(self, guild: discord.Guild):
|
def _init_bot(self):
|
||||||
"""Try to play the next song, while it exists. Otherwise, just return."""
|
"""Create a bot instance."""
|
||||||
guild_music_data = self.music_data[guild]
|
self.bot = self._bot_factory()()
|
||||||
voice_client = self.find_voice_client(guild)
|
|
||||||
next_file: RoyalPCMFile = await guild_music_data.next()
|
|
||||||
if next_file is None:
|
|
||||||
log.debug(f"Ending playback chain")
|
|
||||||
return
|
|
||||||
|
|
||||||
def advance(error=None):
|
def __init__(self, *,
|
||||||
log.debug(f"Deleting {next_file}")
|
discord_config: DiscordConfig,
|
||||||
next_file.delete_audio_file()
|
royalnet_config: RoyalnetConfig,
|
||||||
loop.create_task(self.advance_music_data(guild))
|
database_config: typing.Optional[DatabaseConfig] = None,
|
||||||
|
commands: typing.List[typing.Type[Command]] = None,
|
||||||
log.debug(f"Creating AudioSource of {next_file}")
|
missing_command: typing.Type[Command] = NullCommand,
|
||||||
next_source = next_file.create_audio_source()
|
error_command: typing.Type[Command] = NullCommand):
|
||||||
log.debug(f"Starting playback of {next_source}")
|
super().__init__(royalnet_config=royalnet_config,
|
||||||
voice_client.play(next_source, after=advance)
|
database_config=database_config,
|
||||||
|
commands=commands,
|
||||||
|
missing_command=missing_command,
|
||||||
|
error_command=error_command)
|
||||||
|
self._discord_config = discord_config
|
||||||
|
self._init_bot()
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
await self.bot.login(self.token)
|
await self.bot.login(self._discord_config.token)
|
||||||
await self.bot.connect()
|
await self.bot.connect()
|
||||||
# TODO: how to stop?
|
# TODO: how to stop?
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# class DiscordBot:
|
||||||
|
# async def add_to_music_data(self, url: str, guild: discord.Guild):
|
||||||
|
# """Add a file to the corresponding music_data object."""
|
||||||
|
# log.debug(f"Downloading {url} to add to music_data")
|
||||||
|
# files: typing.List[RoyalPCMFile] = await asyncify(RoyalPCMFile.create_from_url, url)
|
||||||
|
# guild_music_data = self.music_data[guild]
|
||||||
|
# for file in files:
|
||||||
|
# log.debug(f"Adding {file} to music_data")
|
||||||
|
# guild_music_data.add(file)
|
||||||
|
# if guild_music_data.now_playing is None:
|
||||||
|
# log.debug(f"Starting playback chain")
|
||||||
|
# await self.advance_music_data(guild)
|
||||||
|
#
|
||||||
|
# async def advance_music_data(self, guild: discord.Guild):
|
||||||
|
# """Try to play the next song, while it exists. Otherwise, just return."""
|
||||||
|
# guild_music_data = self.music_data[guild]
|
||||||
|
# voice_client = self.find_voice_client(guild)
|
||||||
|
# next_file: RoyalPCMFile = await guild_music_data.next()
|
||||||
|
# if next_file is None:
|
||||||
|
# log.debug(f"Ending playback chain")
|
||||||
|
# return
|
||||||
|
#
|
||||||
|
# def advance(error=None):
|
||||||
|
# log.debug(f"Deleting {next_file}")
|
||||||
|
# next_file.delete_audio_file()
|
||||||
|
# loop.create_task(self.advance_music_data(guild))
|
||||||
|
#
|
||||||
|
# log.debug(f"Creating AudioSource of {next_file}")
|
||||||
|
# next_source = next_file.create_audio_source()
|
||||||
|
# log.debug(f"Starting playback of {next_source}")
|
||||||
|
# voice_client.play(next_source, after=advance)
|
||||||
|
#
|
||||||
|
|
||||||
|
#
|
|
@ -1,7 +1,8 @@
|
||||||
|
import sys
|
||||||
import typing
|
import typing
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from ..utils import Command, NetworkHandler
|
from ..utils import Command, NetworkHandler, Call
|
||||||
from ..commands import NullCommand
|
from ..commands import NullCommand
|
||||||
from ..network import RoyalnetLink, Message, RequestError, RoyalnetConfig
|
from ..network import RoyalnetLink, Message, RequestError, RoyalnetConfig
|
||||||
from ..database import Alchemy, DatabaseConfig, relationshiplinkchain
|
from ..database import Alchemy, DatabaseConfig, relationshiplinkchain
|
||||||
|
@ -11,10 +12,12 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GenericBot:
|
class GenericBot:
|
||||||
|
"""A generic bot class, to be used as base for the other more specific classes, such as TelegramBot and DiscordBot."""
|
||||||
def _init_commands(self,
|
def _init_commands(self,
|
||||||
commands: typing.List[typing.Type[Command]],
|
commands: typing.List[typing.Type[Command]],
|
||||||
missing_command: typing.Type[Command],
|
missing_command: typing.Type[Command],
|
||||||
error_command: typing.Type[Command]):
|
error_command: typing.Type[Command]):
|
||||||
|
"""Generate the commands dictionary required to handle incoming messages, and the network_handlers dictionary required to handle incoming requests."""
|
||||||
log.debug(f"Now generating commands")
|
log.debug(f"Now generating commands")
|
||||||
self.commands: typing.Dict[str, typing.Type[Command]] = {}
|
self.commands: typing.Dict[str, typing.Type[Command]] = {}
|
||||||
self.network_handlers: typing.Dict[typing.Type[Message], typing.Type[NetworkHandler]] = {}
|
self.network_handlers: typing.Dict[typing.Type[Message], typing.Type[NetworkHandler]] = {}
|
||||||
|
@ -25,13 +28,19 @@ class GenericBot:
|
||||||
self.error_command: typing.Type[Command] = error_command
|
self.error_command: typing.Type[Command] = error_command
|
||||||
log.debug(f"Successfully generated commands")
|
log.debug(f"Successfully generated commands")
|
||||||
|
|
||||||
|
def _call_factory(self) -> typing.Type[Call]:
|
||||||
|
"""Create the Call class, representing a Call command. It should inherit from the utils.Call class."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _init_royalnet(self, royalnet_config: RoyalnetConfig):
|
def _init_royalnet(self, royalnet_config: RoyalnetConfig):
|
||||||
|
"""Create a RoyalnetLink, and run it as a task."""
|
||||||
self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, "discord",
|
self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, "discord",
|
||||||
self._network_handler)
|
self._network_handler)
|
||||||
log.debug(f"Running RoyalnetLink {self.network}")
|
log.debug(f"Running RoyalnetLink {self.network}")
|
||||||
loop.create_task(self.network.run())
|
loop.create_task(self.network.run())
|
||||||
|
|
||||||
def _network_handler(self, message: Message) -> Message:
|
def _network_handler(self, message: Message) -> Message:
|
||||||
|
"""Handle a single Message received from the RoyalnetLink"""
|
||||||
log.debug(f"Received {message} from the RoyalnetLink")
|
log.debug(f"Received {message} from the RoyalnetLink")
|
||||||
try:
|
try:
|
||||||
network_handler = self.network_handlers[message.__class__]
|
network_handler = self.network_handlers[message.__class__]
|
||||||
|
@ -46,25 +55,33 @@ class GenericBot:
|
||||||
return RequestError(exc)
|
return RequestError(exc)
|
||||||
|
|
||||||
def _init_database(self, commands: typing.List[typing.Type[Command]], database_config: DatabaseConfig):
|
def _init_database(self, commands: typing.List[typing.Type[Command]], database_config: DatabaseConfig):
|
||||||
|
"""Connect to the database, and create the missing tables required by the selected commands."""
|
||||||
|
log.debug(f"Initializing database")
|
||||||
required_tables = set()
|
required_tables = set()
|
||||||
for command in commands:
|
for command in commands:
|
||||||
required_tables = required_tables.union(command.require_alchemy_tables)
|
required_tables = required_tables.union(command.require_alchemy_tables)
|
||||||
|
log.debug(f"Found {len(required_tables)} required tables")
|
||||||
self.alchemy = Alchemy(database_config.database_uri, required_tables)
|
self.alchemy = Alchemy(database_config.database_uri, required_tables)
|
||||||
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
|
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
|
||||||
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
|
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
|
||||||
self.identity_column = self.identity_table.__getattribute__(self.identity_table,
|
self.identity_column = self.identity_table.__getattribute__(self.identity_table,
|
||||||
database_config.identity_column_name)
|
database_config.identity_column_name)
|
||||||
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
||||||
|
log.debug(f"Identity chain is {self.identity_chain}")
|
||||||
|
|
||||||
def __init__(self, *,
|
def __init__(self, *,
|
||||||
royalnet_config: RoyalnetConfig,
|
royalnet_config: typing.Optional[RoyalnetConfig] = None,
|
||||||
database_config: typing.Optional[DatabaseConfig] = None,
|
database_config: typing.Optional[DatabaseConfig] = None,
|
||||||
commands: typing.Optional[typing.List[typing.Type[Command]]] = None,
|
commands: typing.List[typing.Type[Command]] = None,
|
||||||
missing_command: typing.Type[Command] = NullCommand,
|
missing_command: typing.Type[Command] = NullCommand,
|
||||||
error_command: typing.Type[Command] = NullCommand):
|
error_command: typing.Type[Command] = NullCommand):
|
||||||
if commands is None:
|
if commands is None:
|
||||||
commands = []
|
commands = []
|
||||||
self._init_commands(commands, missing_command=missing_command, error_command=error_command)
|
self._init_commands(commands, missing_command=missing_command, error_command=error_command)
|
||||||
|
self._Call = self._call_factory()
|
||||||
|
if royalnet_config is None:
|
||||||
|
self.network = None
|
||||||
|
else:
|
||||||
self._init_royalnet(royalnet_config=royalnet_config)
|
self._init_royalnet(royalnet_config=royalnet_config)
|
||||||
if database_config is None:
|
if database_config is None:
|
||||||
self.alchemy = None
|
self.alchemy = None
|
||||||
|
@ -73,3 +90,22 @@ class GenericBot:
|
||||||
self.identity_column = None
|
self.identity_column = None
|
||||||
else:
|
else:
|
||||||
self._init_database(commands=commands, database_config=database_config)
|
self._init_database(commands=commands, database_config=database_config)
|
||||||
|
|
||||||
|
async def call(self, command_name: str, channel, parameters: typing.List[str] = None, **kwargs):
|
||||||
|
"""Call a command by its string, or missing_command if it doesn't exists, or error_command if an exception is raised during the execution."""
|
||||||
|
if parameters is None:
|
||||||
|
parameters = []
|
||||||
|
try:
|
||||||
|
command: typing.Type[Command] = self.commands[command_name]
|
||||||
|
except KeyError:
|
||||||
|
command = self.missing_command
|
||||||
|
try:
|
||||||
|
await self._Call(channel, command, parameters, **kwargs).run()
|
||||||
|
except Exception as exc:
|
||||||
|
await self._Call(channel, self.error_command,
|
||||||
|
exception_info=sys.exc_info(),
|
||||||
|
previous_command=command, **kwargs).run()
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""A blocking coroutine that should make the bot start listening to commands and requests."""
|
||||||
|
raise NotImplementedError()
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging as _logging
|
||||||
import traceback
|
import traceback
|
||||||
from ..utils import Command, Call
|
from ..utils import Command, Call
|
||||||
from ..error import NoneFoundError, \
|
from ..error import NoneFoundError, \
|
||||||
|
@ -9,6 +10,9 @@ from ..error import NoneFoundError, \
|
||||||
ExternalError
|
ExternalError
|
||||||
|
|
||||||
|
|
||||||
|
log = _logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ErrorHandlerCommand(Command):
|
class ErrorHandlerCommand(Command):
|
||||||
|
|
||||||
command_name = "error_handler"
|
command_name = "error_handler"
|
||||||
|
@ -46,4 +50,4 @@ class ErrorHandlerCommand(Command):
|
||||||
return
|
return
|
||||||
await call.reply(f"❌ Eccezione non gestita durante l'esecuzione del comando:\n[b]{e_type.__name__}[/b]\n{e_value}")
|
await call.reply(f"❌ Eccezione non gestita durante l'esecuzione del comando:\n[b]{e_type.__name__}[/b]\n{e_value}")
|
||||||
formatted_tb: str = '\n'.join(traceback.format_tb(e_tb))
|
formatted_tb: str = '\n'.join(traceback.format_tb(e_tb))
|
||||||
call.logger.error(f"Unhandled exception - {e_type.__name__}: {e_value}\n{formatted_tb}")
|
log.error(f"Unhandled exception - {e_type.__name__}: {e_value}\n{formatted_tb}")
|
||||||
|
|
|
@ -36,13 +36,18 @@ class Call:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
# These parameters / methods should be left alone
|
# These parameters / methods should be left alone
|
||||||
def __init__(self, channel, command: typing.Type[Command], command_args: list, logger: logging.Logger, **kwargs):
|
def __init__(self,
|
||||||
|
channel,
|
||||||
|
command: typing.Type[Command],
|
||||||
|
command_args: typing.List[str] = None,
|
||||||
|
**kwargs):
|
||||||
|
if command_args is None:
|
||||||
|
command_args = []
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.command = command
|
self.command = command
|
||||||
self.args = CommandArgs(command_args)
|
self.args = CommandArgs(command_args)
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.session = None
|
self.session = None
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
async def session_init(self):
|
async def session_init(self):
|
||||||
if not self.command.require_alchemy_tables:
|
if not self.command.require_alchemy_tables:
|
||||||
|
|
Loading…
Reference in a new issue