diff --git a/royalnet/bots/discord.py b/royalnet/bots/discord.py index 62a7df92..db29fad9 100644 --- a/royalnet/bots/discord.py +++ b/royalnet/bots/discord.py @@ -10,7 +10,6 @@ from ..network import RoyalnetConfig, Request, ResponseSuccess, ResponseError from ..database import DatabaseConfig from ..audio import PlayMode, Playlist, RoyalPCMAudio -loop = asyncio.get_event_loop() log = _logging.getLogger(__name__) # TODO: Load the opus library @@ -231,7 +230,7 @@ class DiscordBot(GenericBot): def advance(error=None): if error: raise Exception(f"Error while advancing music_data: {error}") - loop.create_task(self.advance_music_data(guild)) + self.loop.create_task(self.advance_music_data(guild)) log.debug(f"Starting playback of {next_source}") voice_client.play(next_source, after=advance) diff --git a/royalnet/bots/generic.py b/royalnet/bots/generic.py index cb119a14..3c26fbff 100644 --- a/royalnet/bots/generic.py +++ b/royalnet/bots/generic.py @@ -8,7 +8,6 @@ from ..network import RoyalnetLink, Request, Response, ResponseError, RoyalnetCo from ..database import Alchemy, DatabaseConfig, relationshiplinkchain -loop = asyncio.get_event_loop() log = logging.getLogger(__name__) @@ -45,7 +44,7 @@ class GenericBot: self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, self.interface_name, self._network_handler) log.debug(f"Running RoyalnetLink {self.network}") - loop.create_task(self.network.run()) + self.loop.create_task(self.network.run()) async def _network_handler(self, request_dict: dict) -> dict: """Handle a single :py:class:`dict` received from the :py:class:`royalnet.network.RoyalnetLink`. @@ -101,7 +100,12 @@ class GenericBot: command_prefix: str, commands: typing.List[typing.Type[Command]] = None, missing_command: typing.Type[Command] = NullCommand, - error_command: typing.Type[Command] = NullCommand): + error_command: typing.Type[Command] = NullCommand, + loop: asyncio.AbstractEventLoop = None): + if loop is None: + self.loop = asyncio.get_event_loop() + else: + self.loop = loop if database_config is None: self.alchemy = None self.master_table = None diff --git a/royalnet/bots/telegram.py b/royalnet/bots/telegram.py index 430f979d..019be23a 100644 --- a/royalnet/bots/telegram.py +++ b/royalnet/bots/telegram.py @@ -10,7 +10,7 @@ from ..error import UnregisteredError, InvalidConfigError, RoyalnetResponseError from ..network import RoyalnetConfig, Request, ResponseSuccess, ResponseError from ..database import DatabaseConfig -loop = asyncio.get_event_loop() + log = _logging.getLogger(__name__) @@ -121,7 +121,7 @@ class TelegramBot(GenericBot): # Handle updates for update in last_updates: # noinspection PyAsyncCall - loop.create_task(self._handle_update(update)) + self.loop.create_task(self._handle_update(update)) # Recalculate offset try: self._offset = last_updates[-1].update_id + 1 diff --git a/royalnet/commands/cv.py b/royalnet/commands/cv.py index 338e5bf8..e6719ba7 100644 --- a/royalnet/commands/cv.py +++ b/royalnet/commands/cv.py @@ -8,9 +8,6 @@ if typing.TYPE_CHECKING: from ..bots import DiscordBot -loop = asyncio.get_event_loop() - - class CvNH(NetworkHandler): message_type = "discord_cv" diff --git a/royalnet/commands/missing.py b/royalnet/commands/missing.py index d5eb8392..6a5670fd 100644 --- a/royalnet/commands/missing.py +++ b/royalnet/commands/missing.py @@ -2,7 +2,7 @@ import asyncio import logging as _logging from ..utils import Command, Call -loop = asyncio.get_event_loop() + log = _logging.getLogger(__name__) diff --git a/royalnet/commands/play.py b/royalnet/commands/play.py index ee4f94d9..0b0f59f5 100644 --- a/royalnet/commands/play.py +++ b/royalnet/commands/play.py @@ -11,9 +11,6 @@ if typing.TYPE_CHECKING: from ..bots import DiscordBot -loop = asyncio.get_event_loop() - - class PlayNH(NetworkHandler): message_type = "music_play" diff --git a/royalnet/commands/playmode.py b/royalnet/commands/playmode.py index 8bc25a55..df4ac9f2 100644 --- a/royalnet/commands/playmode.py +++ b/royalnet/commands/playmode.py @@ -8,9 +8,6 @@ if typing.TYPE_CHECKING: from ..bots import DiscordBot -loop = asyncio.get_event_loop() - - class PlaymodeNH(NetworkHandler): message_type = "music_playmode" diff --git a/royalnet/commands/summon.py b/royalnet/commands/summon.py index a1329a81..f2cc2493 100644 --- a/royalnet/commands/summon.py +++ b/royalnet/commands/summon.py @@ -1,6 +1,5 @@ import typing import discord -import asyncio from ..utils import Command, Call, NetworkHandler from ..network import Request, ResponseSuccess from ..error import NoneFoundError @@ -8,9 +7,6 @@ if typing.TYPE_CHECKING: from ..bots import DiscordBot -loop = asyncio.get_event_loop() - - class SummonNH(NetworkHandler): message_type = "music_summon" @@ -20,7 +16,7 @@ class SummonNH(NetworkHandler): channel = bot.client.find_channel_by_name(data["channel_name"]) if not isinstance(channel, discord.VoiceChannel): raise NoneFoundError("Channel is not a voice channel") - loop.create_task(bot.client.vc_connect_or_move(channel)) + bot.loop.create_task(bot.client.vc_connect_or_move(channel)) return ResponseSuccess() diff --git a/royalnet/database/alchemy.py b/royalnet/database/alchemy.py index c413f343..d2b86469 100644 --- a/royalnet/database/alchemy.py +++ b/royalnet/database/alchemy.py @@ -1,5 +1,4 @@ import typing -import asyncio from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker @@ -8,8 +7,6 @@ from ..utils import asyncify # noinspection PyUnresolvedReferences from ..error import InvalidConfigError -loop = asyncio.get_event_loop() - class Alchemy: """A wrapper around SQLAlchemy declarative that allows to use multiple databases at once while maintaining a single table-class for both of them.""" diff --git a/royalnet/network/royalnetlink.py b/royalnet/network/royalnetlink.py index 992eeb89..1d0989d6 100644 --- a/royalnet/network/royalnetlink.py +++ b/royalnet/network/royalnetlink.py @@ -8,7 +8,7 @@ import logging as _logging import typing from .package import Package -default_loop = asyncio.get_event_loop() + log = _logging.getLogger(__name__) @@ -35,7 +35,11 @@ class NetworkError(Exception): class PendingRequest: - def __init__(self, *, loop=default_loop): + def __init__(self, *, loop: asyncio.AbstractEventLoop = None): + if loop is None: + self.loop = asyncio.get_event_loop() + else: + self.loop = loop self.event: asyncio.Event = asyncio.Event(loop=loop) self.data: typing.Optional[dict] = None @@ -67,8 +71,9 @@ def requires_identification(func): class RoyalnetLink: def __init__(self, master_uri: str, secret: str, link_type: str, request_handler, *, - loop: asyncio.AbstractEventLoop = default_loop): - assert ":" not in link_type + loop: asyncio.AbstractEventLoop = None): + if ":" in link_type: + raise ValueError("Link types cannot contain colons.") self.master_uri: str = master_uri self.link_type: str = link_type self.nid: str = str(uuid.uuid4()) @@ -76,7 +81,10 @@ class RoyalnetLink: self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None self.request_handler = request_handler self._pending_requests: typing.Dict[str, PendingRequest] = {} - self._loop: asyncio.AbstractEventLoop = loop + if loop is None: + self._loop = asyncio.get_event_loop() + else: + self._loop = loop self.error_event: asyncio.Event = asyncio.Event(loop=self._loop) self.connect_event: asyncio.Event = asyncio.Event(loop=self._loop) self.identify_event: asyncio.Event = asyncio.Event(loop=self._loop) diff --git a/royalnet/network/royalnetserver.py b/royalnet/network/royalnetserver.py index 1b328e8a..1b501930 100644 --- a/royalnet/network/royalnetserver.py +++ b/royalnet/network/royalnetserver.py @@ -7,7 +7,7 @@ import asyncio import logging as _logging from .package import Package -default_loop = asyncio.get_event_loop() + log = _logging.getLogger(__name__) @@ -35,12 +35,15 @@ class ConnectedClient: class RoyalnetServer: - def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = default_loop): + def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = None): self.address: str = address self.port: int = port self.required_secret: str = required_secret self.identified_clients: typing.List[ConnectedClient] = [] - self._loop: asyncio.AbstractEventLoop = loop + if loop is None: + self.loop = asyncio.get_event_loop() + else: + self.loop = loop def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]: assert not (nid and link_type) diff --git a/royalnet/utils/__init__.py b/royalnet/utils/__init__.py index 770b7196..353b5fe7 100644 --- a/royalnet/utils/__init__.py +++ b/royalnet/utils/__init__.py @@ -12,4 +12,5 @@ from .networkhandler import NetworkHandler from .formatters import andformat, plusformat, fileformat, ytdldateformat, numberemojiformat __all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs", - "NetworkHandler", "andformat", "plusformat", "fileformat", "ytdldateformat", "numberemojiformat"] + "NetworkHandler", "andformat", "plusformat", "fileformat", "ytdldateformat", "numberemojiformat", + "telegram_escape", "discord_escape"] diff --git a/royalnet/utils/call.py b/royalnet/utils/call.py index f45c8c82..b9ccf58c 100644 --- a/royalnet/utils/call.py +++ b/royalnet/utils/call.py @@ -6,9 +6,6 @@ if typing.TYPE_CHECKING: from ..database import Alchemy -loop = asyncio.get_event_loop() - - class Call: """A command call. An abstract class, sub-bots should create a new call class from this. @@ -55,6 +52,7 @@ class Call: channel, command: typing.Type[Command], command_args: typing.List[str] = None, + loop: asyncio.AbstractEventLoop = None, **kwargs): """Create the call. @@ -66,6 +64,10 @@ class Call: """ if command_args is None: command_args = [] + if loop is None: + self.loop = asyncio.get_event_loop() + else: + self.loop = loop self.channel = channel self.command = command self.args = CommandArgs(command_args) @@ -76,7 +78,7 @@ class Call: """If the command requires database access, create a :py:class:`royalnet.database.Alchemy` session for this call, otherwise, do nothing.""" if not self.command.require_alchemy_tables: return - self.session = await loop.run_in_executor(None, self.alchemy.Session) + self.session = await self.loop.run_in_executor(None, self.alchemy.Session) async def session_end(self): """Close the previously created :py:class:`royalnet.database.Alchemy` session for this call (if it was created)."""