mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 19:44:20 +00:00
Improve upon a few thingies
This commit is contained in:
parent
2c97a6cf47
commit
a4f2cc5e46
16 changed files with 190 additions and 98 deletions
|
@ -1,7 +1,9 @@
|
||||||
from contextlib import contextmanager, asynccontextmanager
|
|
||||||
from typing import *
|
from typing import *
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.exc import ProgrammingError
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.ext.declarative.api import DeclarativeMeta
|
from sqlalchemy.ext.declarative.api import DeclarativeMeta
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
@ -15,6 +17,8 @@ if TYPE_CHECKING:
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Alchemy:
|
class Alchemy:
|
||||||
"""A wrapper around :mod:`sqlalchemy.orm` that allows the instantiation of multiple engines at once while
|
"""A wrapper around :mod:`sqlalchemy.orm` that allows the instantiation of multiple engines at once while
|
||||||
|
@ -42,7 +46,12 @@ class Alchemy:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
bound_table: Table = type(name, (self._Base, table), {})
|
bound_table: Table = type(name, (self._Base, table), {})
|
||||||
self._tables[name] = bound_table
|
self._tables[name] = bound_table
|
||||||
self._Base.metadata.create_all()
|
# FIXME: Dirty hack
|
||||||
|
try:
|
||||||
|
self._Base.metadata.create_all()
|
||||||
|
except ProgrammingError:
|
||||||
|
log.warning("Skipping table creation, as it is probably being created by a different process.")
|
||||||
|
|
||||||
|
|
||||||
def get(self, table: Union[str, type]) -> DeclarativeMeta:
|
def get(self, table: Union[str, type]) -> DeclarativeMeta:
|
||||||
"""Get the table with a specified name or class.
|
"""Get the table with a specified name or class.
|
||||||
|
@ -66,7 +75,7 @@ class Alchemy:
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Can't get tables with objects of type '{table.__class__.__qualname__}'")
|
raise TypeError(f"Can't get tables with objects of type '{table.__class__.__qualname__}'")
|
||||||
|
|
||||||
@contextmanager
|
@contextlib.contextmanager
|
||||||
def session_cm(self) -> Iterator[Session]:
|
def session_cm(self) -> Iterator[Session]:
|
||||||
"""Create a Session as a context manager (that can be used in ``with`` statements).
|
"""Create a Session as a context manager (that can be used in ``with`` statements).
|
||||||
|
|
||||||
|
@ -91,7 +100,7 @@ class Alchemy:
|
||||||
finally:
|
finally:
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
@asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def session_acm(self) -> AsyncIterator[Session]:
|
async def session_acm(self) -> AsyncIterator[Session]:
|
||||||
"""Create a Session as a async context manager (that can be used in ``async with`` statements).
|
"""Create a Session as a async context manager (that can be used in ``async with`` statements).
|
||||||
|
|
||||||
|
|
|
@ -15,12 +15,16 @@ class RoyalnetsyncCommand(rc.Command):
|
||||||
syntax: str = "{username} {password}"
|
syntax: str = "{username} {password}"
|
||||||
|
|
||||||
async def run(self, args: rc.CommandArgs, data: rc.CommandData) -> None:
|
async def run(self, args: rc.CommandArgs, data: rc.CommandData) -> None:
|
||||||
|
author = await data.get_author(error_if_none=False)
|
||||||
|
if author is not None:
|
||||||
|
raise rc.UserError(f"This account is already connected to {author}!")
|
||||||
|
|
||||||
username = args[0]
|
username = args[0]
|
||||||
password = " ".join(args[1:])
|
password = " ".join(args[1:])
|
||||||
|
|
||||||
author = await data.get_author(error_if_none=True)
|
|
||||||
|
|
||||||
user = await data.find_user(username)
|
user = await data.find_user(username)
|
||||||
|
if user is None:
|
||||||
|
raise rc.UserError("No such user.")
|
||||||
try:
|
try:
|
||||||
successful = user.test_password(password)
|
successful = user.test_password(password)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -28,60 +32,61 @@ class RoyalnetsyncCommand(rc.Command):
|
||||||
if not successful:
|
if not successful:
|
||||||
raise rc.InvalidInputError(f"Invalid password!")
|
raise rc.InvalidInputError(f"Invalid password!")
|
||||||
|
|
||||||
if isinstance(self.serf, rst.TelegramSerf):
|
async with data.session_acm() as session:
|
||||||
import telegram
|
if isinstance(self.serf, rst.TelegramSerf):
|
||||||
message: telegram.Message = data.message
|
import telegram
|
||||||
from_user: telegram.User = message.from_user
|
message: telegram.Message = data.message
|
||||||
TelegramT = self.alchemy.get(Telegram)
|
from_user: telegram.User = message.from_user
|
||||||
tg_user: Telegram = await ru.asyncify(
|
TelegramT = self.alchemy.get(Telegram)
|
||||||
data.session.query(TelegramT).filter_by(tg_id=from_user.id).one_or_none
|
tg_user: Telegram = await ru.asyncify(
|
||||||
)
|
session.query(TelegramT).filter_by(tg_id=from_user.id).one_or_none
|
||||||
if tg_user is None:
|
|
||||||
# Create
|
|
||||||
tg_user = TelegramT(
|
|
||||||
user=author,
|
|
||||||
tg_id=from_user.id,
|
|
||||||
first_name=from_user.first_name,
|
|
||||||
last_name=from_user.last_name,
|
|
||||||
username=from_user.username
|
|
||||||
)
|
)
|
||||||
data.session.add(tg_user)
|
if tg_user is None:
|
||||||
else:
|
# Create
|
||||||
# Edit
|
tg_user = TelegramT(
|
||||||
tg_user.first_name = from_user.first_name
|
user=user,
|
||||||
tg_user.last_name = from_user.last_name
|
tg_id=from_user.id,
|
||||||
tg_user.username = from_user.username
|
first_name=from_user.first_name,
|
||||||
await data.session_commit()
|
last_name=from_user.last_name,
|
||||||
await data.reply(f"↔️ Account {tg_user} synced to {author}!")
|
username=from_user.username
|
||||||
|
)
|
||||||
|
session.add(tg_user)
|
||||||
|
else:
|
||||||
|
# Edit
|
||||||
|
tg_user.first_name = from_user.first_name
|
||||||
|
tg_user.last_name = from_user.last_name
|
||||||
|
tg_user.username = from_user.username
|
||||||
|
await ru.asyncify(session.commit)
|
||||||
|
await data.reply(f"↔️ Account {tg_user} synced to {user}!")
|
||||||
|
|
||||||
elif isinstance(self.serf, rsd.DiscordSerf):
|
elif isinstance(self.serf, rsd.DiscordSerf):
|
||||||
import discord
|
import discord
|
||||||
message: discord.Message = data.message
|
message: discord.Message = data.message
|
||||||
author: discord.User = message.author
|
ds_author: discord.User = message.author
|
||||||
DiscordT = self.alchemy.get(Discord)
|
DiscordT = self.alchemy.get(Discord)
|
||||||
ds_user: Discord = await ru.asyncify(
|
ds_user: Discord = await ru.asyncify(
|
||||||
data.session.query(DiscordT).filter_by(discord_id=author.id).one_or_none
|
session.query(DiscordT).filter_by(discord_id=ds_author.id).one_or_none
|
||||||
)
|
|
||||||
if ds_user is None:
|
|
||||||
# Create
|
|
||||||
ds_user = DiscordT(
|
|
||||||
user=author,
|
|
||||||
discord_id=author.id,
|
|
||||||
username=author.name,
|
|
||||||
discriminator=author.discriminator,
|
|
||||||
avatar_url=author.avatar_url
|
|
||||||
)
|
)
|
||||||
data.session.add(ds_user)
|
if ds_user is None:
|
||||||
|
# Create
|
||||||
|
ds_user = DiscordT(
|
||||||
|
user=user,
|
||||||
|
discord_id=ds_author.id,
|
||||||
|
username=ds_author.name,
|
||||||
|
discriminator=ds_author.discriminator,
|
||||||
|
avatar_url=ds_author.avatar_url
|
||||||
|
)
|
||||||
|
session.add(ds_user)
|
||||||
|
else:
|
||||||
|
# Edit
|
||||||
|
ds_user.username = ds_author.name
|
||||||
|
ds_user.discriminator = ds_author.discriminator
|
||||||
|
ds_user.avatar_url = ds_author.avatar_url
|
||||||
|
await ru.asyncify(session.commit)
|
||||||
|
await data.reply(f"↔️ Account {ds_user} synced to {ds_author}!")
|
||||||
|
|
||||||
|
elif isinstance(self.serf, rsm.MatrixSerf):
|
||||||
|
raise rc.UnsupportedError(f"{self} hasn't been implemented for Matrix yet")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Edit
|
raise rc.UnsupportedError(f"Unknown interface: {self.serf.__class__.__qualname__}")
|
||||||
ds_user.username = author.name
|
|
||||||
ds_user.discriminator = author.discriminator
|
|
||||||
ds_user.avatar_url = author.avatar_url
|
|
||||||
await data.session_commit()
|
|
||||||
await data.reply(f"↔️ Account {ds_user} synced to {author}!")
|
|
||||||
|
|
||||||
elif isinstance(self.serf, rsm.MatrixSerf):
|
|
||||||
raise rc.UnsupportedError(f"{self} hasn't been implemented for Matrix yet")
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise rc.UnsupportedError(f"Unknown interface: {self.serf.__class__.__qualname__}")
|
|
||||||
|
|
|
@ -19,6 +19,6 @@ class RoyalnetversionCommand(Command):
|
||||||
else:
|
else:
|
||||||
message = f"ℹ️ Royalnet [url=https://github.com/Steffo99/royalnet/releases/tag/{self.royalnet_version}]" \
|
message = f"ℹ️ Royalnet [url=https://github.com/Steffo99/royalnet/releases/tag/{self.royalnet_version}]" \
|
||||||
f"{self.royalnet_version}[/url]\n"
|
f"{self.royalnet_version}[/url]\n"
|
||||||
if "69" in royalnet.version.semantic:
|
if "69" in self.royalnet_version:
|
||||||
message += "(Nice.)"
|
message += "(Nice.)"
|
||||||
await data.reply(message)
|
await data.reply(message)
|
||||||
|
|
|
@ -17,6 +17,7 @@ class ApiUserCreateStar(rca.ApiStar):
|
||||||
|
|
||||||
tags = ["user"]
|
tags = ["user"]
|
||||||
|
|
||||||
|
@rca.magic
|
||||||
async def post(self, data: rca.ApiData) -> ru.JSON:
|
async def post(self, data: rca.ApiData) -> ru.JSON:
|
||||||
"""Create a new Royalnet account."""
|
"""Create a new Royalnet account."""
|
||||||
UserT = self.alchemy.get(User)
|
UserT = self.alchemy.get(User)
|
||||||
|
|
|
@ -15,7 +15,7 @@ class Discord:
|
||||||
|
|
||||||
@declared_attr
|
@declared_attr
|
||||||
def user_id(self):
|
def user_id(self):
|
||||||
return Column(Integer, ForeignKey("users.uid"))
|
return Column(Integer, ForeignKey("users.uid"), nullable=False)
|
||||||
|
|
||||||
@declared_attr
|
@declared_attr
|
||||||
def user(self):
|
def user(self):
|
||||||
|
|
|
@ -13,7 +13,7 @@ class Matrix:
|
||||||
|
|
||||||
@declared_attr
|
@declared_attr
|
||||||
def user_id(self):
|
def user_id(self):
|
||||||
return Column(Integer, ForeignKey("users.uid"))
|
return Column(Integer, ForeignKey("users.uid"), nullable=False)
|
||||||
|
|
||||||
@declared_attr
|
@declared_attr
|
||||||
def user(self):
|
def user(self):
|
||||||
|
|
|
@ -15,7 +15,7 @@ class Telegram:
|
||||||
|
|
||||||
@declared_attr
|
@declared_attr
|
||||||
def user_id(self):
|
def user_id(self):
|
||||||
return Column(Integer, ForeignKey("users.uid"))
|
return Column(Integer, ForeignKey("users.uid"), nullable=False)
|
||||||
|
|
||||||
@declared_attr
|
@declared_attr
|
||||||
def user(self):
|
def user(self):
|
||||||
|
|
|
@ -40,12 +40,17 @@ class Command(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def alchemy(self) -> "Alchemy":
|
def alchemy(self) -> "Alchemy":
|
||||||
"""A shortcut for :attr:`.interface.alchemy`."""
|
"""A shortcut for :attr:`.serf.alchemy`."""
|
||||||
return self.serf.alchemy
|
return self.serf.alchemy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session_acm(self):
|
||||||
|
"""A shortcut for :attr:`.alchemy.session_acm`."""
|
||||||
|
return self.alchemy.session_acm
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loop(self) -> aio.AbstractEventLoop:
|
def loop(self) -> aio.AbstractEventLoop:
|
||||||
"""A shortcut for :attr:`.interface.loop`."""
|
"""A shortcut for :attr:`.serf.loop`."""
|
||||||
return self.serf.loop
|
return self.serf.loop
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
|
|
@ -21,6 +21,16 @@ class CommandData:
|
||||||
def loop(self):
|
def loop(self):
|
||||||
return self.command.serf.loop
|
return self.command.serf.loop
|
||||||
|
|
||||||
|
@property
|
||||||
|
def alchemy(self):
|
||||||
|
"""A shortcut for :attr:`.command.alchemy`."""
|
||||||
|
return self.command.alchemy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session_acm(self):
|
||||||
|
"""A shortcut for :attr:`.alchemy.session_acm`."""
|
||||||
|
return self.alchemy.session_acm
|
||||||
|
|
||||||
async def reply(self, text: str) -> None:
|
async def reply(self, text: str) -> None:
|
||||||
"""Send a text message to the channel where the call was made.
|
"""Send a text message to the channel where the call was made.
|
||||||
|
|
||||||
|
|
|
@ -220,7 +220,7 @@ class Constellation:
|
||||||
for SelectedEvent in events:
|
for SelectedEvent in events:
|
||||||
# Initialize the event
|
# Initialize the event
|
||||||
try:
|
try:
|
||||||
event = SelectedEvent(constellation=self, config=pack_cfg)
|
event = SelectedEvent(parent=self, config=pack_cfg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Skipping: "
|
log.error(f"Skipping: "
|
||||||
f"{SelectedEvent.__qualname__} - {e.__class__.__qualname__} in the initialization.")
|
f"{SelectedEvent.__qualname__} - {e.__class__.__qualname__} in the initialization.")
|
||||||
|
@ -262,7 +262,7 @@ class Constellation:
|
||||||
self.starlette.add_route(*self._page_star_wrapper(page_star_instance))
|
self.starlette.add_route(*self._page_star_wrapper(page_star_instance))
|
||||||
|
|
||||||
def run_blocking(self):
|
def run_blocking(self):
|
||||||
log.info(f"Running Constellation on https://{self.address}:{self.port}/...")
|
log.info(f"Running Constellation on http://{self.address}:{self.port}/...")
|
||||||
self.running = True
|
self.running = True
|
||||||
try:
|
try:
|
||||||
uvicorn.run(self.starlette, host=self.address, port=self.port, log_config=UVICORN_LOGGING_CONFIG)
|
uvicorn.run(self.starlette, host=self.address, port=self.port, log_config=UVICORN_LOGGING_CONFIG)
|
||||||
|
|
|
@ -73,7 +73,8 @@ class DiscordSerf(Serf):
|
||||||
|
|
||||||
async def get_author(data, error_if_none=False):
|
async def get_author(data, error_if_none=False):
|
||||||
user: "discord.Member" = data.message.author
|
user: "discord.Member" = data.message.author
|
||||||
query = data.session.query(self.master_table)
|
async with data.session_acm() as session:
|
||||||
|
query = session.query(self.master_table)
|
||||||
for link in self.identity_chain:
|
for link in self.identity_chain:
|
||||||
query = query.join(link.mapper.class_)
|
query = query.join(link.mapper.class_)
|
||||||
query = query.filter(self.identity_column == user.id)
|
query = query.filter(self.identity_column == user.id)
|
||||||
|
@ -135,8 +136,7 @@ class DiscordSerf(Serf):
|
||||||
# noinspection PyMethodMayBeStatic
|
# noinspection PyMethodMayBeStatic
|
||||||
async def on_message(cli, message: "discord.Message") -> None:
|
async def on_message(cli, message: "discord.Message") -> None:
|
||||||
"""Handle messages received by passing them to the handle_message method of the bot."""
|
"""Handle messages received by passing them to the handle_message method of the bot."""
|
||||||
# TODO: keep reference to these tasks somewhere
|
self.tasks.add(self.handle_message(message))
|
||||||
self.loop.create_task(self.handle_message(message))
|
|
||||||
|
|
||||||
async def on_ready(cli) -> None:
|
async def on_ready(cli) -> None:
|
||||||
"""Change the bot presence to ``online`` when the bot is ready."""
|
"""Change the bot presence to ``online`` when the bot is ready."""
|
||||||
|
|
|
@ -36,6 +36,9 @@ class Serf(abc.ABC):
|
||||||
self.loop: Optional[aio.AbstractEventLoop] = loop
|
self.loop: Optional[aio.AbstractEventLoop] = loop
|
||||||
"""The event loop this Serf is running on."""
|
"""The event loop this Serf is running on."""
|
||||||
|
|
||||||
|
self.tasks: Optional[ru.TaskList] = ru.TaskList(self.loop)
|
||||||
|
"""A list of all running tasks of the serf. Initialized at the serf start."""
|
||||||
|
|
||||||
# Import packs
|
# Import packs
|
||||||
pack_names = packs_cfg["active"]
|
pack_names = packs_cfg["active"]
|
||||||
packs = {}
|
packs = {}
|
||||||
|
@ -63,8 +66,8 @@ class Serf(abc.ABC):
|
||||||
"""The identity table containing the interface data (such as the Telegram user data) and that is in a
|
"""The identity table containing the interface data (such as the Telegram user data) and that is in a
|
||||||
many-to-one relationship with the master table."""
|
many-to-one relationship with the master table."""
|
||||||
|
|
||||||
# TODO: I'm not sure what this is either
|
|
||||||
self.identity_column: Optional[str] = None
|
self.identity_column: Optional[str] = None
|
||||||
|
"""The name of the column in the identity table that contains a unique user identifier. (???)"""
|
||||||
|
|
||||||
# Alchemy
|
# Alchemy
|
||||||
if ra.Alchemy is None:
|
if ra.Alchemy is None:
|
||||||
|
@ -76,6 +79,7 @@ class Serf(abc.ABC):
|
||||||
tables = set()
|
tables = set()
|
||||||
for pack in packs.values():
|
for pack in packs.values():
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
tables = tables.union(pack["tables"].available_tables)
|
tables = tables.union(pack["tables"].available_tables)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
log.warning(f"Pack `{pack}` does not have the `available_tables` attribute.")
|
log.warning(f"Pack `{pack}` does not have the `available_tables` attribute.")
|
||||||
|
@ -100,12 +104,14 @@ class Serf(abc.ABC):
|
||||||
pack = packs[pack_name]
|
pack = packs[pack_name]
|
||||||
pack_cfg = packs_cfg.get(pack_name, {})
|
pack_cfg = packs_cfg.get(pack_name, {})
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
events = pack["events"].available_events
|
events = pack["events"].available_events
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
log.warning(f"Pack `{pack}` does not have the `available_events` attribute.")
|
log.warning(f"Pack `{pack}` does not have the `available_events` attribute.")
|
||||||
else:
|
else:
|
||||||
self.register_events(events, pack_cfg)
|
self.register_events(events, pack_cfg)
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
commands = pack["commands"].available_commands
|
commands = pack["commands"].available_commands
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
log.warning(f"Pack `{pack}` does not have the `available_commands` attribute.")
|
log.warning(f"Pack `{pack}` does not have the `available_commands` attribute.")
|
||||||
|
@ -216,7 +222,7 @@ class Serf(abc.ABC):
|
||||||
for SelectedEvent in events:
|
for SelectedEvent in events:
|
||||||
# Initialize the event
|
# Initialize the event
|
||||||
try:
|
try:
|
||||||
event = SelectedEvent(serf=self, config=pack_cfg)
|
event = SelectedEvent(parent=self, config=pack_cfg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Skipping: "
|
log.error(f"Skipping: "
|
||||||
f"{SelectedEvent.__qualname__} - {e.__class__.__qualname__} in the initialization.")
|
f"{SelectedEvent.__qualname__} - {e.__class__.__qualname__} in the initialization.")
|
||||||
|
@ -282,8 +288,6 @@ class Serf(abc.ABC):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ru.sentry_exc(e)
|
ru.sentry_exc(e)
|
||||||
await data.reply(f"⛔️ [b]{e.__class__.__name__}[/b]\n" + '\n'.join(map(lambda a: repr(a), e.args)))
|
await data.reply(f"⛔️ [b]{e.__class__.__name__}[/b]\n" + '\n'.join(map(lambda a: repr(a), e.args)))
|
||||||
finally:
|
|
||||||
await data.session_close()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def press(key: rc.KeyboardKey, data: rc.CommandData):
|
async def press(key: rc.KeyboardKey, data: rc.CommandData):
|
||||||
|
@ -307,12 +311,10 @@ class Serf(abc.ABC):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ru.sentry_exc(e)
|
ru.sentry_exc(e)
|
||||||
await data.reply(f"⛔️ [b]{e.__class__.__name__}[/b]\n" + '\n'.join(map(lambda a: repr(a), e.args)))
|
await data.reply(f"⛔️ [b]{e.__class__.__name__}[/b]\n" + '\n'.join(map(lambda a: repr(a), e.args)))
|
||||||
finally:
|
|
||||||
await data.session_close()
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""A coroutine that starts the event loop and handles command calls."""
|
"""A coroutine that starts the event loop and handles command calls."""
|
||||||
self.herald_task = self.loop.create_task(self.herald.run())
|
self.herald_task = self.tasks.add(self.herald.run())
|
||||||
# OVERRIDE THIS METHOD!
|
# OVERRIDE THIS METHOD!
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -127,7 +127,8 @@ class TelegramSerf(Serf):
|
||||||
if error_if_none:
|
if error_if_none:
|
||||||
raise rc.CommandError("No command caller for this message")
|
raise rc.CommandError("No command caller for this message")
|
||||||
return None
|
return None
|
||||||
query = data.session.query(self.master_table)
|
async with data.session_acm() as session:
|
||||||
|
query = session.query(self.master_table)
|
||||||
for link in self.identity_chain:
|
for link in self.identity_chain:
|
||||||
query = query.join(link.mapper.class_)
|
query = query.join(link.mapper.class_)
|
||||||
query = query.filter(self.identity_column == user.id)
|
query = query.filter(self.identity_column == user.id)
|
||||||
|
@ -189,7 +190,8 @@ class TelegramSerf(Serf):
|
||||||
if error_if_none:
|
if error_if_none:
|
||||||
raise rc.CommandError("No command caller for this message")
|
raise rc.CommandError("No command caller for this message")
|
||||||
return None
|
return None
|
||||||
query = data.session.query(self.master_table)
|
async with data.session_acm() as session:
|
||||||
|
query = session.query(self.master_table)
|
||||||
for link in self.identity_chain:
|
for link in self.identity_chain:
|
||||||
query = query.join(link.mapper.class_)
|
query = query.join(link.mapper.class_)
|
||||||
query = query.filter(self.identity_column == user.id)
|
query = query.filter(self.identity_column == user.id)
|
||||||
|
@ -206,25 +208,27 @@ class TelegramSerf(Serf):
|
||||||
async def handle_update(self, update: telegram.Update):
|
async def handle_update(self, update: telegram.Update):
|
||||||
"""Delegate :class:`telegram.Update` handling to the correct message type submethod."""
|
"""Delegate :class:`telegram.Update` handling to the correct message type submethod."""
|
||||||
if update.message is not None:
|
if update.message is not None:
|
||||||
|
log.debug(f"Handling update as a message")
|
||||||
await self.handle_message(update.message)
|
await self.handle_message(update.message)
|
||||||
elif update.edited_message is not None:
|
elif update.edited_message is not None:
|
||||||
pass
|
log.debug(f"Update is a edited message, not doing anything")
|
||||||
elif update.channel_post is not None:
|
elif update.channel_post is not None:
|
||||||
pass
|
log.debug(f"Update is a channel post, not doing anything")
|
||||||
elif update.edited_channel_post is not None:
|
elif update.edited_channel_post is not None:
|
||||||
pass
|
log.debug(f"Update is a channel edit, not doing anything")
|
||||||
elif update.inline_query is not None:
|
elif update.inline_query is not None:
|
||||||
pass
|
log.debug(f"Update is a inline query, not doing anything")
|
||||||
elif update.chosen_inline_result is not None:
|
elif update.chosen_inline_result is not None:
|
||||||
pass
|
log.debug(f"Update is a chosen inline result, not doing anything")
|
||||||
elif update.callback_query is not None:
|
elif update.callback_query is not None:
|
||||||
|
log.debug(f"Handling update as a callback query")
|
||||||
await self.handle_callback_query(update.callback_query)
|
await self.handle_callback_query(update.callback_query)
|
||||||
elif update.shipping_query is not None:
|
elif update.shipping_query is not None:
|
||||||
pass
|
log.debug(f"Update is a shipping query, not doing anything")
|
||||||
elif update.pre_checkout_query is not None:
|
elif update.pre_checkout_query is not None:
|
||||||
pass
|
log.debug(f"Update is a precheckout query, not doing anything")
|
||||||
elif update.poll is not None:
|
elif update.poll is not None:
|
||||||
pass
|
log.debug(f"Update is a poll, not doing anything")
|
||||||
else:
|
else:
|
||||||
log.warning(f"Unknown update type: {update}")
|
log.warning(f"Unknown update type: {update}")
|
||||||
|
|
||||||
|
@ -236,25 +240,31 @@ class TelegramSerf(Serf):
|
||||||
text: str = message.caption
|
text: str = message.caption
|
||||||
# No text or caption, ignore the message
|
# No text or caption, ignore the message
|
||||||
if text is None:
|
if text is None:
|
||||||
|
log.debug("Skipping message as it had no text or caption")
|
||||||
return
|
return
|
||||||
# Skip non-command updates
|
# Skip non-command updates
|
||||||
if not text.startswith("/"):
|
if not text.startswith(self.prefix):
|
||||||
|
log.debug(f"Skipping message as it didn't start with {self.prefix}")
|
||||||
return
|
return
|
||||||
# Find and clean parameters
|
# Find and clean parameters
|
||||||
command_text, *parameters = text.split(" ")
|
command_text, *parameters = text.split(" ")
|
||||||
command_name = command_text.replace(f"@{self.client.username}", "").lower()
|
command_name = command_text.replace(f"@{self.client.username}", "").lstrip(self.prefix).lower()
|
||||||
|
log.debug(f"Parsed '{command_name}' as command name")
|
||||||
# Find the command
|
# Find the command
|
||||||
try:
|
try:
|
||||||
command = self.commands[command_name]
|
command = self.commands[command_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
# Skip the message
|
# Skip the message
|
||||||
|
log.debug(f"Skipping message as I could not find the command {command_name}")
|
||||||
return
|
return
|
||||||
# Send a typing notification
|
# Send a typing notification
|
||||||
|
log.debug(f"Sending typing notification")
|
||||||
await self.api_call(message.chat.send_action, telegram.ChatAction.TYPING)
|
await self.api_call(message.chat.send_action, telegram.ChatAction.TYPING)
|
||||||
# Prepare data
|
# Prepare data
|
||||||
# noinspection PyArgumentList
|
# noinspection PyArgumentList
|
||||||
data = self.MessageData(command=command, message=message)
|
data = self.MessageData(command=command, message=message)
|
||||||
# Call the command
|
# Call the command
|
||||||
|
log.debug(f"Calling {command}")
|
||||||
await self.call(command, data, parameters)
|
await self.call(command, data, parameters)
|
||||||
|
|
||||||
async def handle_callback_query(self, cbq: telegram.CallbackQuery):
|
async def handle_callback_query(self, cbq: telegram.CallbackQuery):
|
||||||
|
@ -270,16 +280,25 @@ class TelegramSerf(Serf):
|
||||||
async def run(self):
|
async def run(self):
|
||||||
await super().run()
|
await super().run()
|
||||||
while True:
|
while True:
|
||||||
|
# Collect ended tasks
|
||||||
|
self.tasks.collect()
|
||||||
# Get the latest 100 updates
|
# Get the latest 100 updates
|
||||||
last_updates: List[telegram.Update] = await self.api_call(self.client.get_updates,
|
log.debug("Getting updates...")
|
||||||
offset=self.update_offset,
|
last_updates: Optional[List[telegram.Update]] = await self.api_call(
|
||||||
timeout=60,
|
self.client.get_updates,
|
||||||
read_latency=5.0)
|
offset=self.update_offset,
|
||||||
|
timeout=60,
|
||||||
|
read_latency=5.0
|
||||||
|
)
|
||||||
|
# Ensure a list was returned from the API call
|
||||||
|
if not isinstance(last_updates, list):
|
||||||
|
log.warning("Received invalid data from get_updates, sleeping for 60 seconds, hoping it fixes itself.")
|
||||||
|
await aio.sleep(60)
|
||||||
|
continue
|
||||||
# Handle updates
|
# Handle updates
|
||||||
|
log.debug("Handling updates...")
|
||||||
for update in last_updates:
|
for update in last_updates:
|
||||||
# TODO: don't lose the reference to the task
|
self.tasks.add(self.handle_update(update))
|
||||||
# noinspection PyAsyncCall
|
|
||||||
self.loop.create_task(self.handle_update(update))
|
|
||||||
# Recalculate offset
|
# Recalculate offset
|
||||||
try:
|
try:
|
||||||
self.update_offset = last_updates[-1].update_id + 1
|
self.update_offset = last_updates[-1].update_id + 1
|
||||||
|
|
|
@ -7,6 +7,7 @@ from .sentry import init_sentry, sentry_exc, sentry_wrap, sentry_async_wrap
|
||||||
from .sleep_until import sleep_until
|
from .sleep_until import sleep_until
|
||||||
from .strip_tabs import strip_tabs
|
from .strip_tabs import strip_tabs
|
||||||
from .urluuid import to_urluuid, from_urluuid
|
from .urluuid import to_urluuid, from_urluuid
|
||||||
|
from .taskslist import TaskList
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"asyncify",
|
"asyncify",
|
||||||
|
@ -26,4 +27,5 @@ __all__ = [
|
||||||
"init_logging",
|
"init_logging",
|
||||||
"JSON",
|
"JSON",
|
||||||
"strip_tabs",
|
"strip_tabs",
|
||||||
|
"TaskList",
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,7 +6,9 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MultiLock:
|
class MultiLock:
|
||||||
"""A lock that can allow both simultaneous access and exclusive access to a resource."""
|
"""A lock that allows either simultaneous read access or exclusive write access.
|
||||||
|
|
||||||
|
Basically, a reimplementation of Rust's `RwLock <https://doc.rust-lang.org/beta/std/sync/struct.RwLock.html>`_ ."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._counter: int = 0
|
self._counter: int = 0
|
||||||
|
@ -40,7 +42,6 @@ class MultiLock:
|
||||||
async def exclusive(self):
|
async def exclusive(self):
|
||||||
"""Acquire the lock for exclusive access."""
|
"""Acquire the lock for exclusive access."""
|
||||||
log.debug(f"Waiting for exclusive lock end: {self}")
|
log.debug(f"Waiting for exclusive lock end: {self}")
|
||||||
# TODO: check if this actually works
|
|
||||||
await self._exclusive_event.wait()
|
await self._exclusive_event.wait()
|
||||||
self._exclusive_event.clear()
|
self._exclusive_event.clear()
|
||||||
log.debug(f"Waiting for normal lock end: {self}")
|
log.debug(f"Waiting for normal lock end: {self}")
|
||||||
|
|
38
royalnet/utils/taskslist.py
Normal file
38
royalnet/utils/taskslist.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
from typing import *
|
||||||
|
import asyncio as aio
|
||||||
|
import logging
|
||||||
|
from .sentry import sentry_exc
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskList:
|
||||||
|
def __init__(self, loop: aio.AbstractEventLoop):
|
||||||
|
self.loop: aio.AbstractEventLoop = loop
|
||||||
|
self.tasks: List[aio.Task] = []
|
||||||
|
|
||||||
|
def collect(self):
|
||||||
|
"""Remove finished tasks from the list."""
|
||||||
|
log.debug(f"Collecting done tasks")
|
||||||
|
new_list = []
|
||||||
|
for task in self.tasks:
|
||||||
|
try:
|
||||||
|
task.result()
|
||||||
|
except aio.CancelledError:
|
||||||
|
log.warning(f"Task {task} was unexpectedly cancelled.")
|
||||||
|
except aio.InvalidStateError:
|
||||||
|
log.debug(f"Task {task} hasn't finished running yet, readding it to the list.")
|
||||||
|
new_list.append(task)
|
||||||
|
except Exception as err:
|
||||||
|
sentry_exc(err)
|
||||||
|
self.tasks = new_list
|
||||||
|
|
||||||
|
def add(self, coroutine: Awaitable[Any], timeout: float = None) -> aio.Task:
|
||||||
|
"""Add a new task to the list; the task will be cancelled if ``timeout`` seconds pass."""
|
||||||
|
log.debug(f"Creating new task {coroutine}")
|
||||||
|
if timeout:
|
||||||
|
task = self.loop.create_task(aio.wait_for(coroutine, timeout=timeout))
|
||||||
|
else:
|
||||||
|
task = self.loop.create_task(coroutine)
|
||||||
|
self.tasks.append(task)
|
||||||
|
return task
|
Loading…
Reference in a new issue