mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 19:44:20 +00:00
Some asyncio / loop improvements (i hope)
This commit is contained in:
parent
ce6593a8fd
commit
6ba220869f
13 changed files with 39 additions and 38 deletions
|
@ -10,7 +10,6 @@ from ..network import RoyalnetConfig, Request, ResponseSuccess, ResponseError
|
|||
from ..database import DatabaseConfig
|
||||
from ..audio import PlayMode, Playlist, RoyalPCMAudio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
log = _logging.getLogger(__name__)
|
||||
|
||||
# TODO: Load the opus library
|
||||
|
@ -231,7 +230,7 @@ class DiscordBot(GenericBot):
|
|||
def advance(error=None):
|
||||
if error:
|
||||
raise Exception(f"Error while advancing music_data: {error}")
|
||||
loop.create_task(self.advance_music_data(guild))
|
||||
self.loop.create_task(self.advance_music_data(guild))
|
||||
|
||||
log.debug(f"Starting playback of {next_source}")
|
||||
voice_client.play(next_source, after=advance)
|
||||
|
|
|
@ -8,7 +8,6 @@ from ..network import RoyalnetLink, Request, Response, ResponseError, RoyalnetCo
|
|||
from ..database import Alchemy, DatabaseConfig, relationshiplinkchain
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -45,7 +44,7 @@ class GenericBot:
|
|||
self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, self.interface_name,
|
||||
self._network_handler)
|
||||
log.debug(f"Running RoyalnetLink {self.network}")
|
||||
loop.create_task(self.network.run())
|
||||
self.loop.create_task(self.network.run())
|
||||
|
||||
async def _network_handler(self, request_dict: dict) -> dict:
|
||||
"""Handle a single :py:class:`dict` received from the :py:class:`royalnet.network.RoyalnetLink`.
|
||||
|
@ -101,7 +100,12 @@ class GenericBot:
|
|||
command_prefix: str,
|
||||
commands: typing.List[typing.Type[Command]] = None,
|
||||
missing_command: typing.Type[Command] = NullCommand,
|
||||
error_command: typing.Type[Command] = NullCommand):
|
||||
error_command: typing.Type[Command] = NullCommand,
|
||||
loop: asyncio.AbstractEventLoop = None):
|
||||
if loop is None:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
else:
|
||||
self.loop = loop
|
||||
if database_config is None:
|
||||
self.alchemy = None
|
||||
self.master_table = None
|
||||
|
|
|
@ -10,7 +10,7 @@ from ..error import UnregisteredError, InvalidConfigError, RoyalnetResponseError
|
|||
from ..network import RoyalnetConfig, Request, ResponseSuccess, ResponseError
|
||||
from ..database import DatabaseConfig
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
log = _logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -121,7 +121,7 @@ class TelegramBot(GenericBot):
|
|||
# Handle updates
|
||||
for update in last_updates:
|
||||
# noinspection PyAsyncCall
|
||||
loop.create_task(self._handle_update(update))
|
||||
self.loop.create_task(self._handle_update(update))
|
||||
# Recalculate offset
|
||||
try:
|
||||
self._offset = last_updates[-1].update_id + 1
|
||||
|
|
|
@ -8,9 +8,6 @@ if typing.TYPE_CHECKING:
|
|||
from ..bots import DiscordBot
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
class CvNH(NetworkHandler):
|
||||
message_type = "discord_cv"
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import logging as _logging
|
||||
from ..utils import Command, Call
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
log = _logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
@ -11,9 +11,6 @@ if typing.TYPE_CHECKING:
|
|||
from ..bots import DiscordBot
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
class PlayNH(NetworkHandler):
|
||||
message_type = "music_play"
|
||||
|
||||
|
|
|
@ -8,9 +8,6 @@ if typing.TYPE_CHECKING:
|
|||
from ..bots import DiscordBot
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
class PlaymodeNH(NetworkHandler):
|
||||
message_type = "music_playmode"
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import typing
|
||||
import discord
|
||||
import asyncio
|
||||
from ..utils import Command, Call, NetworkHandler
|
||||
from ..network import Request, ResponseSuccess
|
||||
from ..error import NoneFoundError
|
||||
|
@ -8,9 +7,6 @@ if typing.TYPE_CHECKING:
|
|||
from ..bots import DiscordBot
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
class SummonNH(NetworkHandler):
|
||||
message_type = "music_summon"
|
||||
|
||||
|
@ -20,7 +16,7 @@ class SummonNH(NetworkHandler):
|
|||
channel = bot.client.find_channel_by_name(data["channel_name"])
|
||||
if not isinstance(channel, discord.VoiceChannel):
|
||||
raise NoneFoundError("Channel is not a voice channel")
|
||||
loop.create_task(bot.client.vc_connect_or_move(channel))
|
||||
bot.loop.create_task(bot.client.vc_connect_or_move(channel))
|
||||
return ResponseSuccess()
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import typing
|
||||
import asyncio
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
@ -8,8 +7,6 @@ from ..utils import asyncify
|
|||
# noinspection PyUnresolvedReferences
|
||||
from ..error import InvalidConfigError
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
class Alchemy:
|
||||
"""A wrapper around SQLAlchemy declarative that allows to use multiple databases at once while maintaining a single table-class for both of them."""
|
||||
|
|
|
@ -8,7 +8,7 @@ import logging as _logging
|
|||
import typing
|
||||
from .package import Package
|
||||
|
||||
default_loop = asyncio.get_event_loop()
|
||||
|
||||
log = _logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -35,7 +35,11 @@ class NetworkError(Exception):
|
|||
|
||||
|
||||
class PendingRequest:
|
||||
def __init__(self, *, loop=default_loop):
|
||||
def __init__(self, *, loop: asyncio.AbstractEventLoop = None):
|
||||
if loop is None:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
else:
|
||||
self.loop = loop
|
||||
self.event: asyncio.Event = asyncio.Event(loop=loop)
|
||||
self.data: typing.Optional[dict] = None
|
||||
|
||||
|
@ -67,8 +71,9 @@ def requires_identification(func):
|
|||
|
||||
class RoyalnetLink:
|
||||
def __init__(self, master_uri: str, secret: str, link_type: str, request_handler, *,
|
||||
loop: asyncio.AbstractEventLoop = default_loop):
|
||||
assert ":" not in link_type
|
||||
loop: asyncio.AbstractEventLoop = None):
|
||||
if ":" in link_type:
|
||||
raise ValueError("Link types cannot contain colons.")
|
||||
self.master_uri: str = master_uri
|
||||
self.link_type: str = link_type
|
||||
self.nid: str = str(uuid.uuid4())
|
||||
|
@ -76,7 +81,10 @@ class RoyalnetLink:
|
|||
self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None
|
||||
self.request_handler = request_handler
|
||||
self._pending_requests: typing.Dict[str, PendingRequest] = {}
|
||||
self._loop: asyncio.AbstractEventLoop = loop
|
||||
if loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
else:
|
||||
self._loop = loop
|
||||
self.error_event: asyncio.Event = asyncio.Event(loop=self._loop)
|
||||
self.connect_event: asyncio.Event = asyncio.Event(loop=self._loop)
|
||||
self.identify_event: asyncio.Event = asyncio.Event(loop=self._loop)
|
||||
|
|
|
@ -7,7 +7,7 @@ import asyncio
|
|||
import logging as _logging
|
||||
from .package import Package
|
||||
|
||||
default_loop = asyncio.get_event_loop()
|
||||
|
||||
log = _logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -35,12 +35,15 @@ class ConnectedClient:
|
|||
|
||||
|
||||
class RoyalnetServer:
|
||||
def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = default_loop):
|
||||
def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = None):
|
||||
self.address: str = address
|
||||
self.port: int = port
|
||||
self.required_secret: str = required_secret
|
||||
self.identified_clients: typing.List[ConnectedClient] = []
|
||||
self._loop: asyncio.AbstractEventLoop = loop
|
||||
if loop is None:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
else:
|
||||
self.loop = loop
|
||||
|
||||
def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]:
|
||||
assert not (nid and link_type)
|
||||
|
|
|
@ -12,4 +12,5 @@ from .networkhandler import NetworkHandler
|
|||
from .formatters import andformat, plusformat, fileformat, ytdldateformat, numberemojiformat
|
||||
|
||||
__all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs",
|
||||
"NetworkHandler", "andformat", "plusformat", "fileformat", "ytdldateformat", "numberemojiformat"]
|
||||
"NetworkHandler", "andformat", "plusformat", "fileformat", "ytdldateformat", "numberemojiformat",
|
||||
"telegram_escape", "discord_escape"]
|
||||
|
|
|
@ -6,9 +6,6 @@ if typing.TYPE_CHECKING:
|
|||
from ..database import Alchemy
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
class Call:
|
||||
"""A command call. An abstract class, sub-bots should create a new call class from this.
|
||||
|
||||
|
@ -55,6 +52,7 @@ class Call:
|
|||
channel,
|
||||
command: typing.Type[Command],
|
||||
command_args: typing.List[str] = None,
|
||||
loop: asyncio.AbstractEventLoop = None,
|
||||
**kwargs):
|
||||
"""Create the call.
|
||||
|
||||
|
@ -66,6 +64,10 @@ class Call:
|
|||
"""
|
||||
if command_args is None:
|
||||
command_args = []
|
||||
if loop is None:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
else:
|
||||
self.loop = loop
|
||||
self.channel = channel
|
||||
self.command = command
|
||||
self.args = CommandArgs(command_args)
|
||||
|
@ -76,7 +78,7 @@ class Call:
|
|||
"""If the command requires database access, create a :py:class:`royalnet.database.Alchemy` session for this call, otherwise, do nothing."""
|
||||
if not self.command.require_alchemy_tables:
|
||||
return
|
||||
self.session = await loop.run_in_executor(None, self.alchemy.Session)
|
||||
self.session = await self.loop.run_in_executor(None, self.alchemy.Session)
|
||||
|
||||
async def session_end(self):
|
||||
"""Close the previously created :py:class:`royalnet.database.Alchemy` session for this call (if it was created)."""
|
||||
|
|
Loading…
Reference in a new issue