1
Fork 0
mirror of https://github.com/RYGhub/royalnet.git synced 2024-11-23 11:34:18 +00:00

Improve upon a few thingies

This commit is contained in:
Steffo 2020-08-18 03:43:11 +02:00
parent 2c97a6cf47
commit a4f2cc5e46
16 changed files with 190 additions and 98 deletions

View file

@ -1,7 +1,9 @@
from contextlib import contextmanager, asynccontextmanager
from typing import *
import contextlib
import logging
from sqlalchemy import create_engine
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative.api import DeclarativeMeta
from sqlalchemy.orm import sessionmaker
@ -15,6 +17,8 @@ if TYPE_CHECKING:
# noinspection PyProtectedMember
from sqlalchemy.engine import Engine
log = logging.getLogger(__name__)
class Alchemy:
"""A wrapper around :mod:`sqlalchemy.orm` that allows the instantiation of multiple engines at once while
@ -42,7 +46,12 @@ class Alchemy:
# noinspection PyTypeChecker
bound_table: Table = type(name, (self._Base, table), {})
self._tables[name] = bound_table
self._Base.metadata.create_all()
# FIXME: Dirty hack
try:
self._Base.metadata.create_all()
except ProgrammingError:
log.warning("Skipping table creation, as it is probably being created by a different process.")
def get(self, table: Union[str, type]) -> DeclarativeMeta:
"""Get the table with a specified name or class.
@ -66,7 +75,7 @@ class Alchemy:
else:
raise TypeError(f"Can't get tables with objects of type '{table.__class__.__qualname__}'")
@contextmanager
@contextlib.contextmanager
def session_cm(self) -> Iterator[Session]:
"""Create a Session as a context manager (that can be used in ``with`` statements).
@ -91,7 +100,7 @@ class Alchemy:
finally:
session.close()
@asynccontextmanager
@contextlib.asynccontextmanager
async def session_acm(self) -> AsyncIterator[Session]:
"""Create a Session as a async context manager (that can be used in ``async with`` statements).

View file

@ -15,12 +15,16 @@ class RoyalnetsyncCommand(rc.Command):
syntax: str = "{username} {password}"
async def run(self, args: rc.CommandArgs, data: rc.CommandData) -> None:
author = await data.get_author(error_if_none=False)
if author is not None:
raise rc.UserError(f"This account is already connected to {author}!")
username = args[0]
password = " ".join(args[1:])
author = await data.get_author(error_if_none=True)
user = await data.find_user(username)
if user is None:
raise rc.UserError("No such user.")
try:
successful = user.test_password(password)
except ValueError:
@ -28,60 +32,61 @@ class RoyalnetsyncCommand(rc.Command):
if not successful:
raise rc.InvalidInputError(f"Invalid password!")
if isinstance(self.serf, rst.TelegramSerf):
import telegram
message: telegram.Message = data.message
from_user: telegram.User = message.from_user
TelegramT = self.alchemy.get(Telegram)
tg_user: Telegram = await ru.asyncify(
data.session.query(TelegramT).filter_by(tg_id=from_user.id).one_or_none
)
if tg_user is None:
# Create
tg_user = TelegramT(
user=author,
tg_id=from_user.id,
first_name=from_user.first_name,
last_name=from_user.last_name,
username=from_user.username
async with data.session_acm() as session:
if isinstance(self.serf, rst.TelegramSerf):
import telegram
message: telegram.Message = data.message
from_user: telegram.User = message.from_user
TelegramT = self.alchemy.get(Telegram)
tg_user: Telegram = await ru.asyncify(
session.query(TelegramT).filter_by(tg_id=from_user.id).one_or_none
)
data.session.add(tg_user)
else:
# Edit
tg_user.first_name = from_user.first_name
tg_user.last_name = from_user.last_name
tg_user.username = from_user.username
await data.session_commit()
await data.reply(f"↔️ Account {tg_user} synced to {author}!")
if tg_user is None:
# Create
tg_user = TelegramT(
user=user,
tg_id=from_user.id,
first_name=from_user.first_name,
last_name=from_user.last_name,
username=from_user.username
)
session.add(tg_user)
else:
# Edit
tg_user.first_name = from_user.first_name
tg_user.last_name = from_user.last_name
tg_user.username = from_user.username
await ru.asyncify(session.commit)
await data.reply(f"↔️ Account {tg_user} synced to {user}!")
elif isinstance(self.serf, rsd.DiscordSerf):
import discord
message: discord.Message = data.message
author: discord.User = message.author
DiscordT = self.alchemy.get(Discord)
ds_user: Discord = await ru.asyncify(
data.session.query(DiscordT).filter_by(discord_id=author.id).one_or_none
)
if ds_user is None:
# Create
ds_user = DiscordT(
user=author,
discord_id=author.id,
username=author.name,
discriminator=author.discriminator,
avatar_url=author.avatar_url
elif isinstance(self.serf, rsd.DiscordSerf):
import discord
message: discord.Message = data.message
ds_author: discord.User = message.author
DiscordT = self.alchemy.get(Discord)
ds_user: Discord = await ru.asyncify(
session.query(DiscordT).filter_by(discord_id=ds_author.id).one_or_none
)
data.session.add(ds_user)
if ds_user is None:
# Create
ds_user = DiscordT(
user=user,
discord_id=ds_author.id,
username=ds_author.name,
discriminator=ds_author.discriminator,
avatar_url=ds_author.avatar_url
)
session.add(ds_user)
else:
# Edit
ds_user.username = ds_author.name
ds_user.discriminator = ds_author.discriminator
ds_user.avatar_url = ds_author.avatar_url
await ru.asyncify(session.commit)
await data.reply(f"↔️ Account {ds_user} synced to {ds_author}!")
elif isinstance(self.serf, rsm.MatrixSerf):
raise rc.UnsupportedError(f"{self} hasn't been implemented for Matrix yet")
else:
# Edit
ds_user.username = author.name
ds_user.discriminator = author.discriminator
ds_user.avatar_url = author.avatar_url
await data.session_commit()
await data.reply(f"↔️ Account {ds_user} synced to {author}!")
elif isinstance(self.serf, rsm.MatrixSerf):
raise rc.UnsupportedError(f"{self} hasn't been implemented for Matrix yet")
else:
raise rc.UnsupportedError(f"Unknown interface: {self.serf.__class__.__qualname__}")
raise rc.UnsupportedError(f"Unknown interface: {self.serf.__class__.__qualname__}")

View file

@ -19,6 +19,6 @@ class RoyalnetversionCommand(Command):
else:
message = f" Royalnet [url=https://github.com/Steffo99/royalnet/releases/tag/{self.royalnet_version}]" \
f"{self.royalnet_version}[/url]\n"
if "69" in royalnet.version.semantic:
if "69" in self.royalnet_version:
message += "(Nice.)"
await data.reply(message)

View file

@ -17,6 +17,7 @@ class ApiUserCreateStar(rca.ApiStar):
tags = ["user"]
@rca.magic
async def post(self, data: rca.ApiData) -> ru.JSON:
"""Create a new Royalnet account."""
UserT = self.alchemy.get(User)

View file

@ -15,7 +15,7 @@ class Discord:
@declared_attr
def user_id(self):
return Column(Integer, ForeignKey("users.uid"))
return Column(Integer, ForeignKey("users.uid"), nullable=False)
@declared_attr
def user(self):

View file

@ -13,7 +13,7 @@ class Matrix:
@declared_attr
def user_id(self):
return Column(Integer, ForeignKey("users.uid"))
return Column(Integer, ForeignKey("users.uid"), nullable=False)
@declared_attr
def user(self):

View file

@ -15,7 +15,7 @@ class Telegram:
@declared_attr
def user_id(self):
return Column(Integer, ForeignKey("users.uid"))
return Column(Integer, ForeignKey("users.uid"), nullable=False)
@declared_attr
def user(self):

View file

@ -40,12 +40,17 @@ class Command(metaclass=abc.ABCMeta):
@property
def alchemy(self) -> "Alchemy":
"""A shortcut for :attr:`.interface.alchemy`."""
"""A shortcut for :attr:`.serf.alchemy`."""
return self.serf.alchemy
@property
def session_acm(self):
"""A shortcut for :attr:`.alchemy.session_acm`."""
return self.alchemy.session_acm
@property
def loop(self) -> aio.AbstractEventLoop:
"""A shortcut for :attr:`.interface.loop`."""
"""A shortcut for :attr:`.serf.loop`."""
return self.serf.loop
@abc.abstractmethod

View file

@ -21,6 +21,16 @@ class CommandData:
def loop(self):
return self.command.serf.loop
@property
def alchemy(self):
"""A shortcut for :attr:`.command.alchemy`."""
return self.command.alchemy
@property
def session_acm(self):
"""A shortcut for :attr:`.alchemy.session_acm`."""
return self.alchemy.session_acm
async def reply(self, text: str) -> None:
"""Send a text message to the channel where the call was made.

View file

@ -220,7 +220,7 @@ class Constellation:
for SelectedEvent in events:
# Initialize the event
try:
event = SelectedEvent(constellation=self, config=pack_cfg)
event = SelectedEvent(parent=self, config=pack_cfg)
except Exception as e:
log.error(f"Skipping: "
f"{SelectedEvent.__qualname__} - {e.__class__.__qualname__} in the initialization.")
@ -262,7 +262,7 @@ class Constellation:
self.starlette.add_route(*self._page_star_wrapper(page_star_instance))
def run_blocking(self):
log.info(f"Running Constellation on https://{self.address}:{self.port}/...")
log.info(f"Running Constellation on http://{self.address}:{self.port}/...")
self.running = True
try:
uvicorn.run(self.starlette, host=self.address, port=self.port, log_config=UVICORN_LOGGING_CONFIG)

View file

@ -73,7 +73,8 @@ class DiscordSerf(Serf):
async def get_author(data, error_if_none=False):
user: "discord.Member" = data.message.author
query = data.session.query(self.master_table)
async with data.session_acm() as session:
query = session.query(self.master_table)
for link in self.identity_chain:
query = query.join(link.mapper.class_)
query = query.filter(self.identity_column == user.id)
@ -135,8 +136,7 @@ class DiscordSerf(Serf):
# noinspection PyMethodMayBeStatic
async def on_message(cli, message: "discord.Message") -> None:
"""Handle messages received by passing them to the handle_message method of the bot."""
# TODO: keep reference to these tasks somewhere
self.loop.create_task(self.handle_message(message))
self.tasks.add(self.handle_message(message))
async def on_ready(cli) -> None:
"""Change the bot presence to ``online`` when the bot is ready."""

View file

@ -36,6 +36,9 @@ class Serf(abc.ABC):
self.loop: Optional[aio.AbstractEventLoop] = loop
"""The event loop this Serf is running on."""
self.tasks: Optional[ru.TaskList] = ru.TaskList(self.loop)
"""A list of all running tasks of the serf. Initialized at the serf start."""
# Import packs
pack_names = packs_cfg["active"]
packs = {}
@ -63,8 +66,8 @@ class Serf(abc.ABC):
"""The identity table containing the interface data (such as the Telegram user data) and that is in a
many-to-one relationship with the master table."""
# TODO: I'm not sure what this is either
self.identity_column: Optional[str] = None
"""The name of the column in the identity table that contains a unique user identifier. (???)"""
# Alchemy
if ra.Alchemy is None:
@ -76,6 +79,7 @@ class Serf(abc.ABC):
tables = set()
for pack in packs.values():
try:
# noinspection PyUnresolvedReferences
tables = tables.union(pack["tables"].available_tables)
except AttributeError:
log.warning(f"Pack `{pack}` does not have the `available_tables` attribute.")
@ -100,12 +104,14 @@ class Serf(abc.ABC):
pack = packs[pack_name]
pack_cfg = packs_cfg.get(pack_name, {})
try:
# noinspection PyUnresolvedReferences
events = pack["events"].available_events
except AttributeError:
log.warning(f"Pack `{pack}` does not have the `available_events` attribute.")
else:
self.register_events(events, pack_cfg)
try:
# noinspection PyUnresolvedReferences
commands = pack["commands"].available_commands
except AttributeError:
log.warning(f"Pack `{pack}` does not have the `available_commands` attribute.")
@ -216,7 +222,7 @@ class Serf(abc.ABC):
for SelectedEvent in events:
# Initialize the event
try:
event = SelectedEvent(serf=self, config=pack_cfg)
event = SelectedEvent(parent=self, config=pack_cfg)
except Exception as e:
log.error(f"Skipping: "
f"{SelectedEvent.__qualname__} - {e.__class__.__qualname__} in the initialization.")
@ -282,8 +288,6 @@ class Serf(abc.ABC):
except Exception as e:
ru.sentry_exc(e)
await data.reply(f"⛔️ [b]{e.__class__.__name__}[/b]\n" + '\n'.join(map(lambda a: repr(a), e.args)))
finally:
await data.session_close()
@staticmethod
async def press(key: rc.KeyboardKey, data: rc.CommandData):
@ -307,12 +311,10 @@ class Serf(abc.ABC):
except Exception as e:
ru.sentry_exc(e)
await data.reply(f"⛔️ [b]{e.__class__.__name__}[/b]\n" + '\n'.join(map(lambda a: repr(a), e.args)))
finally:
await data.session_close()
async def run(self):
"""A coroutine that starts the event loop and handles command calls."""
self.herald_task = self.loop.create_task(self.herald.run())
self.herald_task = self.tasks.add(self.herald.run())
# OVERRIDE THIS METHOD!
@classmethod

View file

@ -127,7 +127,8 @@ class TelegramSerf(Serf):
if error_if_none:
raise rc.CommandError("No command caller for this message")
return None
query = data.session.query(self.master_table)
async with data.session_acm() as session:
query = session.query(self.master_table)
for link in self.identity_chain:
query = query.join(link.mapper.class_)
query = query.filter(self.identity_column == user.id)
@ -189,7 +190,8 @@ class TelegramSerf(Serf):
if error_if_none:
raise rc.CommandError("No command caller for this message")
return None
query = data.session.query(self.master_table)
async with data.session_acm() as session:
query = session.query(self.master_table)
for link in self.identity_chain:
query = query.join(link.mapper.class_)
query = query.filter(self.identity_column == user.id)
@ -206,25 +208,27 @@ class TelegramSerf(Serf):
async def handle_update(self, update: telegram.Update):
"""Delegate :class:`telegram.Update` handling to the correct message type submethod."""
if update.message is not None:
log.debug(f"Handling update as a message")
await self.handle_message(update.message)
elif update.edited_message is not None:
pass
log.debug(f"Update is a edited message, not doing anything")
elif update.channel_post is not None:
pass
log.debug(f"Update is a channel post, not doing anything")
elif update.edited_channel_post is not None:
pass
log.debug(f"Update is a channel edit, not doing anything")
elif update.inline_query is not None:
pass
log.debug(f"Update is a inline query, not doing anything")
elif update.chosen_inline_result is not None:
pass
log.debug(f"Update is a chosen inline result, not doing anything")
elif update.callback_query is not None:
log.debug(f"Handling update as a callback query")
await self.handle_callback_query(update.callback_query)
elif update.shipping_query is not None:
pass
log.debug(f"Update is a shipping query, not doing anything")
elif update.pre_checkout_query is not None:
pass
log.debug(f"Update is a precheckout query, not doing anything")
elif update.poll is not None:
pass
log.debug(f"Update is a poll, not doing anything")
else:
log.warning(f"Unknown update type: {update}")
@ -236,25 +240,31 @@ class TelegramSerf(Serf):
text: str = message.caption
# No text or caption, ignore the message
if text is None:
log.debug("Skipping message as it had no text or caption")
return
# Skip non-command updates
if not text.startswith("/"):
if not text.startswith(self.prefix):
log.debug(f"Skipping message as it didn't start with {self.prefix}")
return
# Find and clean parameters
command_text, *parameters = text.split(" ")
command_name = command_text.replace(f"@{self.client.username}", "").lower()
command_name = command_text.replace(f"@{self.client.username}", "").lstrip(self.prefix).lower()
log.debug(f"Parsed '{command_name}' as command name")
# Find the command
try:
command = self.commands[command_name]
except KeyError:
# Skip the message
log.debug(f"Skipping message as I could not find the command {command_name}")
return
# Send a typing notification
log.debug(f"Sending typing notification")
await self.api_call(message.chat.send_action, telegram.ChatAction.TYPING)
# Prepare data
# noinspection PyArgumentList
data = self.MessageData(command=command, message=message)
# Call the command
log.debug(f"Calling {command}")
await self.call(command, data, parameters)
async def handle_callback_query(self, cbq: telegram.CallbackQuery):
@ -270,16 +280,25 @@ class TelegramSerf(Serf):
async def run(self):
await super().run()
while True:
# Collect ended tasks
self.tasks.collect()
# Get the latest 100 updates
last_updates: List[telegram.Update] = await self.api_call(self.client.get_updates,
offset=self.update_offset,
timeout=60,
read_latency=5.0)
log.debug("Getting updates...")
last_updates: Optional[List[telegram.Update]] = await self.api_call(
self.client.get_updates,
offset=self.update_offset,
timeout=60,
read_latency=5.0
)
# Ensure a list was returned from the API call
if not isinstance(last_updates, list):
log.warning("Received invalid data from get_updates, sleeping for 60 seconds, hoping it fixes itself.")
await aio.sleep(60)
continue
# Handle updates
log.debug("Handling updates...")
for update in last_updates:
# TODO: don't lose the reference to the task
# noinspection PyAsyncCall
self.loop.create_task(self.handle_update(update))
self.tasks.add(self.handle_update(update))
# Recalculate offset
try:
self.update_offset = last_updates[-1].update_id + 1

View file

@ -7,6 +7,7 @@ from .sentry import init_sentry, sentry_exc, sentry_wrap, sentry_async_wrap
from .sleep_until import sleep_until
from .strip_tabs import strip_tabs
from .urluuid import to_urluuid, from_urluuid
from .taskslist import TaskList
__all__ = [
"asyncify",
@ -26,4 +27,5 @@ __all__ = [
"init_logging",
"JSON",
"strip_tabs",
"TaskList",
]

View file

@ -6,7 +6,9 @@ log = logging.getLogger(__name__)
class MultiLock:
"""A lock that can allow both simultaneous access and exclusive access to a resource."""
"""A lock that allows either simultaneous read access or exclusive write access.
Basically, a reimplementation of Rust's `RwLock <https://doc.rust-lang.org/beta/std/sync/struct.RwLock.html>`_ ."""
def __init__(self):
self._counter: int = 0
@ -40,7 +42,6 @@ class MultiLock:
async def exclusive(self):
"""Acquire the lock for exclusive access."""
log.debug(f"Waiting for exclusive lock end: {self}")
# TODO: check if this actually works
await self._exclusive_event.wait()
self._exclusive_event.clear()
log.debug(f"Waiting for normal lock end: {self}")

View file

@ -0,0 +1,38 @@
from typing import *
import asyncio as aio
import logging
from .sentry import sentry_exc
log = logging.getLogger(__name__)
class TaskList:
def __init__(self, loop: aio.AbstractEventLoop):
self.loop: aio.AbstractEventLoop = loop
self.tasks: List[aio.Task] = []
def collect(self):
"""Remove finished tasks from the list."""
log.debug(f"Collecting done tasks")
new_list = []
for task in self.tasks:
try:
task.result()
except aio.CancelledError:
log.warning(f"Task {task} was unexpectedly cancelled.")
except aio.InvalidStateError:
log.debug(f"Task {task} hasn't finished running yet, readding it to the list.")
new_list.append(task)
except Exception as err:
sentry_exc(err)
self.tasks = new_list
def add(self, coroutine: Awaitable[Any], timeout: float = None) -> aio.Task:
"""Add a new task to the list; the task will be cancelled if ``timeout`` seconds pass."""
log.debug(f"Creating new task {coroutine}")
if timeout:
task = self.loop.create_task(aio.wait_for(coroutine, timeout=timeout))
else:
task = self.loop.create_task(coroutine)
self.tasks.append(task)
return task