diff --git a/royalgames.py b/royalgames.py index b9eb6d42..5563c102 100644 --- a/royalgames.py +++ b/royalgames.py @@ -6,6 +6,7 @@ from royalnet.commands import * from royalnet.commands.debug_create import DebugCreateCommand from royalnet.commands.error_handler import ErrorHandlerCommand from royalnet.network import RoyalnetServer +from royalnet.database import DatabaseConfig from royalnet.database.tables import Royal, Telegram, Discord loop = asyncio.get_event_loop() @@ -20,11 +21,13 @@ commands = [PingCommand, ShipCommand, SmecdsCommand, ColorCommand, CiaoruoziComm KvrollCommand, VideoinfoCommand, SummonCommand, PlayCommand] master = RoyalnetServer("localhost", 1234, "sas") -# tg_bot = TelegramBot(os.environ["TG_AK"], "ws://localhost:1234", "sas", commands, os.environ["DB_PATH"], Royal, Telegram, "tg_id", error_command=ErrorHandlerCommand) -ds_bot = DiscordBot(os.environ["DS_AK"], "ws://localhost:1234", "sas", commands, os.environ["DB_PATH"], Royal, Discord, "discord_id", error_command=ErrorHandlerCommand) +tg_db_cfg = DatabaseConfig(os.environ["DB_PATH"], Royal, Telegram, "tg_id") +tg_bot = TelegramBot(os.environ["TG_AK"], "ws://localhost:1234", "sas", commands, NullCommand, ErrorHandlerCommand, tg_db_cfg) +ds_db_cfg = DatabaseConfig(os.environ["DB_PATH"], Royal, Discord, "discord_id") +ds_bot = DiscordBot(os.environ["DS_AK"], "ws://localhost:1234", "sas", commands, NullCommand, ErrorHandlerCommand, ds_db_cfg) loop.run_until_complete(master.run()) # Dirty hack, remove me asap -# loop.create_task(tg_bot.run()) +loop.create_task(tg_bot.run()) loop.create_task(ds_bot.run()) print("Starting loop...") loop.run_forever() diff --git a/royalnet/bots/discord.py b/royalnet/bots/discord.py index 86f51f61..40d33525 100644 --- a/royalnet/bots/discord.py +++ b/royalnet/bots/discord.py @@ -4,11 +4,11 @@ import typing import logging as _logging import sys from ..commands import NullCommand -from ..utils import asyncify, Call, Command -from ..error import UnregisteredError, NoneFoundError, TooManyFoundError +from ..utils import asyncify, Call, Command, NetworkHandler +from ..error import UnregisteredError, NoneFoundError, TooManyFoundError, InvalidConfigError from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError -from ..database import Alchemy, relationshiplinkchain -from ..audio import RoyalPCMFile, PlayMode, Playlist, Pool +from ..database import Alchemy, relationshiplinkchain, DatabaseConfig +from ..audio import RoyalPCMFile, PlayMode, Playlist loop = asyncio.get_event_loop() log = _logging.getLogger(__name__) @@ -18,31 +18,15 @@ if not discord.opus.is_loaded(): log.error("Opus is not loaded. Weird behaviour might emerge.") -class PlayMessage(Message): - def __init__(self, url: str, guild_identifier: typing.Optional[str] = None): - self.url: str = url - self.guild_identifier: typing.Optional[str] = guild_identifier - - -class SummonMessage(Message): - def __init__(self, channel_identifier: typing.Union[int, str], - guild_identifier: typing.Optional[typing.Union[int, str]]): - self.channel_identifier = channel_identifier - self.guild_identifier = guild_identifier - - class DiscordBot: def __init__(self, token: str, master_server_uri: str, master_server_secret: str, commands: typing.List[typing.Type[Command]], - database_uri: str, - master_table, - identity_table, - identity_column_name: str, missing_command: typing.Type[Command] = NullCommand, - error_command: typing.Type[Command] = NullCommand): + error_command: typing.Type[Command] = NullCommand, + database_config: typing.Optional[DatabaseConfig] = None): self.token = token # Generate commands self.missing_command = missing_command @@ -52,12 +36,25 @@ class DiscordBot: for command in commands: self.commands[f"!{command.command_name}"] = command required_tables = required_tables.union(command.require_alchemy_tables) + # Generate network handlers + self.network_handlers: typing.Dict[typing.Type[Message], typing.Type[NetworkHandler]] = {} + for command in commands: + self.network_handlers = {**self.network_handlers, **command.network_handler_dict()} # Generate the Alchemy database - self.alchemy = Alchemy(database_uri, required_tables) - self.master_table = self.alchemy.__getattribute__(master_table.__name__) - self.identity_table = self.alchemy.__getattribute__(identity_table.__name__) - self.identity_column = self.identity_table.__getattribute__(self.identity_table, identity_column_name) - self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table) + if database_config: + self.alchemy = Alchemy(database_config.database_uri, required_tables) + self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__) + self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__) + self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name) + self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table) + else: + if required_tables: + raise InvalidConfigError("Tables are required by the commands, but Alchemy is not configured") + self.alchemy = None + self.master_table = None + self.identity_table = None + self.identity_column = None + self.identity_chain = None # Connect to Royalnet self.network: RoyalnetLink = RoyalnetLink(master_server_uri, master_server_secret, "discord", self.network_handler) @@ -70,6 +67,7 @@ class DiscordBot: interface_name = "discord" interface_obj = self interface_prefix = "!" + alchemy = self.alchemy async def reply(call, text: str): @@ -213,18 +211,7 @@ class DiscordBot: async def network_handler(self, message: Message) -> Message: """Handle a Royalnet request.""" log.debug(f"Received {message} from Royalnet") - if isinstance(message, SummonMessage): - return await self.nh_summon(message) - elif isinstance(message, PlayMessage): - return await self.nh_play(message) - - async def nh_summon(self, message: SummonMessage): - """Handle a summon Royalnet request. That is, join a voice channel, or move to a different one if that is not possible.""" - channel = self.find_channel(message.channel_identifier) - if not isinstance(channel, discord.VoiceChannel): - raise NoneFoundError("Channel is not a voice channel") - loop.create_task(self.bot.vc_connect_or_move(channel)) - return RequestSuccessful() + return await self.network_handlers[message.__class__].discord(message) async def add_to_music_data(self, url: str, guild: discord.Guild): """Add a file to the corresponding music_data object.""" @@ -257,23 +244,6 @@ class DiscordBot: log.debug(f"Starting playback of {next_source}") voice_client.play(next_source, after=advance) - async def nh_play(self, message: PlayMessage): - """Handle a play Royalnet request. That is, add audio to a PlayMode.""" - # Find the matching guild - if message.guild_identifier: - guild = self.find_guild(message.guild_identifier) - else: - if len(self.music_data) != 1: - raise TooManyFoundError("Multiple guilds found") - guild = list(self.music_data)[0] - # Ensure the guild has a PlayMode before adding the file to it - if not self.music_data.get(guild): - # TODO: change Exception - raise Exception("No music_data for this guild") - # Start downloading - loop.create_task(self.add_to_music_data(message.url, guild)) - return RequestSuccessful() - async def run(self): await self.bot.login(self.token) await self.bot.connect() diff --git a/royalnet/bots/telegram.py b/royalnet/bots/telegram.py index 0ecc2ed9..64c44e76 100644 --- a/royalnet/bots/telegram.py +++ b/royalnet/bots/telegram.py @@ -5,9 +5,9 @@ import logging as _logging import sys from ..commands import NullCommand from ..utils import asyncify, Call, Command -from ..error import UnregisteredError +from ..error import UnregisteredError, InvalidConfigError from ..network import RoyalnetLink, Message, RequestError -from ..database import Alchemy, relationshiplinkchain +from ..database import Alchemy, relationshiplinkchain, DatabaseConfig loop = asyncio.get_event_loop() log = _logging.getLogger(__name__) @@ -23,12 +23,9 @@ class TelegramBot: master_server_uri: str, master_server_secret: str, commands: typing.List[typing.Type[Command]], - database_uri: str, - master_table, - identity_table, - identity_column_name: str, missing_command: typing.Type[Command] = NullCommand, - error_command: typing.Type[Command] = NullCommand): + error_command: typing.Type[Command] = NullCommand, + database_config: typing.Optional[DatabaseConfig] = None): self.bot: telegram.Bot = telegram.Bot(api_key) self.should_run: bool = False self.offset: int = -100 @@ -43,12 +40,20 @@ class TelegramBot: self.commands[f"/{command.command_name}"] = command required_tables = required_tables.union(command.require_alchemy_tables) # Generate the Alchemy database - self.alchemy = Alchemy(database_uri, required_tables) - self.master_table = self.alchemy.__getattribute__(master_table.__name__) - self.identity_table = self.alchemy.__getattribute__(identity_table.__name__) - self.identity_column = self.identity_table.__getattribute__(self.identity_table, identity_column_name) - self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table) - + if database_config: + self.alchemy = Alchemy(database_config.database_uri, required_tables) + self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__) + self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__) + self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name) + self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table) + else: + if required_tables: + raise InvalidConfigError("Tables are required by the commands, but Alchemy is not configured") + self.alchemy = None + self.master_table = None + self.identity_table = None + self.identity_column = None + self.identity_chain = None # noinspection PyMethodParameters class TelegramCall(Call): interface_name = "telegram" diff --git a/royalnet/commands/play.py b/royalnet/commands/play.py index 02851695..c4c05c91 100644 --- a/royalnet/commands/play.py +++ b/royalnet/commands/play.py @@ -1,7 +1,42 @@ import typing -from ..utils import Command, Call +import discord +import asyncio +from ..utils import Command, Call, NetworkHandler from ..network import Message, RequestSuccessful, RequestError -from ..bots.discord import PlayMessage +from ..error import TooManyFoundError +if typing.TYPE_CHECKING: + from ..bots import DiscordBot + + +loop = asyncio.get_event_loop() + + +class PlayMessage(Message): + def __init__(self, url: str, guild_identifier: typing.Optional[str] = None): + self.url: str = url + self.guild_identifier: typing.Optional[str] = guild_identifier + + +class PlayNH(NetworkHandler): + message_type = PlayMessage + + @classmethod + async def nh_play(cls, bot: "DiscordBot", message: PlayMessage): + """Handle a play Royalnet request. That is, add audio to a PlayMode.""" + # Find the matching guild + if message.guild_identifier: + guild = bot.find_guild(message.guild_identifier) + else: + if len(bot.music_data) != 1: + raise TooManyFoundError("Multiple guilds found") + guild = list(bot.music_data)[0] + # Ensure the guild has a PlayMode before adding the file to it + if not bot.music_data.get(guild): + # TODO: change Exception + raise Exception("No music_data for this guild") + # Start downloading + loop.create_task(bot.add_to_music_data(message.url, guild)) + return RequestSuccessful() class PlayCommand(Command): @@ -9,6 +44,8 @@ class PlayCommand(Command): command_description = "Riproduce una canzone in chat vocale." command_syntax = "[ [guild] ] (url)" + network_handlers = [PlayNH] + @classmethod async def common(cls, call: Call): guild, url = call.args.match(r"(?:\[(.+)])?\s*(\S+)\s*") diff --git a/royalnet/commands/summon.py b/royalnet/commands/summon.py index 4d80280f..7aeb5eb5 100644 --- a/royalnet/commands/summon.py +++ b/royalnet/commands/summon.py @@ -1,8 +1,34 @@ import typing import discord -from ..utils import Command, Call +import asyncio +from ..utils import Command, Call, NetworkHandler from ..network import Message, RequestSuccessful, RequestError -from ..bots.discord import SummonMessage +from ..error import NoneFoundError +if typing.TYPE_CHECKING: + from ..bots import DiscordBot + + +loop = asyncio.get_event_loop() + + +class SummonMessage(Message): + def __init__(self, channel_identifier: typing.Union[int, str], + guild_identifier: typing.Optional[typing.Union[int, str]] = None): + self.channel_identifier = channel_identifier + self.guild_identifier = guild_identifier + + +class SummonNH(NetworkHandler): + message_type = SummonMessage + + @classmethod + async def discord(cls, bot: "DiscordBot", message: SummonMessage): + """Handle a summon Royalnet request. That is, join a voice channel, or move to a different one if that is not possible.""" + channel = bot.find_channel(message.channel_identifier) + if not isinstance(channel, discord.VoiceChannel): + raise NoneFoundError("Channel is not a voice channel") + loop.create_task(bot.bot.vc_connect_or_move(channel)) + return RequestSuccessful() class SummonCommand(Command): @@ -11,6 +37,8 @@ class SummonCommand(Command): command_description = "Evoca il bot in un canale vocale." command_syntax = "[channelname]" + network_handlers = [SummonNH] + @classmethod async def common(cls, call: Call): channel_name: str = call.args[0].lstrip("#") diff --git a/royalnet/database/__init__.py b/royalnet/database/__init__.py index 70a54417..3081a4b0 100644 --- a/royalnet/database/__init__.py +++ b/royalnet/database/__init__.py @@ -1,4 +1,5 @@ from .alchemy import Alchemy from .relationshiplinkchain import relationshiplinkchain +from .databaseconfig import DatabaseConfig -__all__ = ["Alchemy", "relationshiplinkchain"] +__all__ = ["Alchemy", "relationshiplinkchain", "DatabaseConfig"] diff --git a/royalnet/database/alchemy.py b/royalnet/database/alchemy.py index ac1ce1a8..d99eea7b 100644 --- a/royalnet/database/alchemy.py +++ b/royalnet/database/alchemy.py @@ -4,7 +4,8 @@ from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from contextlib import contextmanager, asynccontextmanager -from ..utils import cdj, asyncify +from ..utils import asyncify +from ..error import InvalidConfigError loop = asyncio.get_event_loop() diff --git a/royalnet/database/databaseconfig.py b/royalnet/database/databaseconfig.py new file mode 100644 index 00000000..c0f12165 --- /dev/null +++ b/royalnet/database/databaseconfig.py @@ -0,0 +1,13 @@ +import typing + + +class DatabaseConfig: + def __init__(self, + database_uri: str, + master_table: typing.Type, + identity_table: typing.Type, + identity_column_name: str): + self.database_uri: str = database_uri + self.master_table: typing.Type = master_table + self.identity_table: typing.Type = identity_table + self.identity_column_name: str = identity_column_name diff --git a/royalnet/utils/__init__.py b/royalnet/utils/__init__.py index 925bd023..edcb1256 100644 --- a/royalnet/utils/__init__.py +++ b/royalnet/utils/__init__.py @@ -6,5 +6,7 @@ from .safeformat import safeformat from .classdictjanitor import cdj from .sleepuntil import sleep_until from .plusformat import plusformat +from .networkhandler import NetworkHandler -__all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs"] +__all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs", + "NetworkHandler"] diff --git a/royalnet/utils/call.py b/royalnet/utils/call.py index e3bd6ea1..eb68ae6f 100644 --- a/royalnet/utils/call.py +++ b/royalnet/utils/call.py @@ -3,8 +3,7 @@ import asyncio import logging from ..network.messages import Message from .command import Command -from royalnet.utils import CommandArgs - +from .commandargs import CommandArgs if typing.TYPE_CHECKING: from ..database import Alchemy @@ -57,10 +56,7 @@ class Call: async def run(self): await self.session_init() - try: - coroutine = getattr(self.command, self.interface_name) - except AttributeError: - coroutine = getattr(self.command, "common") + coroutine = getattr(self.command, self.interface_name) try: result = await coroutine(self) finally: diff --git a/royalnet/utils/command.py b/royalnet/utils/command.py index d1f0c077..a406746e 100644 --- a/royalnet/utils/command.py +++ b/royalnet/utils/command.py @@ -1,6 +1,9 @@ import typing +from ..error import UnsupportedError +from ..network import Message if typing.TYPE_CHECKING: from .call import Call + from ..utils import NetworkHandler class Command: @@ -12,5 +15,21 @@ class Command: require_alchemy_tables: typing.Set = set() - async def common(self, call: "Call"): - raise NotImplementedError() + network_handlers: typing.List[typing.Type["NetworkHandler"]] = {} + + @classmethod + async def common(cls, call: "Call"): + raise UnsupportedError() + + @classmethod + def network_handler_dict(cls): + d = {} + for network_handler in cls.network_handlers: + d[network_handler.message_type] = network_handler + return d + + def __getattribute__(self, item: str): + try: + return self.__dict__[item] + except KeyError: + return self.common diff --git a/royalnet/utils/commandargs.py b/royalnet/utils/commandargs.py index 490c201f..63a02b2f 100644 --- a/royalnet/utils/commandargs.py +++ b/royalnet/utils/commandargs.py @@ -35,4 +35,4 @@ class CommandArgs(list): try: return self[index] except InvalidInputError: - return default \ No newline at end of file + return default diff --git a/royalnet/utils/networkhandler.py b/royalnet/utils/networkhandler.py new file mode 100644 index 00000000..36cf4b49 --- /dev/null +++ b/royalnet/utils/networkhandler.py @@ -0,0 +1,14 @@ +from ..network import Message +from ..error import UnsupportedError + + +class NetworkHandler: + """The NetworkHandler functions are called when a specific Message type is received.""" + + message_type = NotImplemented + + def __getattribute__(self, item: str): + try: + return self.__dict__[item] + except KeyError: + raise UnsupportedError()