1
Fork 0
mirror of https://github.com/RYGhub/royalnet.git synced 2024-11-23 19:44:20 +00:00

Some asyncio / loop improvements (i hope)

This commit is contained in:
Steffo 2019-06-13 00:23:49 +02:00
parent ce6593a8fd
commit 6ba220869f
13 changed files with 39 additions and 38 deletions

View file

@ -10,7 +10,6 @@ from ..network import RoyalnetConfig, Request, ResponseSuccess, ResponseError
from ..database import DatabaseConfig from ..database import DatabaseConfig
from ..audio import PlayMode, Playlist, RoyalPCMAudio from ..audio import PlayMode, Playlist, RoyalPCMAudio
loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__) log = _logging.getLogger(__name__)
# TODO: Load the opus library # TODO: Load the opus library
@ -231,7 +230,7 @@ class DiscordBot(GenericBot):
def advance(error=None): def advance(error=None):
if error: if error:
raise Exception(f"Error while advancing music_data: {error}") raise Exception(f"Error while advancing music_data: {error}")
loop.create_task(self.advance_music_data(guild)) self.loop.create_task(self.advance_music_data(guild))
log.debug(f"Starting playback of {next_source}") log.debug(f"Starting playback of {next_source}")
voice_client.play(next_source, after=advance) voice_client.play(next_source, after=advance)

View file

@ -8,7 +8,6 @@ from ..network import RoyalnetLink, Request, Response, ResponseError, RoyalnetCo
from ..database import Alchemy, DatabaseConfig, relationshiplinkchain from ..database import Alchemy, DatabaseConfig, relationshiplinkchain
loop = asyncio.get_event_loop()
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -45,7 +44,7 @@ class GenericBot:
self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, self.interface_name, self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, self.interface_name,
self._network_handler) self._network_handler)
log.debug(f"Running RoyalnetLink {self.network}") log.debug(f"Running RoyalnetLink {self.network}")
loop.create_task(self.network.run()) self.loop.create_task(self.network.run())
async def _network_handler(self, request_dict: dict) -> dict: async def _network_handler(self, request_dict: dict) -> dict:
"""Handle a single :py:class:`dict` received from the :py:class:`royalnet.network.RoyalnetLink`. """Handle a single :py:class:`dict` received from the :py:class:`royalnet.network.RoyalnetLink`.
@ -101,7 +100,12 @@ class GenericBot:
command_prefix: str, command_prefix: str,
commands: typing.List[typing.Type[Command]] = None, commands: typing.List[typing.Type[Command]] = None,
missing_command: typing.Type[Command] = NullCommand, missing_command: typing.Type[Command] = NullCommand,
error_command: typing.Type[Command] = NullCommand): error_command: typing.Type[Command] = NullCommand,
loop: asyncio.AbstractEventLoop = None):
if loop is None:
self.loop = asyncio.get_event_loop()
else:
self.loop = loop
if database_config is None: if database_config is None:
self.alchemy = None self.alchemy = None
self.master_table = None self.master_table = None

View file

@ -10,7 +10,7 @@ from ..error import UnregisteredError, InvalidConfigError, RoyalnetResponseError
from ..network import RoyalnetConfig, Request, ResponseSuccess, ResponseError from ..network import RoyalnetConfig, Request, ResponseSuccess, ResponseError
from ..database import DatabaseConfig from ..database import DatabaseConfig
loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__) log = _logging.getLogger(__name__)
@ -121,7 +121,7 @@ class TelegramBot(GenericBot):
# Handle updates # Handle updates
for update in last_updates: for update in last_updates:
# noinspection PyAsyncCall # noinspection PyAsyncCall
loop.create_task(self._handle_update(update)) self.loop.create_task(self._handle_update(update))
# Recalculate offset # Recalculate offset
try: try:
self._offset = last_updates[-1].update_id + 1 self._offset = last_updates[-1].update_id + 1

View file

@ -8,9 +8,6 @@ if typing.TYPE_CHECKING:
from ..bots import DiscordBot from ..bots import DiscordBot
loop = asyncio.get_event_loop()
class CvNH(NetworkHandler): class CvNH(NetworkHandler):
message_type = "discord_cv" message_type = "discord_cv"

View file

@ -2,7 +2,7 @@ import asyncio
import logging as _logging import logging as _logging
from ..utils import Command, Call from ..utils import Command, Call
loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__) log = _logging.getLogger(__name__)

View file

@ -11,9 +11,6 @@ if typing.TYPE_CHECKING:
from ..bots import DiscordBot from ..bots import DiscordBot
loop = asyncio.get_event_loop()
class PlayNH(NetworkHandler): class PlayNH(NetworkHandler):
message_type = "music_play" message_type = "music_play"

View file

@ -8,9 +8,6 @@ if typing.TYPE_CHECKING:
from ..bots import DiscordBot from ..bots import DiscordBot
loop = asyncio.get_event_loop()
class PlaymodeNH(NetworkHandler): class PlaymodeNH(NetworkHandler):
message_type = "music_playmode" message_type = "music_playmode"

View file

@ -1,6 +1,5 @@
import typing import typing
import discord import discord
import asyncio
from ..utils import Command, Call, NetworkHandler from ..utils import Command, Call, NetworkHandler
from ..network import Request, ResponseSuccess from ..network import Request, ResponseSuccess
from ..error import NoneFoundError from ..error import NoneFoundError
@ -8,9 +7,6 @@ if typing.TYPE_CHECKING:
from ..bots import DiscordBot from ..bots import DiscordBot
loop = asyncio.get_event_loop()
class SummonNH(NetworkHandler): class SummonNH(NetworkHandler):
message_type = "music_summon" message_type = "music_summon"
@ -20,7 +16,7 @@ class SummonNH(NetworkHandler):
channel = bot.client.find_channel_by_name(data["channel_name"]) channel = bot.client.find_channel_by_name(data["channel_name"])
if not isinstance(channel, discord.VoiceChannel): if not isinstance(channel, discord.VoiceChannel):
raise NoneFoundError("Channel is not a voice channel") raise NoneFoundError("Channel is not a voice channel")
loop.create_task(bot.client.vc_connect_or_move(channel)) bot.loop.create_task(bot.client.vc_connect_or_move(channel))
return ResponseSuccess() return ResponseSuccess()

View file

@ -1,5 +1,4 @@
import typing import typing
import asyncio
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -8,8 +7,6 @@ from ..utils import asyncify
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from ..error import InvalidConfigError from ..error import InvalidConfigError
loop = asyncio.get_event_loop()
class Alchemy: class Alchemy:
"""A wrapper around SQLAlchemy declarative that allows to use multiple databases at once while maintaining a single table-class for both of them.""" """A wrapper around SQLAlchemy declarative that allows to use multiple databases at once while maintaining a single table-class for both of them."""

View file

@ -8,7 +8,7 @@ import logging as _logging
import typing import typing
from .package import Package from .package import Package
default_loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__) log = _logging.getLogger(__name__)
@ -35,7 +35,11 @@ class NetworkError(Exception):
class PendingRequest: class PendingRequest:
def __init__(self, *, loop=default_loop): def __init__(self, *, loop: asyncio.AbstractEventLoop = None):
if loop is None:
self.loop = asyncio.get_event_loop()
else:
self.loop = loop
self.event: asyncio.Event = asyncio.Event(loop=loop) self.event: asyncio.Event = asyncio.Event(loop=loop)
self.data: typing.Optional[dict] = None self.data: typing.Optional[dict] = None
@ -67,8 +71,9 @@ def requires_identification(func):
class RoyalnetLink: class RoyalnetLink:
def __init__(self, master_uri: str, secret: str, link_type: str, request_handler, *, def __init__(self, master_uri: str, secret: str, link_type: str, request_handler, *,
loop: asyncio.AbstractEventLoop = default_loop): loop: asyncio.AbstractEventLoop = None):
assert ":" not in link_type if ":" in link_type:
raise ValueError("Link types cannot contain colons.")
self.master_uri: str = master_uri self.master_uri: str = master_uri
self.link_type: str = link_type self.link_type: str = link_type
self.nid: str = str(uuid.uuid4()) self.nid: str = str(uuid.uuid4())
@ -76,7 +81,10 @@ class RoyalnetLink:
self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None
self.request_handler = request_handler self.request_handler = request_handler
self._pending_requests: typing.Dict[str, PendingRequest] = {} self._pending_requests: typing.Dict[str, PendingRequest] = {}
self._loop: asyncio.AbstractEventLoop = loop if loop is None:
self._loop = asyncio.get_event_loop()
else:
self._loop = loop
self.error_event: asyncio.Event = asyncio.Event(loop=self._loop) self.error_event: asyncio.Event = asyncio.Event(loop=self._loop)
self.connect_event: asyncio.Event = asyncio.Event(loop=self._loop) self.connect_event: asyncio.Event = asyncio.Event(loop=self._loop)
self.identify_event: asyncio.Event = asyncio.Event(loop=self._loop) self.identify_event: asyncio.Event = asyncio.Event(loop=self._loop)

View file

@ -7,7 +7,7 @@ import asyncio
import logging as _logging import logging as _logging
from .package import Package from .package import Package
default_loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__) log = _logging.getLogger(__name__)
@ -35,12 +35,15 @@ class ConnectedClient:
class RoyalnetServer: class RoyalnetServer:
def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = default_loop): def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = None):
self.address: str = address self.address: str = address
self.port: int = port self.port: int = port
self.required_secret: str = required_secret self.required_secret: str = required_secret
self.identified_clients: typing.List[ConnectedClient] = [] self.identified_clients: typing.List[ConnectedClient] = []
self._loop: asyncio.AbstractEventLoop = loop if loop is None:
self.loop = asyncio.get_event_loop()
else:
self.loop = loop
def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]: def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]:
assert not (nid and link_type) assert not (nid and link_type)

View file

@ -12,4 +12,5 @@ from .networkhandler import NetworkHandler
from .formatters import andformat, plusformat, fileformat, ytdldateformat, numberemojiformat from .formatters import andformat, plusformat, fileformat, ytdldateformat, numberemojiformat
__all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs", __all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs",
"NetworkHandler", "andformat", "plusformat", "fileformat", "ytdldateformat", "numberemojiformat"] "NetworkHandler", "andformat", "plusformat", "fileformat", "ytdldateformat", "numberemojiformat",
"telegram_escape", "discord_escape"]

View file

@ -6,9 +6,6 @@ if typing.TYPE_CHECKING:
from ..database import Alchemy from ..database import Alchemy
loop = asyncio.get_event_loop()
class Call: class Call:
"""A command call. An abstract class, sub-bots should create a new call class from this. """A command call. An abstract class, sub-bots should create a new call class from this.
@ -55,6 +52,7 @@ class Call:
channel, channel,
command: typing.Type[Command], command: typing.Type[Command],
command_args: typing.List[str] = None, command_args: typing.List[str] = None,
loop: asyncio.AbstractEventLoop = None,
**kwargs): **kwargs):
"""Create the call. """Create the call.
@ -66,6 +64,10 @@ class Call:
""" """
if command_args is None: if command_args is None:
command_args = [] command_args = []
if loop is None:
self.loop = asyncio.get_event_loop()
else:
self.loop = loop
self.channel = channel self.channel = channel
self.command = command self.command = command
self.args = CommandArgs(command_args) self.args = CommandArgs(command_args)
@ -76,7 +78,7 @@ class Call:
"""If the command requires database access, create a :py:class:`royalnet.database.Alchemy` session for this call, otherwise, do nothing.""" """If the command requires database access, create a :py:class:`royalnet.database.Alchemy` session for this call, otherwise, do nothing."""
if not self.command.require_alchemy_tables: if not self.command.require_alchemy_tables:
return return
self.session = await loop.run_in_executor(None, self.alchemy.Session) self.session = await self.loop.run_in_executor(None, self.alchemy.Session)
async def session_end(self): async def session_end(self):
"""Close the previously created :py:class:`royalnet.database.Alchemy` session for this call (if it was created).""" """Close the previously created :py:class:`royalnet.database.Alchemy` session for this call (if it was created)."""