mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 19:44:20 +00:00
Some asyncio / loop improvements (i hope)
This commit is contained in:
parent
ce6593a8fd
commit
6ba220869f
13 changed files with 39 additions and 38 deletions
|
@ -10,7 +10,6 @@ from ..network import RoyalnetConfig, Request, ResponseSuccess, ResponseError
|
||||||
from ..database import DatabaseConfig
|
from ..database import DatabaseConfig
|
||||||
from ..audio import PlayMode, Playlist, RoyalPCMAudio
|
from ..audio import PlayMode, Playlist, RoyalPCMAudio
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
log = _logging.getLogger(__name__)
|
log = _logging.getLogger(__name__)
|
||||||
|
|
||||||
# TODO: Load the opus library
|
# TODO: Load the opus library
|
||||||
|
@ -231,7 +230,7 @@ class DiscordBot(GenericBot):
|
||||||
def advance(error=None):
|
def advance(error=None):
|
||||||
if error:
|
if error:
|
||||||
raise Exception(f"Error while advancing music_data: {error}")
|
raise Exception(f"Error while advancing music_data: {error}")
|
||||||
loop.create_task(self.advance_music_data(guild))
|
self.loop.create_task(self.advance_music_data(guild))
|
||||||
|
|
||||||
log.debug(f"Starting playback of {next_source}")
|
log.debug(f"Starting playback of {next_source}")
|
||||||
voice_client.play(next_source, after=advance)
|
voice_client.play(next_source, after=advance)
|
||||||
|
|
|
@ -8,7 +8,6 @@ from ..network import RoyalnetLink, Request, Response, ResponseError, RoyalnetCo
|
||||||
from ..database import Alchemy, DatabaseConfig, relationshiplinkchain
|
from ..database import Alchemy, DatabaseConfig, relationshiplinkchain
|
||||||
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,7 +44,7 @@ class GenericBot:
|
||||||
self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, self.interface_name,
|
self.network: RoyalnetLink = RoyalnetLink(royalnet_config.master_uri, royalnet_config.master_secret, self.interface_name,
|
||||||
self._network_handler)
|
self._network_handler)
|
||||||
log.debug(f"Running RoyalnetLink {self.network}")
|
log.debug(f"Running RoyalnetLink {self.network}")
|
||||||
loop.create_task(self.network.run())
|
self.loop.create_task(self.network.run())
|
||||||
|
|
||||||
async def _network_handler(self, request_dict: dict) -> dict:
|
async def _network_handler(self, request_dict: dict) -> dict:
|
||||||
"""Handle a single :py:class:`dict` received from the :py:class:`royalnet.network.RoyalnetLink`.
|
"""Handle a single :py:class:`dict` received from the :py:class:`royalnet.network.RoyalnetLink`.
|
||||||
|
@ -101,7 +100,12 @@ class GenericBot:
|
||||||
command_prefix: str,
|
command_prefix: str,
|
||||||
commands: typing.List[typing.Type[Command]] = None,
|
commands: typing.List[typing.Type[Command]] = None,
|
||||||
missing_command: typing.Type[Command] = NullCommand,
|
missing_command: typing.Type[Command] = NullCommand,
|
||||||
error_command: typing.Type[Command] = NullCommand):
|
error_command: typing.Type[Command] = NullCommand,
|
||||||
|
loop: asyncio.AbstractEventLoop = None):
|
||||||
|
if loop is None:
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
else:
|
||||||
|
self.loop = loop
|
||||||
if database_config is None:
|
if database_config is None:
|
||||||
self.alchemy = None
|
self.alchemy = None
|
||||||
self.master_table = None
|
self.master_table = None
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ..error import UnregisteredError, InvalidConfigError, RoyalnetResponseError
|
||||||
from ..network import RoyalnetConfig, Request, ResponseSuccess, ResponseError
|
from ..network import RoyalnetConfig, Request, ResponseSuccess, ResponseError
|
||||||
from ..database import DatabaseConfig
|
from ..database import DatabaseConfig
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
log = _logging.getLogger(__name__)
|
log = _logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -121,7 +121,7 @@ class TelegramBot(GenericBot):
|
||||||
# Handle updates
|
# Handle updates
|
||||||
for update in last_updates:
|
for update in last_updates:
|
||||||
# noinspection PyAsyncCall
|
# noinspection PyAsyncCall
|
||||||
loop.create_task(self._handle_update(update))
|
self.loop.create_task(self._handle_update(update))
|
||||||
# Recalculate offset
|
# Recalculate offset
|
||||||
try:
|
try:
|
||||||
self._offset = last_updates[-1].update_id + 1
|
self._offset = last_updates[-1].update_id + 1
|
||||||
|
|
|
@ -8,9 +8,6 @@ if typing.TYPE_CHECKING:
|
||||||
from ..bots import DiscordBot
|
from ..bots import DiscordBot
|
||||||
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
|
|
||||||
class CvNH(NetworkHandler):
|
class CvNH(NetworkHandler):
|
||||||
message_type = "discord_cv"
|
message_type = "discord_cv"
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
||||||
import logging as _logging
|
import logging as _logging
|
||||||
from ..utils import Command, Call
|
from ..utils import Command, Call
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
log = _logging.getLogger(__name__)
|
log = _logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,9 +11,6 @@ if typing.TYPE_CHECKING:
|
||||||
from ..bots import DiscordBot
|
from ..bots import DiscordBot
|
||||||
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
|
|
||||||
class PlayNH(NetworkHandler):
|
class PlayNH(NetworkHandler):
|
||||||
message_type = "music_play"
|
message_type = "music_play"
|
||||||
|
|
||||||
|
|
|
@ -8,9 +8,6 @@ if typing.TYPE_CHECKING:
|
||||||
from ..bots import DiscordBot
|
from ..bots import DiscordBot
|
||||||
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
|
|
||||||
class PlaymodeNH(NetworkHandler):
|
class PlaymodeNH(NetworkHandler):
|
||||||
message_type = "music_playmode"
|
message_type = "music_playmode"
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import typing
|
import typing
|
||||||
import discord
|
import discord
|
||||||
import asyncio
|
|
||||||
from ..utils import Command, Call, NetworkHandler
|
from ..utils import Command, Call, NetworkHandler
|
||||||
from ..network import Request, ResponseSuccess
|
from ..network import Request, ResponseSuccess
|
||||||
from ..error import NoneFoundError
|
from ..error import NoneFoundError
|
||||||
|
@ -8,9 +7,6 @@ if typing.TYPE_CHECKING:
|
||||||
from ..bots import DiscordBot
|
from ..bots import DiscordBot
|
||||||
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
|
|
||||||
class SummonNH(NetworkHandler):
|
class SummonNH(NetworkHandler):
|
||||||
message_type = "music_summon"
|
message_type = "music_summon"
|
||||||
|
|
||||||
|
@ -20,7 +16,7 @@ class SummonNH(NetworkHandler):
|
||||||
channel = bot.client.find_channel_by_name(data["channel_name"])
|
channel = bot.client.find_channel_by_name(data["channel_name"])
|
||||||
if not isinstance(channel, discord.VoiceChannel):
|
if not isinstance(channel, discord.VoiceChannel):
|
||||||
raise NoneFoundError("Channel is not a voice channel")
|
raise NoneFoundError("Channel is not a voice channel")
|
||||||
loop.create_task(bot.client.vc_connect_or_move(channel))
|
bot.loop.create_task(bot.client.vc_connect_or_move(channel))
|
||||||
return ResponseSuccess()
|
return ResponseSuccess()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import typing
|
import typing
|
||||||
import asyncio
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
@ -8,8 +7,6 @@ from ..utils import asyncify
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
from ..error import InvalidConfigError
|
from ..error import InvalidConfigError
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
|
|
||||||
class Alchemy:
|
class Alchemy:
|
||||||
"""A wrapper around SQLAlchemy declarative that allows to use multiple databases at once while maintaining a single table-class for both of them."""
|
"""A wrapper around SQLAlchemy declarative that allows to use multiple databases at once while maintaining a single table-class for both of them."""
|
||||||
|
|
|
@ -8,7 +8,7 @@ import logging as _logging
|
||||||
import typing
|
import typing
|
||||||
from .package import Package
|
from .package import Package
|
||||||
|
|
||||||
default_loop = asyncio.get_event_loop()
|
|
||||||
log = _logging.getLogger(__name__)
|
log = _logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,7 +35,11 @@ class NetworkError(Exception):
|
||||||
|
|
||||||
|
|
||||||
class PendingRequest:
|
class PendingRequest:
|
||||||
def __init__(self, *, loop=default_loop):
|
def __init__(self, *, loop: asyncio.AbstractEventLoop = None):
|
||||||
|
if loop is None:
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
else:
|
||||||
|
self.loop = loop
|
||||||
self.event: asyncio.Event = asyncio.Event(loop=loop)
|
self.event: asyncio.Event = asyncio.Event(loop=loop)
|
||||||
self.data: typing.Optional[dict] = None
|
self.data: typing.Optional[dict] = None
|
||||||
|
|
||||||
|
@ -67,8 +71,9 @@ def requires_identification(func):
|
||||||
|
|
||||||
class RoyalnetLink:
|
class RoyalnetLink:
|
||||||
def __init__(self, master_uri: str, secret: str, link_type: str, request_handler, *,
|
def __init__(self, master_uri: str, secret: str, link_type: str, request_handler, *,
|
||||||
loop: asyncio.AbstractEventLoop = default_loop):
|
loop: asyncio.AbstractEventLoop = None):
|
||||||
assert ":" not in link_type
|
if ":" in link_type:
|
||||||
|
raise ValueError("Link types cannot contain colons.")
|
||||||
self.master_uri: str = master_uri
|
self.master_uri: str = master_uri
|
||||||
self.link_type: str = link_type
|
self.link_type: str = link_type
|
||||||
self.nid: str = str(uuid.uuid4())
|
self.nid: str = str(uuid.uuid4())
|
||||||
|
@ -76,7 +81,10 @@ class RoyalnetLink:
|
||||||
self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None
|
self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None
|
||||||
self.request_handler = request_handler
|
self.request_handler = request_handler
|
||||||
self._pending_requests: typing.Dict[str, PendingRequest] = {}
|
self._pending_requests: typing.Dict[str, PendingRequest] = {}
|
||||||
self._loop: asyncio.AbstractEventLoop = loop
|
if loop is None:
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
else:
|
||||||
|
self._loop = loop
|
||||||
self.error_event: asyncio.Event = asyncio.Event(loop=self._loop)
|
self.error_event: asyncio.Event = asyncio.Event(loop=self._loop)
|
||||||
self.connect_event: asyncio.Event = asyncio.Event(loop=self._loop)
|
self.connect_event: asyncio.Event = asyncio.Event(loop=self._loop)
|
||||||
self.identify_event: asyncio.Event = asyncio.Event(loop=self._loop)
|
self.identify_event: asyncio.Event = asyncio.Event(loop=self._loop)
|
||||||
|
|
|
@ -7,7 +7,7 @@ import asyncio
|
||||||
import logging as _logging
|
import logging as _logging
|
||||||
from .package import Package
|
from .package import Package
|
||||||
|
|
||||||
default_loop = asyncio.get_event_loop()
|
|
||||||
log = _logging.getLogger(__name__)
|
log = _logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,12 +35,15 @@ class ConnectedClient:
|
||||||
|
|
||||||
|
|
||||||
class RoyalnetServer:
|
class RoyalnetServer:
|
||||||
def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = default_loop):
|
def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = None):
|
||||||
self.address: str = address
|
self.address: str = address
|
||||||
self.port: int = port
|
self.port: int = port
|
||||||
self.required_secret: str = required_secret
|
self.required_secret: str = required_secret
|
||||||
self.identified_clients: typing.List[ConnectedClient] = []
|
self.identified_clients: typing.List[ConnectedClient] = []
|
||||||
self._loop: asyncio.AbstractEventLoop = loop
|
if loop is None:
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
else:
|
||||||
|
self.loop = loop
|
||||||
|
|
||||||
def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]:
|
def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]:
|
||||||
assert not (nid and link_type)
|
assert not (nid and link_type)
|
||||||
|
|
|
@ -12,4 +12,5 @@ from .networkhandler import NetworkHandler
|
||||||
from .formatters import andformat, plusformat, fileformat, ytdldateformat, numberemojiformat
|
from .formatters import andformat, plusformat, fileformat, ytdldateformat, numberemojiformat
|
||||||
|
|
||||||
__all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs",
|
__all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs",
|
||||||
"NetworkHandler", "andformat", "plusformat", "fileformat", "ytdldateformat", "numberemojiformat"]
|
"NetworkHandler", "andformat", "plusformat", "fileformat", "ytdldateformat", "numberemojiformat",
|
||||||
|
"telegram_escape", "discord_escape"]
|
||||||
|
|
|
@ -6,9 +6,6 @@ if typing.TYPE_CHECKING:
|
||||||
from ..database import Alchemy
|
from ..database import Alchemy
|
||||||
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
|
|
||||||
class Call:
|
class Call:
|
||||||
"""A command call. An abstract class, sub-bots should create a new call class from this.
|
"""A command call. An abstract class, sub-bots should create a new call class from this.
|
||||||
|
|
||||||
|
@ -55,6 +52,7 @@ class Call:
|
||||||
channel,
|
channel,
|
||||||
command: typing.Type[Command],
|
command: typing.Type[Command],
|
||||||
command_args: typing.List[str] = None,
|
command_args: typing.List[str] = None,
|
||||||
|
loop: asyncio.AbstractEventLoop = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Create the call.
|
"""Create the call.
|
||||||
|
|
||||||
|
@ -66,6 +64,10 @@ class Call:
|
||||||
"""
|
"""
|
||||||
if command_args is None:
|
if command_args is None:
|
||||||
command_args = []
|
command_args = []
|
||||||
|
if loop is None:
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
else:
|
||||||
|
self.loop = loop
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.command = command
|
self.command = command
|
||||||
self.args = CommandArgs(command_args)
|
self.args = CommandArgs(command_args)
|
||||||
|
@ -76,7 +78,7 @@ class Call:
|
||||||
"""If the command requires database access, create a :py:class:`royalnet.database.Alchemy` session for this call, otherwise, do nothing."""
|
"""If the command requires database access, create a :py:class:`royalnet.database.Alchemy` session for this call, otherwise, do nothing."""
|
||||||
if not self.command.require_alchemy_tables:
|
if not self.command.require_alchemy_tables:
|
||||||
return
|
return
|
||||||
self.session = await loop.run_in_executor(None, self.alchemy.Session)
|
self.session = await self.loop.run_in_executor(None, self.alchemy.Session)
|
||||||
|
|
||||||
async def session_end(self):
|
async def session_end(self):
|
||||||
"""Close the previously created :py:class:`royalnet.database.Alchemy` session for this call (if it was created)."""
|
"""Close the previously created :py:class:`royalnet.database.Alchemy` session for this call (if it was created)."""
|
||||||
|
|
Loading…
Reference in a new issue