diff --git a/royalnet/bots/discord.py b/royalnet/bots/discord.py index 31226655..50b41255 100644 --- a/royalnet/bots/discord.py +++ b/royalnet/bots/discord.py @@ -4,11 +4,9 @@ import typing import logging as _logging import sys from ..commands import NullCommand -from ..commands.summon import SummonMessage -from ..commands.play import PlayMessage from ..utils import asyncify, Call, Command -from royalnet.error import UnregisteredError -from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError +from ..error import UnregisteredError, NoneFoundError, TooManyFoundError +from ..network import RoyalnetLink, Message, RequestSuccessful from ..database import Alchemy, relationshiplinkchain from ..audio import RoyalAudioFile @@ -20,6 +18,19 @@ if not discord.opus.is_loaded(): log.error("Opus is not loaded. Weird behaviour might emerge.") +class PlayMessage(Message): + def __init__(self, url: str, channel_identifier: typing.Optional[typing.Union[int, str]] = None): + self.url: str = url + self.channel_identifier: typing.Optional[typing.Union[int, str]] = channel_identifier + + +class SummonMessage(Message): + def __init__(self, channel_identifier: typing.Union[int, str], + guild_identifier: typing.Optional[typing.Union[int, str]]): + self.channel_identifier = channel_identifier + self.guild_identifier = guild_identifier + + class DiscordBot: def __init__(self, token: str, @@ -33,9 +44,9 @@ class DiscordBot: missing_command: typing.Type[Command] = NullCommand, error_command: typing.Type[Command] = NullCommand): self.token = token + # Generate commands self.missing_command = missing_command self.error_command = error_command - # Generate commands self.commands = {} required_tables = set() for command in commands: @@ -47,8 +58,9 @@ class DiscordBot: self.identity_table = self.alchemy.__getattribute__(identity_table.__name__) self.identity_column = self.identity_table.__getattribute__(self.identity_table, identity_column_name) self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table) - - self.network: RoyalnetLink = RoyalnetLink(master_server_uri, master_server_secret, "discord", self.network_handler) + # Connect to Royalnet + self.network: RoyalnetLink = RoyalnetLink(master_server_uri, master_server_secret, "discord", + self.network_handler) loop.create_task(self.network.run()) # noinspection PyMethodParameters @@ -60,16 +72,16 @@ class DiscordBot: async def reply(call, text: str): escaped_text = text.replace("*", "\\*") \ - .replace("_", "\\_") \ - .replace("`", "\\`") \ - .replace("[b]", "**") \ - .replace("[/b]", "**") \ - .replace("[i]", "_") \ - .replace("[/i]", "_") \ - .replace("[u]", "__") \ - .replace("[/u]", "__") \ - .replace("[c]", "`") \ - .replace("[/c]", "`") + .replace("_", "\\_") \ + .replace("`", "\\`") \ + .replace("[b]", "**") \ + .replace("[/b]", "**") \ + .replace("[i]", "_") \ + .replace("[/i]", "_") \ + .replace("[u]", "__") \ + .replace("[/u]", "__") \ + .replace("[c]", "`") \ + .replace("[/c]", "`") await call.channel.send(escaped_text) async def net_request(call, message: Message, destination: str): @@ -98,6 +110,7 @@ class DiscordBot: await channel.connect() except discord.errors.ClientException: # Move to the selected channel, instead of connecting + # noinspection PyUnusedLocal for voice_client in self.bot.voice_clients: voice_client: discord.VoiceClient if voice_client.guild != channel.guild: @@ -133,38 +146,71 @@ class DiscordBot: self.DiscordClient = DiscordClient self.bot = self.DiscordClient() + def find_guild(self, identifier: typing.Union[str, int]): + """Find the Guild with the specified identifier. Names are case-insensitive.""" + if isinstance(identifier, str): + all_guilds: typing.List[discord.Guild] = self.bot.guilds + matching_channels: typing.List[discord.Guild] = [] + for guild in all_guilds: + if guild.name.lower() == identifier.lower(): + matching_channels.append(guild) + if len(matching_channels) == 0: + raise NoneFoundError("No channels were found") + elif len(matching_channels) > 1: + raise TooManyFoundError("Too many channels were found") + return matching_channels[0] + elif isinstance(identifier, int): + return self.bot.get_guild(identifier) + raise TypeError("Invalid identifier type, should be str or int") + + def find_channel(self, identifier: typing.Union[str, int], guild: typing.Optional[discord.Guild] = None): + """Find the GuildChannel with the specified identifier. Names are case-insensitive.""" + if isinstance(identifier, str): + if guild is not None: + all_channels = guild.channels + else: + all_channels: typing.List[discord.abc.GuildChannel] = self.bot.get_all_channels() + matching_channels: typing.List[discord.abc.GuildChannel] = [] + for channel in all_channels: + if not (isinstance(channel, discord.TextChannel) + or isinstance(channel, discord.VoiceChannel) + or isinstance(channel, discord.CategoryChannel)): + continue + if channel.name.lower() == identifier.lower(): + matching_channels.append(channel) + if len(matching_channels) == 0: + raise NoneFoundError("No channels were found") + elif len(matching_channels) > 1: + raise TooManyFoundError("Too many channels were found") + return matching_channels[0] + elif isinstance(identifier, int): + channel: typing.List[discord.abc.GuildChannel] = self.bot.get_channel(identifier) + if ((isinstance(channel, discord.TextChannel) + or isinstance(channel, discord.VoiceChannel) + or isinstance(channel, discord.CategoryChannel)) + and guild): + assert channel.guild == guild + return channel + raise TypeError("Invalid identifier type, should be str or int") + async def network_handler(self, message: Message) -> Message: + """Handle a Royalnet request.""" if isinstance(message, SummonMessage): return await self.nh_summon(message) elif isinstance(message, PlayMessage): return await self.nh_play(message) async def nh_summon(self, message: SummonMessage): - channels: typing.List[discord.abc.GuildChannel] = self.bot.get_all_channels() - matching_channels: typing.List[discord.VoiceChannel] = [] - for channel in channels: - if isinstance(channel, discord.VoiceChannel): - if channel.name == message.channel_name: - matching_channels.append(channel) - if len(matching_channels) == 0: - return RequestError("No channels with a matching name found") - elif len(matching_channels) > 1: - return RequestError("Multiple channels with a matching name found") - matching_channel = matching_channels[0] - await self.bot.vc_connect_or_move(matching_channel) + """Handle a summon Royalnet request. That is, join a voice channel, or move to a different one if that is not possible.""" + channel = self.find_channel(message.channel_identifier) + if not isinstance(channel, discord.VoiceChannel): + raise NoneFoundError("Channel is not a voice channel") + loop.create_task(self.bot.vc_connect_or_move(channel)) return RequestSuccessful() async def nh_play(self, message: PlayMessage): - # TODO: actually do what's intended to do - # Download the audio - file = await asyncify(RoyalAudioFile.create_from_url, message.url) - # Get the audio source - audio_source = file[0].as_audio_source() - # Play the audio source - for voice_client in self.bot.voice_clients: - voice_client: discord.VoiceClient - voice_client.play(audio_source) - return RequestError() + """Handle a play Royalnet request. That is, add audio to a PlayMode.""" + raise async def run(self): await self.bot.login(self.token) diff --git a/royalnet/commands/play.py b/royalnet/commands/play.py index 09c569f1..5910840b 100644 --- a/royalnet/commands/play.py +++ b/royalnet/commands/play.py @@ -1,11 +1,7 @@ import typing from ..utils import Command, Call from ..network import Message, RequestSuccessful, RequestError - - -class PlayMessage(Message): - def __init__(self, url: str): - self.url: str = url +from ..bots.discord import PlayMessage class PlayCommand(Command): @@ -17,10 +13,5 @@ class PlayCommand(Command): async def common(cls, call: Call): url: str = call.args[0] response: typing.Union[RequestSuccessful, RequestError] = await call.net_request(PlayMessage(url), "discord") - if isinstance(response, RequestSuccessful): - await call.reply(f"✅ Richiesta la riproduzione di [c]{url}[/c].") - return - elif isinstance(response, RequestError): - await call.reply(f"⚠️ Si è verificato un'errore nella richiesta di riproduzione:\n[c]{response.reason}[/c]") - return - raise TypeError(f"Received unexpected response in the PlayCommand: {response.__class__.__name__}") + response.raise_on_error() + await call.reply(f"✅ Richiesta la riproduzione di [c]{url}[/c].") diff --git a/royalnet/commands/summon.py b/royalnet/commands/summon.py index 8951b167..4d80280f 100644 --- a/royalnet/commands/summon.py +++ b/royalnet/commands/summon.py @@ -2,11 +2,7 @@ import typing import discord from ..utils import Command, Call from ..network import Message, RequestSuccessful, RequestError - - -class SummonMessage(Message): - def __init__(self, channel_name: str): - self.channel_name: str = channel_name +from ..bots.discord import SummonMessage class SummonCommand(Command): @@ -19,13 +15,8 @@ class SummonCommand(Command): async def common(cls, call: Call): channel_name: str = call.args[0].lstrip("#") response: typing.Union[RequestSuccessful, RequestError] = await call.net_request(SummonMessage(channel_name), "discord") - if isinstance(response, RequestError): - await call.reply(f"⚠️ Si è verificato un'errore nella richiesta di connessione:\n[c]{response.exc}[/c]") - return - elif isinstance(response, RequestSuccessful): - await call.reply(f"✅ Mi sono connesso in [c]#{channel_name}[/c].") - return - raise TypeError(f"Received unexpected response type while summoning the bot: {response.__class__.__name__}") + response.raise_on_error() + await call.reply(f"✅ Mi sono connesso in [c]#{channel_name}[/c].") @classmethod async def discord(cls, call: Call): diff --git a/royalnet/network/messages.py b/royalnet/network/messages.py index 49dd5af1..d63b9c20 100644 --- a/royalnet/network/messages.py +++ b/royalnet/network/messages.py @@ -2,6 +2,9 @@ class Message: def __repr__(self): return f"<{self.__class__.__name__}>" + def raise_on_error(self): + pass + class IdentifySuccessfulMessage(Message): pass @@ -32,3 +35,6 @@ class RequestSuccessful(Message): class RequestError(Message): def __init__(self, exc: Exception): self.exc = exc + + def raise_on_error(self): + raise self.exc