mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 19:44:20 +00:00
sync
This commit is contained in:
parent
756b9d25ed
commit
9176144392
2 changed files with 126 additions and 132 deletions
|
@ -1,35 +1,50 @@
|
|||
import logging
|
||||
import asyncio
|
||||
from typing import Type, Optional, List, Callable, Union
|
||||
from royalnet.commands import Command, CommandInterface, CommandData, CommandArgs, CommandError, InvalidInputError, \
|
||||
UnsupportedError
|
||||
from royalnet.utils import asyncify
|
||||
from .escape import escape
|
||||
from ..serf import Serf
|
||||
|
||||
try:
|
||||
import discord
|
||||
import sentry_sdk
|
||||
import logging as _logging
|
||||
from .generic import GenericBot
|
||||
from royalnet.utils import *
|
||||
from royalnet.error import *
|
||||
from royalnet.bard import *
|
||||
from royalnet.commands import *
|
||||
except ImportError:
|
||||
discord = None
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm.session import Session
|
||||
from ..alchemyconfig import AlchemyConfig
|
||||
except ImportError:
|
||||
Session = None
|
||||
AlchemyConfig = None
|
||||
|
||||
try:
|
||||
from royalnet.herald import Config as HeraldConfig
|
||||
except ImportError:
|
||||
HeraldConfig = None
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
log = _logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MusicData:
|
||||
def __init__(self):
|
||||
self.playmode: playmodes.PlayMode = playmodes.Playlist()
|
||||
self.voice_client: typing.Optional[discord.VoiceClient] = None
|
||||
|
||||
def queue_preview(self):
|
||||
return self.playmode.queue_preview()
|
||||
|
||||
|
||||
class DiscordBot(GenericBot):
|
||||
"""A bot that connects to `Discord <https://discordapp.com/>`_."""
|
||||
class DiscordSerf(Serf):
|
||||
"""A :class:`Serf` that connects to `Discord <https://discordapp.com/>`_ as a bot."""
|
||||
interface_name = "discord"
|
||||
|
||||
def _init_voice(self):
|
||||
"""Initialize the variables needed for the connection to voice chat."""
|
||||
log.debug(f"Creating music_data dict")
|
||||
self.music_data: typing.Dict[discord.Guild, MusicData] = {}
|
||||
def __init__(self, *,
|
||||
alchemy_config: Optional[AlchemyConfig] = None,
|
||||
commands: List[Type[Command]] = None,
|
||||
network_config: Optional[HeraldConfig] = None,
|
||||
secrets_name: str = "__default__"):
|
||||
if discord is None:
|
||||
raise ImportError("'discord' extra is not installed")
|
||||
|
||||
def _interface_factory(self) -> typing.Type[CommandInterface]:
|
||||
super().__init__(alchemy_config=alchemy_config,
|
||||
commands=commands,
|
||||
network_config=network_config,
|
||||
secrets_name=secrets_name)
|
||||
|
||||
def _interface_factory(self) -> Type[CommandInterface]:
|
||||
# noinspection PyPep8Naming
|
||||
GenericInterface = super().interface_factory()
|
||||
|
||||
|
@ -40,22 +55,22 @@ class DiscordBot(GenericBot):
|
|||
|
||||
return DiscordInterface
|
||||
|
||||
def _data_factory(self) -> typing.Type[CommandData]:
|
||||
def _data_factory(self) -> Type[CommandData]:
|
||||
# noinspection PyMethodParameters,PyAbstractClass
|
||||
class DiscordData(CommandData):
|
||||
def __init__(data, interface: CommandInterface, message: discord.Message):
|
||||
super().__init__(interface)
|
||||
def __init__(data, interface: CommandInterface, session, message: discord.Message):
|
||||
super().__init__(interface=interface, session=session)
|
||||
data.message = message
|
||||
|
||||
async def reply(data, text: str):
|
||||
await data.message.channel.send(discord_escape(text))
|
||||
await data.message.channel.send(escape(text))
|
||||
|
||||
async def get_author(data, error_if_none=False):
|
||||
user: discord.Member = data.message.author
|
||||
query = data.session.query(self.master_table)
|
||||
for link in self.identity_chain:
|
||||
query = data.session.query(self._master_table)
|
||||
for link in self._identity_chain:
|
||||
query = query.join(link.mapper.class_)
|
||||
query = query.filter(self.identity_column == user.id)
|
||||
query = query.filter(self._identity_column == user.id)
|
||||
result = await asyncify(query.one_or_none)
|
||||
if result is None and error_if_none:
|
||||
raise CommandError("You must be registered to use this command.")
|
||||
|
@ -66,42 +81,8 @@ class DiscordBot(GenericBot):
|
|||
|
||||
return DiscordData
|
||||
|
||||
def _bot_factory(self) -> typing.Type[discord.Client]:
|
||||
"""Create a custom DiscordClient class inheriting from :py:class:`discord.Client`."""
|
||||
log.debug(f"Creating DiscordClient")
|
||||
|
||||
# noinspection PyMethodParameters
|
||||
class DiscordClient(discord.Client):
|
||||
async def vc_connect_or_move(cli, channel: discord.VoiceChannel):
|
||||
music_data = self.music_data.get(channel.guild)
|
||||
if music_data is None:
|
||||
# Create a MusicData object
|
||||
music_data = MusicData()
|
||||
self.music_data[channel.guild] = music_data
|
||||
# Connect to voice
|
||||
log.debug(f"Connecting to Voice in {channel}")
|
||||
try:
|
||||
music_data.voice_client = await channel.connect(reconnect=False, timeout=10)
|
||||
except Exception:
|
||||
log.warning(f"Failed to connect to Voice in {channel}")
|
||||
del self.music_data[channel.guild]
|
||||
raise
|
||||
else:
|
||||
log.debug(f"Connected to Voice in {channel}")
|
||||
else:
|
||||
if music_data.voice_client is None:
|
||||
# TODO: change exception type
|
||||
raise Exception("Another connection attempt is already in progress.")
|
||||
# Try to move to a different channel
|
||||
voice_client = music_data.voice_client
|
||||
log.debug(f"Moving {voice_client} to {channel}")
|
||||
await voice_client.move_to(channel)
|
||||
log.debug(f"Moved {voice_client} to {channel}")
|
||||
|
||||
async def on_message(cli, message: discord.Message):
|
||||
self.loop.create_task(cli._handle_message(message))
|
||||
|
||||
async def _handle_message(cli, message: discord.Message):
|
||||
async def handle_message(self, message: discord.Message):
|
||||
"""Handle a Discord message by calling a command if appropriate."""
|
||||
text = message.content
|
||||
# Skip non-text messages
|
||||
if not text:
|
||||
|
@ -110,7 +91,7 @@ class DiscordBot(GenericBot):
|
|||
if not text.startswith("!"):
|
||||
return
|
||||
# Skip bot messages
|
||||
author: typing.Union[discord.User] = message.author
|
||||
author: Union[discord.User] = message.author
|
||||
if author.bot:
|
||||
return
|
||||
# Find and clean parameters
|
||||
|
@ -123,13 +104,18 @@ class DiscordBot(GenericBot):
|
|||
except KeyError:
|
||||
# Skip the message
|
||||
return
|
||||
# Prepare data
|
||||
data = self._Data(interface=command.interface, message=message)
|
||||
# Call the command
|
||||
log.debug(f"Calling command '{command.name}'")
|
||||
with message.channel.typing():
|
||||
# Run the command
|
||||
# Open an alchemy session, if available
|
||||
if self.alchemy is not None:
|
||||
session = await asyncify(self.alchemy.Session)
|
||||
else:
|
||||
session = None
|
||||
# Prepare data
|
||||
data = self.Data(interface=command.interface, session=session, message=message)
|
||||
try:
|
||||
# Run the command
|
||||
await command.run(CommandArgs(parameters), data)
|
||||
except InvalidInputError as e:
|
||||
await data.reply(f":warning: {e.message}\n"
|
||||
|
@ -139,44 +125,52 @@ class DiscordBot(GenericBot):
|
|||
except CommandError as e:
|
||||
await data.reply(f":warning: {e.message}")
|
||||
except Exception as e:
|
||||
sentry_sdk.capture_exception(e)
|
||||
error_message = f"🦀 [b]{e.__class__.__name__}[/b] 🦀\n"
|
||||
error_message += '\n'.join(e.args)
|
||||
self.sentry_exc(e)
|
||||
error_message = f"🦀 [b]{e.__class__.__name__}[/b] 🦀\n" \
|
||||
'\n'.join(e.args)
|
||||
await data.reply(error_message)
|
||||
# Close the data session
|
||||
await data.session_close()
|
||||
finally:
|
||||
# Close the alchemy session
|
||||
if session is not None:
|
||||
await asyncify(session.close)
|
||||
|
||||
async def on_connect(cli):
|
||||
log.debug("Connected to Discord")
|
||||
|
||||
async def on_disconnect(cli):
|
||||
log.error("Disconnected from Discord!")
|
||||
def _bot_factory(self) -> Type[discord.Client]:
|
||||
"""Create a custom class inheriting from :py:class:`discord.Client`."""
|
||||
# noinspection PyMethodParameters
|
||||
class DiscordClient(discord.Client):
|
||||
async def on_message(cli, message: discord.Message):
|
||||
"""Handle messages received by passing them to the handle_message method of the bot."""
|
||||
# TODO: keep reference to these tasks somewhere
|
||||
self.loop.create_task(self.handle_message(message))
|
||||
|
||||
async def on_ready(cli) -> None:
|
||||
log.debug("Connection successful, client is ready")
|
||||
"""Change the bot presence to ``online`` when the bot is ready."""
|
||||
await cli.change_presence(status=discord.Status.online)
|
||||
|
||||
def find_guild_by_name(cli, name: str) -> typing.List[discord.Guild]:
|
||||
"""Find the :py:class:`discord.Guild` with the specified name (case insensitive)."""
|
||||
all_guilds: typing.List[discord.Guild] = cli.guilds
|
||||
matching_channels: typing.List[discord.Guild] = []
|
||||
def find_guild(cli, name: str) -> List[discord.Guild]:
|
||||
"""Find the :class:`discord.Guild`s with the specified name (case insensitive).
|
||||
|
||||
Returns:
|
||||
A :class:`list` of :class:`discord.Guild` having the specified name."""
|
||||
all_guilds: List[discord.Guild] = cli.guilds
|
||||
matching_channels: List[discord.Guild] = []
|
||||
for guild in all_guilds:
|
||||
if guild.name.lower() == name.lower():
|
||||
matching_channels.append(guild)
|
||||
return matching_channels
|
||||
|
||||
def find_channel_by_name(cli,
|
||||
def find_channel(cli,
|
||||
name: str,
|
||||
guild: typing.Optional[discord.Guild] = None) -> typing.List[discord.abc.GuildChannel]:
|
||||
"""Find the :py:class:`TextChannel`, :py:class:`VoiceChannel` or :py:class:`CategoryChannel` with the
|
||||
guild: Optional[discord.Guild] = None) -> List[discord.abc.GuildChannel]:
|
||||
"""Find the :class:`TextChannel`, :class:`VoiceChannel` or :class:`CategoryChannel` with the
|
||||
specified name (case insensitive).
|
||||
|
||||
You can specify a guild to only find channels in that specific guild."""
|
||||
You can specify a guild to only search in that specific guild."""
|
||||
if guild is not None:
|
||||
all_channels = guild.channels
|
||||
else:
|
||||
all_channels: typing.List[discord.abc.GuildChannel] = cli.get_all_channels()
|
||||
matching_channels: typing.List[discord.abc.GuildChannel] = []
|
||||
all_channels: List[discord.abc.GuildChannel] = cli.get_all_channels()
|
||||
matching_channels: List[discord.abc.GuildChannel] = []
|
||||
for channel in all_channels:
|
||||
if not (isinstance(channel, discord.TextChannel)
|
||||
or isinstance(channel, discord.VoiceChannel)
|
||||
|
@ -186,8 +180,9 @@ class DiscordBot(GenericBot):
|
|||
matching_channels.append(channel)
|
||||
return matching_channels
|
||||
|
||||
def find_voice_client_by_guild(cli, guild: discord.Guild) -> typing.Optional[discord.VoiceClient]:
|
||||
def find_voice_client(cli, guild: discord.Guild) -> Optional[discord.VoiceClient]:
|
||||
"""Find the :py:class:`discord.VoiceClient` belonging to a specific :py:class:`discord.Guild`."""
|
||||
# TODO: the bug I was looking for might be here
|
||||
for voice_client in cli.voice_clients:
|
||||
if voice_client.guild == guild:
|
||||
return voice_client
|
||||
|
@ -195,6 +190,8 @@ class DiscordBot(GenericBot):
|
|||
|
||||
return DiscordClient
|
||||
|
||||
# TODO: restart from here
|
||||
|
||||
def _init_client(self):
|
||||
"""Create an instance of the DiscordClient class created in :py:func:`royalnet.bots.DiscordBot._bot_factory`."""
|
||||
log.debug(f"Creating DiscordClient instance")
|
||||
|
|
|
@ -32,7 +32,7 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TelegramSerf(Serf):
|
||||
"""A Serf that connects to `Telegram <https://telegram.org/>`_."""
|
||||
"""A Serf that connects to `Telegram <https://telegram.org/>`_ as a bot."""
|
||||
interface_name = "telegram"
|
||||
|
||||
def __init__(self, *,
|
||||
|
@ -94,9 +94,6 @@ class TelegramSerf(Serf):
|
|||
name = self.interface_name
|
||||
prefix = "/"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
return TelegramInterface
|
||||
|
||||
def data_factory(self) -> Type[CommandData]:
|
||||
|
|
Loading…
Reference in a new issue