1
Fork 0
mirror of https://github.com/RYGhub/royalnet.git synced 2024-11-23 19:44:20 +00:00

Some asyncio / loop improvements (i hope)

This commit is contained in:
Steffo 2019-06-13 00:23:49 +02:00
parent ce6593a8fd
commit 6ba220869f
13 changed files with 39 additions and 38 deletions

View file

@ -10,7 +10,6 @@ from ..network import RoyalnetConfig, Request, ResponseSuccess, ResponseError
from ..database import DatabaseConfig
from ..audio import PlayMode, Playlist, RoyalPCMAudio
loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__)
# TODO: Load the opus library
@ -231,7 +230,7 @@ class DiscordBot(GenericBot):
def advance(error=None):
if error:
raise Exception(f"Error while advancing music_data: {error}")
loop.create_task(self.advance_music_data(guild))
self.loop.create_task(self.advance_music_data(guild))
log.debug(f"Starting playback of {next_source}")
voice_client.play(next_source, after=advance)

View file

@ -8,7 +8,6 @@ from ..network import RoyalnetLink, Request, Response, ResponseError, RoyalnetCo
from ..database import Alchemy, DatabaseConfig, relationshiplinkchain
loop = asyncio.get_event_loop()
log = logging.getLogger(__name__)
@ -45,7 +44,7 @@ class GenericBot:
self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, self.interface_name,
self._network_handler)
log.debug(f"Running RoyalnetLink {self.network}")
loop.create_task(self.network.run())
self.loop.create_task(self.network.run())
async def _network_handler(self, request_dict: dict) -> dict:
"""Handle a single :py:class:`dict` received from the :py:class:`royalnet.network.RoyalnetLink`.
@ -101,7 +100,12 @@ class GenericBot:
command_prefix: str,
commands: typing.List[typing.Type[Command]] = None,
missing_command: typing.Type[Command] = NullCommand,
error_command: typing.Type[Command] = NullCommand):
error_command: typing.Type[Command] = NullCommand,
loop: asyncio.AbstractEventLoop = None):
if loop is None:
self.loop = asyncio.get_event_loop()
else:
self.loop = loop
if database_config is None:
self.alchemy = None
self.master_table = None

View file

@ -10,7 +10,7 @@ from ..error import UnregisteredError, InvalidConfigError, RoyalnetResponseError
from ..network import RoyalnetConfig, Request, ResponseSuccess, ResponseError
from ..database import DatabaseConfig
loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__)
@ -121,7 +121,7 @@ class TelegramBot(GenericBot):
# Handle updates
for update in last_updates:
# noinspection PyAsyncCall
loop.create_task(self._handle_update(update))
self.loop.create_task(self._handle_update(update))
# Recalculate offset
try:
self._offset = last_updates[-1].update_id + 1

View file

@ -8,9 +8,6 @@ if typing.TYPE_CHECKING:
from ..bots import DiscordBot
loop = asyncio.get_event_loop()
class CvNH(NetworkHandler):
message_type = "discord_cv"

View file

@ -2,7 +2,7 @@ import asyncio
import logging as _logging
from ..utils import Command, Call
loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__)

View file

@ -11,9 +11,6 @@ if typing.TYPE_CHECKING:
from ..bots import DiscordBot
loop = asyncio.get_event_loop()
class PlayNH(NetworkHandler):
message_type = "music_play"

View file

@ -8,9 +8,6 @@ if typing.TYPE_CHECKING:
from ..bots import DiscordBot
loop = asyncio.get_event_loop()
class PlaymodeNH(NetworkHandler):
message_type = "music_playmode"

View file

@ -1,6 +1,5 @@
import typing
import discord
import asyncio
from ..utils import Command, Call, NetworkHandler
from ..network import Request, ResponseSuccess
from ..error import NoneFoundError
@ -8,9 +7,6 @@ if typing.TYPE_CHECKING:
from ..bots import DiscordBot
loop = asyncio.get_event_loop()
class SummonNH(NetworkHandler):
message_type = "music_summon"
@ -20,7 +16,7 @@ class SummonNH(NetworkHandler):
channel = bot.client.find_channel_by_name(data["channel_name"])
if not isinstance(channel, discord.VoiceChannel):
raise NoneFoundError("Channel is not a voice channel")
loop.create_task(bot.client.vc_connect_or_move(channel))
bot.loop.create_task(bot.client.vc_connect_or_move(channel))
return ResponseSuccess()

View file

@ -1,5 +1,4 @@
import typing
import asyncio
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
@ -8,8 +7,6 @@ from ..utils import asyncify
# noinspection PyUnresolvedReferences
from ..error import InvalidConfigError
loop = asyncio.get_event_loop()
class Alchemy:
"""A wrapper around SQLAlchemy declarative that allows to use multiple databases at once while maintaining a single table-class for both of them."""

View file

@ -8,7 +8,7 @@ import logging as _logging
import typing
from .package import Package
default_loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__)
@ -35,7 +35,11 @@ class NetworkError(Exception):
class PendingRequest:
def __init__(self, *, loop=default_loop):
def __init__(self, *, loop: asyncio.AbstractEventLoop = None):
if loop is None:
self.loop = asyncio.get_event_loop()
else:
self.loop = loop
self.event: asyncio.Event = asyncio.Event(loop=loop)
self.data: typing.Optional[dict] = None
@ -67,8 +71,9 @@ def requires_identification(func):
class RoyalnetLink:
def __init__(self, master_uri: str, secret: str, link_type: str, request_handler, *,
loop: asyncio.AbstractEventLoop = default_loop):
assert ":" not in link_type
loop: asyncio.AbstractEventLoop = None):
if ":" in link_type:
raise ValueError("Link types cannot contain colons.")
self.master_uri: str = master_uri
self.link_type: str = link_type
self.nid: str = str(uuid.uuid4())
@ -76,7 +81,10 @@ class RoyalnetLink:
self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None
self.request_handler = request_handler
self._pending_requests: typing.Dict[str, PendingRequest] = {}
self._loop: asyncio.AbstractEventLoop = loop
if loop is None:
self._loop = asyncio.get_event_loop()
else:
self._loop = loop
self.error_event: asyncio.Event = asyncio.Event(loop=self._loop)
self.connect_event: asyncio.Event = asyncio.Event(loop=self._loop)
self.identify_event: asyncio.Event = asyncio.Event(loop=self._loop)

View file

@ -7,7 +7,7 @@ import asyncio
import logging as _logging
from .package import Package
default_loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__)
@ -35,12 +35,15 @@ class ConnectedClient:
class RoyalnetServer:
def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = default_loop):
def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = None):
self.address: str = address
self.port: int = port
self.required_secret: str = required_secret
self.identified_clients: typing.List[ConnectedClient] = []
self._loop: asyncio.AbstractEventLoop = loop
if loop is None:
self.loop = asyncio.get_event_loop()
else:
self.loop = loop
def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]:
assert not (nid and link_type)

View file

@ -12,4 +12,5 @@ from .networkhandler import NetworkHandler
from .formatters import andformat, plusformat, fileformat, ytdldateformat, numberemojiformat
__all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs",
"NetworkHandler", "andformat", "plusformat", "fileformat", "ytdldateformat", "numberemojiformat"]
"NetworkHandler", "andformat", "plusformat", "fileformat", "ytdldateformat", "numberemojiformat",
"telegram_escape", "discord_escape"]

View file

@ -6,9 +6,6 @@ if typing.TYPE_CHECKING:
from ..database import Alchemy
loop = asyncio.get_event_loop()
class Call:
"""A command call. An abstract class, sub-bots should create a new call class from this.
@ -55,6 +52,7 @@ class Call:
channel,
command: typing.Type[Command],
command_args: typing.List[str] = None,
loop: asyncio.AbstractEventLoop = None,
**kwargs):
"""Create the call.
@ -66,6 +64,10 @@ class Call:
"""
if command_args is None:
command_args = []
if loop is None:
self.loop = asyncio.get_event_loop()
else:
self.loop = loop
self.channel = channel
self.command = command
self.args = CommandArgs(command_args)
@ -76,7 +78,7 @@ class Call:
"""If the command requires database access, create a :py:class:`royalnet.database.Alchemy` session for this call, otherwise, do nothing."""
if not self.command.require_alchemy_tables:
return
self.session = await loop.run_in_executor(None, self.alchemy.Session)
self.session = await self.loop.run_in_executor(None, self.alchemy.Session)
async def session_end(self):
"""Close the previously created :py:class:`royalnet.database.Alchemy` session for this call (if it was created)."""