mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 19:44:20 +00:00
Improve upon a few thingies
This commit is contained in:
parent
2c97a6cf47
commit
a4f2cc5e46
16 changed files with 190 additions and 98 deletions
|
@ -1,7 +1,9 @@
|
|||
from contextlib import contextmanager, asynccontextmanager
|
||||
from typing import *
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.declarative.api import DeclarativeMeta
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
@ -15,6 +17,8 @@ if TYPE_CHECKING:
|
|||
# noinspection PyProtectedMember
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Alchemy:
|
||||
"""A wrapper around :mod:`sqlalchemy.orm` that allows the instantiation of multiple engines at once while
|
||||
|
@ -42,7 +46,12 @@ class Alchemy:
|
|||
# noinspection PyTypeChecker
|
||||
bound_table: Table = type(name, (self._Base, table), {})
|
||||
self._tables[name] = bound_table
|
||||
self._Base.metadata.create_all()
|
||||
# FIXME: Dirty hack
|
||||
try:
|
||||
self._Base.metadata.create_all()
|
||||
except ProgrammingError:
|
||||
log.warning("Skipping table creation, as it is probably being created by a different process.")
|
||||
|
||||
|
||||
def get(self, table: Union[str, type]) -> DeclarativeMeta:
|
||||
"""Get the table with a specified name or class.
|
||||
|
@ -66,7 +75,7 @@ class Alchemy:
|
|||
else:
|
||||
raise TypeError(f"Can't get tables with objects of type '{table.__class__.__qualname__}'")
|
||||
|
||||
@contextmanager
|
||||
@contextlib.contextmanager
|
||||
def session_cm(self) -> Iterator[Session]:
|
||||
"""Create a Session as a context manager (that can be used in ``with`` statements).
|
||||
|
||||
|
@ -91,7 +100,7 @@ class Alchemy:
|
|||
finally:
|
||||
session.close()
|
||||
|
||||
@asynccontextmanager
|
||||
@contextlib.asynccontextmanager
|
||||
async def session_acm(self) -> AsyncIterator[Session]:
|
||||
"""Create a Session as a async context manager (that can be used in ``async with`` statements).
|
||||
|
||||
|
|
|
@ -15,12 +15,16 @@ class RoyalnetsyncCommand(rc.Command):
|
|||
syntax: str = "{username} {password}"
|
||||
|
||||
async def run(self, args: rc.CommandArgs, data: rc.CommandData) -> None:
|
||||
author = await data.get_author(error_if_none=False)
|
||||
if author is not None:
|
||||
raise rc.UserError(f"This account is already connected to {author}!")
|
||||
|
||||
username = args[0]
|
||||
password = " ".join(args[1:])
|
||||
|
||||
author = await data.get_author(error_if_none=True)
|
||||
|
||||
user = await data.find_user(username)
|
||||
if user is None:
|
||||
raise rc.UserError("No such user.")
|
||||
try:
|
||||
successful = user.test_password(password)
|
||||
except ValueError:
|
||||
|
@ -28,60 +32,61 @@ class RoyalnetsyncCommand(rc.Command):
|
|||
if not successful:
|
||||
raise rc.InvalidInputError(f"Invalid password!")
|
||||
|
||||
if isinstance(self.serf, rst.TelegramSerf):
|
||||
import telegram
|
||||
message: telegram.Message = data.message
|
||||
from_user: telegram.User = message.from_user
|
||||
TelegramT = self.alchemy.get(Telegram)
|
||||
tg_user: Telegram = await ru.asyncify(
|
||||
data.session.query(TelegramT).filter_by(tg_id=from_user.id).one_or_none
|
||||
)
|
||||
if tg_user is None:
|
||||
# Create
|
||||
tg_user = TelegramT(
|
||||
user=author,
|
||||
tg_id=from_user.id,
|
||||
first_name=from_user.first_name,
|
||||
last_name=from_user.last_name,
|
||||
username=from_user.username
|
||||
async with data.session_acm() as session:
|
||||
if isinstance(self.serf, rst.TelegramSerf):
|
||||
import telegram
|
||||
message: telegram.Message = data.message
|
||||
from_user: telegram.User = message.from_user
|
||||
TelegramT = self.alchemy.get(Telegram)
|
||||
tg_user: Telegram = await ru.asyncify(
|
||||
session.query(TelegramT).filter_by(tg_id=from_user.id).one_or_none
|
||||
)
|
||||
data.session.add(tg_user)
|
||||
else:
|
||||
# Edit
|
||||
tg_user.first_name = from_user.first_name
|
||||
tg_user.last_name = from_user.last_name
|
||||
tg_user.username = from_user.username
|
||||
await data.session_commit()
|
||||
await data.reply(f"↔️ Account {tg_user} synced to {author}!")
|
||||
if tg_user is None:
|
||||
# Create
|
||||
tg_user = TelegramT(
|
||||
user=user,
|
||||
tg_id=from_user.id,
|
||||
first_name=from_user.first_name,
|
||||
last_name=from_user.last_name,
|
||||
username=from_user.username
|
||||
)
|
||||
session.add(tg_user)
|
||||
else:
|
||||
# Edit
|
||||
tg_user.first_name = from_user.first_name
|
||||
tg_user.last_name = from_user.last_name
|
||||
tg_user.username = from_user.username
|
||||
await ru.asyncify(session.commit)
|
||||
await data.reply(f"↔️ Account {tg_user} synced to {user}!")
|
||||
|
||||
elif isinstance(self.serf, rsd.DiscordSerf):
|
||||
import discord
|
||||
message: discord.Message = data.message
|
||||
author: discord.User = message.author
|
||||
DiscordT = self.alchemy.get(Discord)
|
||||
ds_user: Discord = await ru.asyncify(
|
||||
data.session.query(DiscordT).filter_by(discord_id=author.id).one_or_none
|
||||
)
|
||||
if ds_user is None:
|
||||
# Create
|
||||
ds_user = DiscordT(
|
||||
user=author,
|
||||
discord_id=author.id,
|
||||
username=author.name,
|
||||
discriminator=author.discriminator,
|
||||
avatar_url=author.avatar_url
|
||||
elif isinstance(self.serf, rsd.DiscordSerf):
|
||||
import discord
|
||||
message: discord.Message = data.message
|
||||
ds_author: discord.User = message.author
|
||||
DiscordT = self.alchemy.get(Discord)
|
||||
ds_user: Discord = await ru.asyncify(
|
||||
session.query(DiscordT).filter_by(discord_id=ds_author.id).one_or_none
|
||||
)
|
||||
data.session.add(ds_user)
|
||||
if ds_user is None:
|
||||
# Create
|
||||
ds_user = DiscordT(
|
||||
user=user,
|
||||
discord_id=ds_author.id,
|
||||
username=ds_author.name,
|
||||
discriminator=ds_author.discriminator,
|
||||
avatar_url=ds_author.avatar_url
|
||||
)
|
||||
session.add(ds_user)
|
||||
else:
|
||||
# Edit
|
||||
ds_user.username = ds_author.name
|
||||
ds_user.discriminator = ds_author.discriminator
|
||||
ds_user.avatar_url = ds_author.avatar_url
|
||||
await ru.asyncify(session.commit)
|
||||
await data.reply(f"↔️ Account {ds_user} synced to {ds_author}!")
|
||||
|
||||
elif isinstance(self.serf, rsm.MatrixSerf):
|
||||
raise rc.UnsupportedError(f"{self} hasn't been implemented for Matrix yet")
|
||||
|
||||
else:
|
||||
# Edit
|
||||
ds_user.username = author.name
|
||||
ds_user.discriminator = author.discriminator
|
||||
ds_user.avatar_url = author.avatar_url
|
||||
await data.session_commit()
|
||||
await data.reply(f"↔️ Account {ds_user} synced to {author}!")
|
||||
|
||||
elif isinstance(self.serf, rsm.MatrixSerf):
|
||||
raise rc.UnsupportedError(f"{self} hasn't been implemented for Matrix yet")
|
||||
|
||||
else:
|
||||
raise rc.UnsupportedError(f"Unknown interface: {self.serf.__class__.__qualname__}")
|
||||
raise rc.UnsupportedError(f"Unknown interface: {self.serf.__class__.__qualname__}")
|
||||
|
|
|
@ -19,6 +19,6 @@ class RoyalnetversionCommand(Command):
|
|||
else:
|
||||
message = f"ℹ️ Royalnet [url=https://github.com/Steffo99/royalnet/releases/tag/{self.royalnet_version}]" \
|
||||
f"{self.royalnet_version}[/url]\n"
|
||||
if "69" in royalnet.version.semantic:
|
||||
if "69" in self.royalnet_version:
|
||||
message += "(Nice.)"
|
||||
await data.reply(message)
|
||||
|
|
|
@ -17,6 +17,7 @@ class ApiUserCreateStar(rca.ApiStar):
|
|||
|
||||
tags = ["user"]
|
||||
|
||||
@rca.magic
|
||||
async def post(self, data: rca.ApiData) -> ru.JSON:
|
||||
"""Create a new Royalnet account."""
|
||||
UserT = self.alchemy.get(User)
|
||||
|
|
|
@ -15,7 +15,7 @@ class Discord:
|
|||
|
||||
@declared_attr
|
||||
def user_id(self):
|
||||
return Column(Integer, ForeignKey("users.uid"))
|
||||
return Column(Integer, ForeignKey("users.uid"), nullable=False)
|
||||
|
||||
@declared_attr
|
||||
def user(self):
|
||||
|
|
|
@ -13,7 +13,7 @@ class Matrix:
|
|||
|
||||
@declared_attr
|
||||
def user_id(self):
|
||||
return Column(Integer, ForeignKey("users.uid"))
|
||||
return Column(Integer, ForeignKey("users.uid"), nullable=False)
|
||||
|
||||
@declared_attr
|
||||
def user(self):
|
||||
|
|
|
@ -15,7 +15,7 @@ class Telegram:
|
|||
|
||||
@declared_attr
|
||||
def user_id(self):
|
||||
return Column(Integer, ForeignKey("users.uid"))
|
||||
return Column(Integer, ForeignKey("users.uid"), nullable=False)
|
||||
|
||||
@declared_attr
|
||||
def user(self):
|
||||
|
|
|
@ -40,12 +40,17 @@ class Command(metaclass=abc.ABCMeta):
|
|||
|
||||
@property
|
||||
def alchemy(self) -> "Alchemy":
|
||||
"""A shortcut for :attr:`.interface.alchemy`."""
|
||||
"""A shortcut for :attr:`.serf.alchemy`."""
|
||||
return self.serf.alchemy
|
||||
|
||||
@property
|
||||
def session_acm(self):
|
||||
"""A shortcut for :attr:`.alchemy.session_acm`."""
|
||||
return self.alchemy.session_acm
|
||||
|
||||
@property
|
||||
def loop(self) -> aio.AbstractEventLoop:
|
||||
"""A shortcut for :attr:`.interface.loop`."""
|
||||
"""A shortcut for :attr:`.serf.loop`."""
|
||||
return self.serf.loop
|
||||
|
||||
@abc.abstractmethod
|
||||
|
|
|
@ -21,6 +21,16 @@ class CommandData:
|
|||
def loop(self):
|
||||
return self.command.serf.loop
|
||||
|
||||
@property
|
||||
def alchemy(self):
|
||||
"""A shortcut for :attr:`.command.alchemy`."""
|
||||
return self.command.alchemy
|
||||
|
||||
@property
|
||||
def session_acm(self):
|
||||
"""A shortcut for :attr:`.alchemy.session_acm`."""
|
||||
return self.alchemy.session_acm
|
||||
|
||||
async def reply(self, text: str) -> None:
|
||||
"""Send a text message to the channel where the call was made.
|
||||
|
||||
|
|
|
@ -220,7 +220,7 @@ class Constellation:
|
|||
for SelectedEvent in events:
|
||||
# Initialize the event
|
||||
try:
|
||||
event = SelectedEvent(constellation=self, config=pack_cfg)
|
||||
event = SelectedEvent(parent=self, config=pack_cfg)
|
||||
except Exception as e:
|
||||
log.error(f"Skipping: "
|
||||
f"{SelectedEvent.__qualname__} - {e.__class__.__qualname__} in the initialization.")
|
||||
|
@ -262,7 +262,7 @@ class Constellation:
|
|||
self.starlette.add_route(*self._page_star_wrapper(page_star_instance))
|
||||
|
||||
def run_blocking(self):
|
||||
log.info(f"Running Constellation on https://{self.address}:{self.port}/...")
|
||||
log.info(f"Running Constellation on http://{self.address}:{self.port}/...")
|
||||
self.running = True
|
||||
try:
|
||||
uvicorn.run(self.starlette, host=self.address, port=self.port, log_config=UVICORN_LOGGING_CONFIG)
|
||||
|
|
|
@ -73,7 +73,8 @@ class DiscordSerf(Serf):
|
|||
|
||||
async def get_author(data, error_if_none=False):
|
||||
user: "discord.Member" = data.message.author
|
||||
query = data.session.query(self.master_table)
|
||||
async with data.session_acm() as session:
|
||||
query = session.query(self.master_table)
|
||||
for link in self.identity_chain:
|
||||
query = query.join(link.mapper.class_)
|
||||
query = query.filter(self.identity_column == user.id)
|
||||
|
@ -135,8 +136,7 @@ class DiscordSerf(Serf):
|
|||
# noinspection PyMethodMayBeStatic
|
||||
async def on_message(cli, message: "discord.Message") -> None:
|
||||
"""Handle messages received by passing them to the handle_message method of the bot."""
|
||||
# TODO: keep reference to these tasks somewhere
|
||||
self.loop.create_task(self.handle_message(message))
|
||||
self.tasks.add(self.handle_message(message))
|
||||
|
||||
async def on_ready(cli) -> None:
|
||||
"""Change the bot presence to ``online`` when the bot is ready."""
|
||||
|
|
|
@ -36,6 +36,9 @@ class Serf(abc.ABC):
|
|||
self.loop: Optional[aio.AbstractEventLoop] = loop
|
||||
"""The event loop this Serf is running on."""
|
||||
|
||||
self.tasks: Optional[ru.TaskList] = ru.TaskList(self.loop)
|
||||
"""A list of all running tasks of the serf. Initialized at the serf start."""
|
||||
|
||||
# Import packs
|
||||
pack_names = packs_cfg["active"]
|
||||
packs = {}
|
||||
|
@ -63,8 +66,8 @@ class Serf(abc.ABC):
|
|||
"""The identity table containing the interface data (such as the Telegram user data) and that is in a
|
||||
many-to-one relationship with the master table."""
|
||||
|
||||
# TODO: I'm not sure what this is either
|
||||
self.identity_column: Optional[str] = None
|
||||
"""The name of the column in the identity table that contains a unique user identifier. (???)"""
|
||||
|
||||
# Alchemy
|
||||
if ra.Alchemy is None:
|
||||
|
@ -76,6 +79,7 @@ class Serf(abc.ABC):
|
|||
tables = set()
|
||||
for pack in packs.values():
|
||||
try:
|
||||
# noinspection PyUnresolvedReferences
|
||||
tables = tables.union(pack["tables"].available_tables)
|
||||
except AttributeError:
|
||||
log.warning(f"Pack `{pack}` does not have the `available_tables` attribute.")
|
||||
|
@ -100,12 +104,14 @@ class Serf(abc.ABC):
|
|||
pack = packs[pack_name]
|
||||
pack_cfg = packs_cfg.get(pack_name, {})
|
||||
try:
|
||||
# noinspection PyUnresolvedReferences
|
||||
events = pack["events"].available_events
|
||||
except AttributeError:
|
||||
log.warning(f"Pack `{pack}` does not have the `available_events` attribute.")
|
||||
else:
|
||||
self.register_events(events, pack_cfg)
|
||||
try:
|
||||
# noinspection PyUnresolvedReferences
|
||||
commands = pack["commands"].available_commands
|
||||
except AttributeError:
|
||||
log.warning(f"Pack `{pack}` does not have the `available_commands` attribute.")
|
||||
|
@ -216,7 +222,7 @@ class Serf(abc.ABC):
|
|||
for SelectedEvent in events:
|
||||
# Initialize the event
|
||||
try:
|
||||
event = SelectedEvent(serf=self, config=pack_cfg)
|
||||
event = SelectedEvent(parent=self, config=pack_cfg)
|
||||
except Exception as e:
|
||||
log.error(f"Skipping: "
|
||||
f"{SelectedEvent.__qualname__} - {e.__class__.__qualname__} in the initialization.")
|
||||
|
@ -282,8 +288,6 @@ class Serf(abc.ABC):
|
|||
except Exception as e:
|
||||
ru.sentry_exc(e)
|
||||
await data.reply(f"⛔️ [b]{e.__class__.__name__}[/b]\n" + '\n'.join(map(lambda a: repr(a), e.args)))
|
||||
finally:
|
||||
await data.session_close()
|
||||
|
||||
@staticmethod
|
||||
async def press(key: rc.KeyboardKey, data: rc.CommandData):
|
||||
|
@ -307,12 +311,10 @@ class Serf(abc.ABC):
|
|||
except Exception as e:
|
||||
ru.sentry_exc(e)
|
||||
await data.reply(f"⛔️ [b]{e.__class__.__name__}[/b]\n" + '\n'.join(map(lambda a: repr(a), e.args)))
|
||||
finally:
|
||||
await data.session_close()
|
||||
|
||||
async def run(self):
|
||||
"""A coroutine that starts the event loop and handles command calls."""
|
||||
self.herald_task = self.loop.create_task(self.herald.run())
|
||||
self.herald_task = self.tasks.add(self.herald.run())
|
||||
# OVERRIDE THIS METHOD!
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -127,7 +127,8 @@ class TelegramSerf(Serf):
|
|||
if error_if_none:
|
||||
raise rc.CommandError("No command caller for this message")
|
||||
return None
|
||||
query = data.session.query(self.master_table)
|
||||
async with data.session_acm() as session:
|
||||
query = session.query(self.master_table)
|
||||
for link in self.identity_chain:
|
||||
query = query.join(link.mapper.class_)
|
||||
query = query.filter(self.identity_column == user.id)
|
||||
|
@ -189,7 +190,8 @@ class TelegramSerf(Serf):
|
|||
if error_if_none:
|
||||
raise rc.CommandError("No command caller for this message")
|
||||
return None
|
||||
query = data.session.query(self.master_table)
|
||||
async with data.session_acm() as session:
|
||||
query = session.query(self.master_table)
|
||||
for link in self.identity_chain:
|
||||
query = query.join(link.mapper.class_)
|
||||
query = query.filter(self.identity_column == user.id)
|
||||
|
@ -206,25 +208,27 @@ class TelegramSerf(Serf):
|
|||
async def handle_update(self, update: telegram.Update):
|
||||
"""Delegate :class:`telegram.Update` handling to the correct message type submethod."""
|
||||
if update.message is not None:
|
||||
log.debug(f"Handling update as a message")
|
||||
await self.handle_message(update.message)
|
||||
elif update.edited_message is not None:
|
||||
pass
|
||||
log.debug(f"Update is a edited message, not doing anything")
|
||||
elif update.channel_post is not None:
|
||||
pass
|
||||
log.debug(f"Update is a channel post, not doing anything")
|
||||
elif update.edited_channel_post is not None:
|
||||
pass
|
||||
log.debug(f"Update is a channel edit, not doing anything")
|
||||
elif update.inline_query is not None:
|
||||
pass
|
||||
log.debug(f"Update is a inline query, not doing anything")
|
||||
elif update.chosen_inline_result is not None:
|
||||
pass
|
||||
log.debug(f"Update is a chosen inline result, not doing anything")
|
||||
elif update.callback_query is not None:
|
||||
log.debug(f"Handling update as a callback query")
|
||||
await self.handle_callback_query(update.callback_query)
|
||||
elif update.shipping_query is not None:
|
||||
pass
|
||||
log.debug(f"Update is a shipping query, not doing anything")
|
||||
elif update.pre_checkout_query is not None:
|
||||
pass
|
||||
log.debug(f"Update is a precheckout query, not doing anything")
|
||||
elif update.poll is not None:
|
||||
pass
|
||||
log.debug(f"Update is a poll, not doing anything")
|
||||
else:
|
||||
log.warning(f"Unknown update type: {update}")
|
||||
|
||||
|
@ -236,25 +240,31 @@ class TelegramSerf(Serf):
|
|||
text: str = message.caption
|
||||
# No text or caption, ignore the message
|
||||
if text is None:
|
||||
log.debug("Skipping message as it had no text or caption")
|
||||
return
|
||||
# Skip non-command updates
|
||||
if not text.startswith("/"):
|
||||
if not text.startswith(self.prefix):
|
||||
log.debug(f"Skipping message as it didn't start with {self.prefix}")
|
||||
return
|
||||
# Find and clean parameters
|
||||
command_text, *parameters = text.split(" ")
|
||||
command_name = command_text.replace(f"@{self.client.username}", "").lower()
|
||||
command_name = command_text.replace(f"@{self.client.username}", "").lstrip(self.prefix).lower()
|
||||
log.debug(f"Parsed '{command_name}' as command name")
|
||||
# Find the command
|
||||
try:
|
||||
command = self.commands[command_name]
|
||||
except KeyError:
|
||||
# Skip the message
|
||||
log.debug(f"Skipping message as I could not find the command {command_name}")
|
||||
return
|
||||
# Send a typing notification
|
||||
log.debug(f"Sending typing notification")
|
||||
await self.api_call(message.chat.send_action, telegram.ChatAction.TYPING)
|
||||
# Prepare data
|
||||
# noinspection PyArgumentList
|
||||
data = self.MessageData(command=command, message=message)
|
||||
# Call the command
|
||||
log.debug(f"Calling {command}")
|
||||
await self.call(command, data, parameters)
|
||||
|
||||
async def handle_callback_query(self, cbq: telegram.CallbackQuery):
|
||||
|
@ -270,16 +280,25 @@ class TelegramSerf(Serf):
|
|||
async def run(self):
|
||||
await super().run()
|
||||
while True:
|
||||
# Collect ended tasks
|
||||
self.tasks.collect()
|
||||
# Get the latest 100 updates
|
||||
last_updates: List[telegram.Update] = await self.api_call(self.client.get_updates,
|
||||
offset=self.update_offset,
|
||||
timeout=60,
|
||||
read_latency=5.0)
|
||||
log.debug("Getting updates...")
|
||||
last_updates: Optional[List[telegram.Update]] = await self.api_call(
|
||||
self.client.get_updates,
|
||||
offset=self.update_offset,
|
||||
timeout=60,
|
||||
read_latency=5.0
|
||||
)
|
||||
# Ensure a list was returned from the API call
|
||||
if not isinstance(last_updates, list):
|
||||
log.warning("Received invalid data from get_updates, sleeping for 60 seconds, hoping it fixes itself.")
|
||||
await aio.sleep(60)
|
||||
continue
|
||||
# Handle updates
|
||||
log.debug("Handling updates...")
|
||||
for update in last_updates:
|
||||
# TODO: don't lose the reference to the task
|
||||
# noinspection PyAsyncCall
|
||||
self.loop.create_task(self.handle_update(update))
|
||||
self.tasks.add(self.handle_update(update))
|
||||
# Recalculate offset
|
||||
try:
|
||||
self.update_offset = last_updates[-1].update_id + 1
|
||||
|
|
|
@ -7,6 +7,7 @@ from .sentry import init_sentry, sentry_exc, sentry_wrap, sentry_async_wrap
|
|||
from .sleep_until import sleep_until
|
||||
from .strip_tabs import strip_tabs
|
||||
from .urluuid import to_urluuid, from_urluuid
|
||||
from .taskslist import TaskList
|
||||
|
||||
__all__ = [
|
||||
"asyncify",
|
||||
|
@ -26,4 +27,5 @@ __all__ = [
|
|||
"init_logging",
|
||||
"JSON",
|
||||
"strip_tabs",
|
||||
"TaskList",
|
||||
]
|
||||
|
|
|
@ -6,7 +6,9 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class MultiLock:
|
||||
"""A lock that can allow both simultaneous access and exclusive access to a resource."""
|
||||
"""A lock that allows either simultaneous read access or exclusive write access.
|
||||
|
||||
Basically, a reimplementation of Rust's `RwLock <https://doc.rust-lang.org/beta/std/sync/struct.RwLock.html>`_ ."""
|
||||
|
||||
def __init__(self):
|
||||
self._counter: int = 0
|
||||
|
@ -40,7 +42,6 @@ class MultiLock:
|
|||
async def exclusive(self):
|
||||
"""Acquire the lock for exclusive access."""
|
||||
log.debug(f"Waiting for exclusive lock end: {self}")
|
||||
# TODO: check if this actually works
|
||||
await self._exclusive_event.wait()
|
||||
self._exclusive_event.clear()
|
||||
log.debug(f"Waiting for normal lock end: {self}")
|
||||
|
|
38
royalnet/utils/taskslist.py
Normal file
38
royalnet/utils/taskslist.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
from typing import *
|
||||
import asyncio as aio
|
||||
import logging
|
||||
from .sentry import sentry_exc
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskList:
|
||||
def __init__(self, loop: aio.AbstractEventLoop):
|
||||
self.loop: aio.AbstractEventLoop = loop
|
||||
self.tasks: List[aio.Task] = []
|
||||
|
||||
def collect(self):
|
||||
"""Remove finished tasks from the list."""
|
||||
log.debug(f"Collecting done tasks")
|
||||
new_list = []
|
||||
for task in self.tasks:
|
||||
try:
|
||||
task.result()
|
||||
except aio.CancelledError:
|
||||
log.warning(f"Task {task} was unexpectedly cancelled.")
|
||||
except aio.InvalidStateError:
|
||||
log.debug(f"Task {task} hasn't finished running yet, readding it to the list.")
|
||||
new_list.append(task)
|
||||
except Exception as err:
|
||||
sentry_exc(err)
|
||||
self.tasks = new_list
|
||||
|
||||
def add(self, coroutine: Awaitable[Any], timeout: float = None) -> aio.Task:
|
||||
"""Add a new task to the list; the task will be cancelled if ``timeout`` seconds pass."""
|
||||
log.debug(f"Creating new task {coroutine}")
|
||||
if timeout:
|
||||
task = self.loop.create_task(aio.wait_for(coroutine, timeout=timeout))
|
||||
else:
|
||||
task = self.loop.create_task(coroutine)
|
||||
self.tasks.append(task)
|
||||
return task
|
Loading…
Reference in a new issue