1
Fork 0
mirror of https://github.com/RYGhub/royalnet.git synced 2024-11-27 13:34:28 +00:00

EVEN MORE STUFF

This commit is contained in:
Steffo 2019-11-13 15:58:01 +01:00
parent 9949903305
commit 14c3ce4420
16 changed files with 325 additions and 234 deletions

11
TODO.md
View file

@ -1,11 +0,0 @@
# To do:
- [x] alchemy
- [x] bard
- [ ] commands (check for renamed references)
- [ ] interfaces
- [ ] packs (almost)
- [ ] utils
- [x] constellation
- [ ] main
- [ ] dependencies

View file

@ -1,14 +1,23 @@
from typing import Set, Dict, Union, Type from typing import Set, Dict, Union
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.schema import Table
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative.api import DeclarativeMeta
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager, asynccontextmanager from contextlib import contextmanager, asynccontextmanager
from royalnet.utils import asyncify from royalnet.utils import asyncify
from royalnet.alchemy.errors import TableNotFoundException from royalnet.alchemy.errors import TableNotFoundException
try:
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.schema import Table
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative.api import DeclarativeMeta
from sqlalchemy.orm import sessionmaker
except ImportError:
create_engine = None
Engine = None
Table = None
declarative_base = None
DeclarativeMeta = None
sessionmaker = None
class Alchemy: class Alchemy:
"""A wrapper around ``sqlalchemy.orm`` that allows the instantiation of multiple engines at once while maintaining """A wrapper around ``sqlalchemy.orm`` that allows the instantiation of multiple engines at once while maintaining
@ -22,12 +31,15 @@ class Alchemy:
tables: The :class:`set` of tables to be created and used in the selected database. tables: The :class:`set` of tables to be created and used in the selected database.
Check the tables submodule for more details. Check the tables submodule for more details.
""" """
if create_engine is None:
raise ImportError("'alchemy' extra is not installed")
if database_uri.startswith("sqlite"): if database_uri.startswith("sqlite"):
raise NotImplementedError("sqlite databases aren't supported, as they can't be used in multithreaded" raise NotImplementedError("sqlite databases aren't supported, as they can't be used in multithreaded"
" applications") " applications")
self._engine: Engine = create_engine(database_uri) self._engine: Engine = create_engine(database_uri)
self._Base: DeclarativeMeta = declarative_base(bind=self._engine) self._Base: DeclarativeMeta = declarative_base(bind=self._engine)
self._Session: sessionmaker = sessionmaker(bind=self._engine) self.Session: sessionmaker = sessionmaker(bind=self._engine)
self._tables: Dict[str, Table] = {} self._tables: Dict[str, Table] = {}
for table in tables: for table in tables:
name = table.__name__ name = table.__name__
@ -76,7 +88,7 @@ class Alchemy:
session.commit() session.commit()
""" """
session = self._Session() session = self.Session()
try: try:
yield session yield session
except Exception: except Exception:
@ -99,7 +111,7 @@ class Alchemy:
... ...
# Commit the session # Commit the session
await asyncify(session.commit)""" await asyncify(session.commit)"""
session = await asyncify(self._Session) session = await asyncify(self.Session)
try: try:
yield session yield session
except Exception: except Exception:

View file

@ -1,6 +1,9 @@
from typing import Type try:
from sqlalchemy.inspection import inspect from sqlalchemy.inspection import inspect
from sqlalchemy.schema import Table from sqlalchemy.schema import Table
except ImportError:
inspect = None
Table = None
def table_dfs(starting_table: Table, ending_table: Table) -> tuple: def table_dfs(starting_table: Table, ending_table: Table) -> tuple:
@ -8,6 +11,9 @@ def table_dfs(starting_table: Table, ending_table: Table) -> tuple:
Returns: Returns:
A :class:`tuple` containing the path, starting from the starting table and ending at the ending table.""" A :class:`tuple` containing the path, starting from the starting table and ending at the ending table."""
if inspect is None:
raise ImportError("'alchemy' extra is not installed")
inspected = set() inspected = set()
def search(_mapper, chain): def search(_mapper, chain):

View file

@ -1,5 +1,4 @@
import os import os
import youtube_dl
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from royalnet.utils import asyncify, MultiLock from royalnet.utils import asyncify, MultiLock
@ -7,6 +6,11 @@ from asyncio import AbstractEventLoop, get_event_loop
from .ytdlinfo import YtdlInfo from .ytdlinfo import YtdlInfo
from .errors import NotFoundError, MultipleFilesError from .errors import NotFoundError, MultipleFilesError
try:
from youtube_dl import YoutubeDL
except ImportError:
youtube_dl = None
class YtdlFile: class YtdlFile:
"""A representation of a file download with ``youtube_dl``.""" """A representation of a file download with ``youtube_dl``."""
@ -60,9 +64,12 @@ class YtdlFile:
async def download_file(self) -> None: async def download_file(self) -> None:
"""Download the file.""" """Download the file."""
if YoutubeDL is None:
raise ImportError("'bard' extra is not installed")
def download(): def download():
"""Download function block to be asyncified.""" """Download function block to be asyncified."""
with youtube_dl.YoutubeDL(self.ytdl_args) as ytdl: with YoutubeDL(self.ytdl_args) as ytdl:
filename = ytdl.prepare_filename(self.info.__dict__) filename = ytdl.prepare_filename(self.info.__dict__)
ytdl.download([self.info.webpage_url]) ytdl.download([self.info.webpage_url])
self.filename = filename self.filename = filename

View file

@ -2,9 +2,13 @@ from asyncio import AbstractEventLoop, get_event_loop
from typing import Optional, Dict, List, Any from typing import Optional, Dict, List, Any
from datetime import datetime, timedelta from datetime import datetime, timedelta
import dateparser import dateparser
from youtube_dl import YoutubeDL
from royalnet.utils import ytdldateformat, asyncify from royalnet.utils import ytdldateformat, asyncify
try:
from youtube_dl import YoutubeDL
except ImportError:
YoutubeDL = None
class YtdlInfo: class YtdlInfo:
"""A wrapper around youtube_dl extracted info.""" """A wrapper around youtube_dl extracted info."""
@ -85,6 +89,9 @@ class YtdlInfo:
Returns: Returns:
A :py:class:`list` containing the infos for the requested videos.""" A :py:class:`list` containing the infos for the requested videos."""
if YoutubeDL is None:
raise ImportError("'bard' extra is not installed")
if loop is None: if loop is None:
loop: AbstractEventLoop = get_event_loop() loop: AbstractEventLoop = get_event_loop()
# So many redundant options! # So many redundant options!

View file

@ -2,9 +2,9 @@ import typing
import re import re
import ffmpeg import ffmpeg
import os import os
from royalnet.utils import asyncify, MultiLock
from .ytdlinfo import YtdlInfo from .ytdlinfo import YtdlInfo
from .ytdlfile import YtdlFile from .ytdlfile import YtdlFile
from royalnet.utils import asyncify, MultiLock
class YtdlMp3: class YtdlMp3:

View file

@ -44,7 +44,8 @@ class CommandArgs(list):
"""Get the arguments as a space-joined string. """Get the arguments as a space-joined string.
Parameters: Parameters:
require_at_least: the minimum amount of arguments required, will raise :py:exc:`royalnet.error.InvalidInputError` if the requirement is not fullfilled. require_at_least: the minimum amount of arguments required, will raise :exc:`InvalidInputError` if the
requirement is not fullfilled.
Raises: Raises:
royalnet.error.InvalidInputError: if there are less than ``require_at_least`` arguments. royalnet.error.InvalidInputError: if there are less than ``require_at_least`` arguments.
@ -84,7 +85,8 @@ class CommandArgs(list):
return match.groups() return match.groups()
def optional(self, index: int, default=None) -> Optional[str]: def optional(self, index: int, default=None) -> Optional[str]:
"""Get the argument at a specific index, but don't raise an error if nothing is found, instead returning the ``default`` value. """Get the argument at a specific index, but don't raise an error if nothing is found, instead returning the
``default`` value.
Parameters: Parameters:
index: The index of the argument you want to retrieve. index: The index of the argument you want to retrieve.

View file

@ -1,36 +1,37 @@
from typing import Dict, Callable from typing import Optional, TYPE_CHECKING
import warnings
from .errors import UnsupportedError from .errors import UnsupportedError
from .commandinterface import CommandInterface from .commandinterface import CommandInterface
from ..utils import asyncify from ..utils import asyncify
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
class CommandData: class CommandData:
def __init__(self, interface: CommandInterface): def __init__(self, interface: CommandInterface, session: Optional["Session"]):
self._interface: CommandInterface = interface self._interface: CommandInterface = interface
if len(self._interface.command.tables) > 0: self._session: Optional["Session"] = session
self.session = self._interface.alchemy._Session()
else: @property
self.session = None def session(self) -> "Session":
"""Get the :class:`Alchemy` :class:`Session`, if it is available.
Raises:
UnsupportedError: if no session is available."""
if self._session is None:
raise UnsupportedError("'session' is not supported")
return self._session
async def session_commit(self): async def session_commit(self):
"""Commit the changes to the session.""" """Commit the changes to the session."""
await asyncify(self.session.commit) await asyncify(self.session.commit)
async def session_close(self):
"""Close the opened session.
Remember to call this when the data is disposed of!"""
if self.session:
await asyncify(self.session.close)
self.session = None
async def reply(self, text: str) -> None: async def reply(self, text: str) -> None:
"""Send a text message to the channel where the call was made. """Send a text message to the channel where the call was made.
Parameters: Parameters:
text: The text to be sent, possibly formatted in the weird undescribed markup that I'm using.""" text: The text to be sent, possibly formatted in the weird undescribed markup that I'm using."""
raise UnsupportedError("'reply' is not supported on this platform") raise UnsupportedError("'reply' is not supported")
async def get_author(self, error_if_none: bool = False): async def get_author(self, error_if_none: bool = False):
"""Try to find the identifier of the user that sent the message. """Try to find the identifier of the user that sent the message.
@ -38,14 +39,7 @@ class CommandData:
Parameters: Parameters:
error_if_none: Raise an exception if this is True and the call has no author.""" error_if_none: Raise an exception if this is True and the call has no author."""
raise UnsupportedError("'get_author' is not supported on this platform") raise UnsupportedError("'get_author' is not supported")
async def keyboard(self, text: str, keyboard: Dict[str, Callable]) -> None:
"""Send a keyboard having the keys of the dict as keys and calling the correspondent values on a press.
The function should be passed the :py:class:`CommandData` instance as a argument."""
warnings.warn("keyboard is deprecated, please avoid using it", category=DeprecationWarning)
raise UnsupportedError("'keyboard' is not supported on this platform")
async def delete_invoking(self, error_if_unavailable=False) -> None: async def delete_invoking(self, error_if_unavailable=False) -> None:
"""Delete the invoking message, if supported by the interface. """Delete the invoking message, if supported by the interface.
@ -55,4 +49,4 @@ class CommandData:
Parameters: Parameters:
error_if_unavailable: if True, raise an exception if the message cannot been deleted.""" error_if_unavailable: if True, raise an exception if the message cannot been deleted."""
if error_if_unavailable: if error_if_unavailable:
raise UnsupportedError("'delete_invoking' is not supported on this platform") raise UnsupportedError("'delete_invoking' is not supported")

View file

@ -1,35 +1,29 @@
import typing from typing import Optional, TYPE_CHECKING, Awaitable, Any, Callable
import asyncio from asyncio import AbstractEventLoop
from .errors import UnsupportedError from .errors import UnsupportedError
if typing.TYPE_CHECKING: if TYPE_CHECKING:
from .command import Command from .command import Command
from ..alchemy import Alchemy from ..alchemy import Alchemy
from ..serf import GenericBot from ..serf import Serf
class CommandInterface: class CommandInterface:
name: str = NotImplemented name: str = NotImplemented
prefix: str = NotImplemented prefix: str = NotImplemented
alchemy: "Alchemy" = NotImplemented alchemy: "Alchemy" = NotImplemented
bot: "GenericBot" = NotImplemented bot: "Serf" = NotImplemented
loop: asyncio.AbstractEventLoop = NotImplemented loop: AbstractEventLoop = NotImplemented
def __init__(self): def __init__(self):
self.command: typing.Optional[Command] = None # Will be bound after the command has been created self.command: Optional[Command] = None # Will be bound after the command has been created
def register_herald_action(self, def register_herald_action(self,
event_name: str, event_name: str,
coroutine: typing.Callable[[typing.Any], typing.Awaitable[typing.Dict]]): coroutine: Callable[[Any], Awaitable[dict]]):
raise UnsupportedError(f"{self.register_herald_action.__name__} is not supported on this platform") raise UnsupportedError(f"{self.register_herald_action.__name__} is not supported on this platform")
def unregister_herald_action(self, event_name: str): def unregister_herald_action(self, event_name: str):
raise UnsupportedError(f"{self.unregister_herald_action.__name__} is not supported on this platform") raise UnsupportedError(f"{self.unregister_herald_action.__name__} is not supported on this platform")
async def call_herald_action(self, destination: str, event_name: str, args: typing.Dict) -> typing.Dict: async def call_herald_action(self, destination: str, event_name: str, args: dict) -> dict:
raise UnsupportedError(f"{self.call_herald_action.__name__} is not supported on this platform") raise UnsupportedError(f"{self.call_herald_action.__name__} is not supported on this platform")
def register_keyboard_key(self, key_name: str, callback: typing.Callable):
raise UnsupportedError(f"{self.register_keyboard_key.__name__} is not supported on this platform")
def unregister_keyboard_key(self, key_name: str):
raise UnsupportedError(f"{self.unregister_keyboard_key.__name__} is not supported on this platform")

View file

@ -1,16 +1,27 @@
import typing import typing
import uvicorn
import logging import logging
import sentry_sdk
from sentry_sdk.integrations.aiohttp import AioHttpIntegration
from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
from sentry_sdk.integrations.logging import LoggingIntegration
import royalnet import royalnet
import keyring import keyring
from starlette.applications import Starlette
from .star import PageStar, ExceptionStar
from royalnet.alchemy import Alchemy from royalnet.alchemy import Alchemy
from royalnet import __version__ as version from .star import PageStar, ExceptionStar
try:
import uvicorn
from starlette.applications import Starlette
except ImportError:
uvicorn = None
Starlette = None
try:
import sentry_sdk
from sentry_sdk.integrations.aiohttp import AioHttpIntegration
from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
from sentry_sdk.integrations.logging import LoggingIntegration
except ImportError:
sentry_sdk = None
AioHttpIntegration = None
SqlalchemyIntegration = None
LoggingIntegration = None
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -29,6 +40,9 @@ class Constellation:
exc_stars: typing.List[typing.Type[ExceptionStar]] = None, exc_stars: typing.List[typing.Type[ExceptionStar]] = None,
*, *,
debug: bool = __debug__,): debug: bool = __debug__,):
if Starlette is None:
raise ImportError("'constellation' extra is not installed")
if page_stars is None: if page_stars is None:
page_stars = [] page_stars = []
@ -93,13 +107,18 @@ class Constellation:
address: The IP address this Constellation should bind to. address: The IP address this Constellation should bind to.
port: The port this Constellation should listen for requests on.""" port: The port this Constellation should listen for requests on."""
# Initialize Sentry on the process # Initialize Sentry on the process
if sentry_sdk is None:
log.info("Sentry: not installed")
else:
sentry_dsn = self.get_secret("sentry") sentry_dsn = self.get_secret("sentry")
if sentry_dsn: if not sentry_dsn:
log.info("Sentry: disabled")
else:
# noinspection PyUnreachableCode # noinspection PyUnreachableCode
if __debug__: if __debug__:
release = f"Dev" release = f"Dev"
else: else:
release = f"{version}" release = f"{royalnet.__version__}"
log.debug("Initializing Sentry...") log.debug("Initializing Sentry...")
sentry_sdk.init(sentry_dsn, sentry_sdk.init(sentry_dsn,
integrations=[AioHttpIntegration(), integrations=[AioHttpIntegration(),
@ -107,8 +126,6 @@ class Constellation:
LoggingIntegration(event_level=None)], LoggingIntegration(event_level=None)],
release=release) release=release)
log.info(f"Sentry: enabled (Royalnet {release})") log.info(f"Sentry: enabled (Royalnet {release})")
else:
log.info("Sentry: disabled")
# Run the server # Run the server
log.info(f"Running Constellation on {address}:{port}...") log.info(f"Running Constellation on {address}:{port}...")
self.running = True self.running = True

View file

@ -1,8 +1,13 @@
from starlette.responses import JSONResponse try:
from starlette.responses import JSONResponse
except ImportError:
JSONResponse = None
def shoot(code: int, description: str) -> JSONResponse: def shoot(code: int, description: str) -> JSONResponse:
"""Create a error :class:`JSONResponse` with the passed error code and description.""" """Create a error :class:`JSONResponse` with the passed error code and description."""
if JSONResponse is None:
raise ImportError("'constellation' extra is not installed")
return JSONResponse({ return JSONResponse({
"error": description "error": description
}, status_code=code) }, status_code=code)

View file

@ -1,8 +1,9 @@
from typing import Type, TYPE_CHECKING, List, Union from typing import Type, TYPE_CHECKING, List, Union
from starlette.requests import Request
from starlette.responses import Response
if TYPE_CHECKING: if TYPE_CHECKING:
from .constellation import Constellation from .constellation import Constellation
from starlette.requests import Request
from starlette.responses import Response
class Star: class Star:
@ -15,7 +16,7 @@ class Star:
def __init__(self, constellation: "Constellation"): def __init__(self, constellation: "Constellation"):
self.constellation: "Constellation" = constellation self.constellation: "Constellation" = constellation
async def page(self, request: Request) -> Response: async def page(self, request: "Request") -> "Response":
"""The function generating the :class:`Response` to a web :class:`Request`. """The function generating the :class:`Response` to a web :class:`Request`.
If it raises an error, the corresponding :class:`ExceptionStar` will be used to handle the request instead.""" If it raises an error, the corresponding :class:`ExceptionStar` will be used to handle the request instead."""

View file

@ -1,5 +1,6 @@
from typing import Type from typing import Type, TYPE_CHECKING
from sqlalchemy.schema import Table if TYPE_CHECKING:
from sqlalchemy.schema import Table
class AlchemyConfig: class AlchemyConfig:
@ -15,4 +16,4 @@ class AlchemyConfig:
self.identity_column: str = identity_column self.identity_column: str = identity_column
def __repr__(self): def __repr__(self):
return f"<{self.__class__.__qualname__} for {self.server_url}>" return f"<{self.__class__.__qualname__} for {self.database_url}>"

View file

@ -2,18 +2,40 @@ import logging
from asyncio import Task, AbstractEventLoop from asyncio import Task, AbstractEventLoop
from typing import Type, Optional, Awaitable, Dict, List, Any, Callable, Union, Set from typing import Type, Optional, Awaitable, Dict, List, Any, Callable, Union, Set
from keyring import get_password from keyring import get_password
import sentry_sdk
from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
from sentry_sdk.integrations.aiohttp import AioHttpIntegration
from sentry_sdk.integrations.logging import LoggingIntegration
from sqlalchemy.schema import Table from sqlalchemy.schema import Table
from royalnet import __version__ as version from royalnet import __version__ as version
from royalnet.commands import Command, CommandInterface, CommandData, CommandError, UnsupportedError from royalnet.commands import Command, CommandInterface, CommandData, CommandError, UnsupportedError
from royalnet.alchemy import Alchemy, table_dfs
from royalnet.herald import Response, ResponseSuccess, Broadcast, ResponseFailure, Request, Link
from royalnet.herald import Config as HeraldConfig
from .alchemyconfig import AlchemyConfig from .alchemyconfig import AlchemyConfig
try:
from royalnet.alchemy import Alchemy, table_dfs
except ImportError:
Alchemy = None
table_dfs = None
try:
from royalnet.herald import Response, ResponseSuccess, Broadcast, ResponseFailure, Request, Link
from royalnet.herald import Config as HeraldConfig
except ImportError:
Response = None
ResponseSuccess = None
Broadcast = None
ResponseFailure = None
Request = None
Link = None
HeraldConfig = None
try:
import sentry_sdk
from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
from sentry_sdk.integrations.aiohttp import AioHttpIntegration
from sentry_sdk.integrations.logging import LoggingIntegration
except ImportError:
sentry_sdk = None
SqlalchemyIntegration = None
AioHttpIntegration = None
LoggingIntegration = None
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -42,9 +64,14 @@ class Serf:
# TODO: I'm not sure what this is either # TODO: I'm not sure what this is either
self._identity_column: Optional[str] = None self._identity_column: Optional[str] = None
if alchemy_config is not None: if Alchemy is None:
log.info("Alchemy: not installed")
elif alchemy_config is None:
log.info("Alchemy: disabled")
else:
tables = self.find_tables(alchemy_config, commands) tables = self.find_tables(alchemy_config, commands)
self.init_alchemy(alchemy_config, tables) self.init_alchemy(alchemy_config, tables)
log.info(f"Alchemy: {self.alchemy}")
self.Interface: Type[CommandInterface] = self.interface_factory() self.Interface: Type[CommandInterface] = self.interface_factory()
"""The :class:`CommandInterface` class of this Serf.""" """The :class:`CommandInterface` class of this Serf."""
@ -58,6 +85,7 @@ class Serf:
if commands is None: if commands is None:
commands = [] commands = []
self.register_commands(commands) self.register_commands(commands)
log.info(f"Commands: total {len(self.commands)}")
self.herald_handlers: Dict[str, Callable[["Serf", Any], Awaitable[Optional[dict]]]] = {} self.herald_handlers: Dict[str, Callable[["Serf", Any], Awaitable[Optional[dict]]]] = {}
"""A :class:`dict` linking :class:`Request` event names to coroutines returning a :class:`dict` that will be """A :class:`dict` linking :class:`Request` event names to coroutines returning a :class:`dict` that will be
@ -69,8 +97,13 @@ class Serf:
self.herald_task: Optional[Task] = None self.herald_task: Optional[Task] = None
"""A reference to the :class:`asyncio.Task` that runs the :class:`Link`.""" """A reference to the :class:`asyncio.Task` that runs the :class:`Link`."""
if network_config is not None: if Link is None:
log.info("Herald: not installed")
elif network_config is None:
log.info("Herald: disabled")
else:
self.init_network(network_config) self.init_network(network_config)
log.info(f"Herald: {self.herald}")
self.loop: Optional[AbstractEventLoop] = None self.loop: Optional[AbstractEventLoop] = None
"""The event loop this Serf is running on.""" """The event loop this Serf is running on."""
@ -194,7 +227,7 @@ class Serf:
self.commands[f"{interface.prefix}{alias}"] = \ self.commands[f"{interface.prefix}{alias}"] = \
self.commands[f"{interface.prefix}{SelectedCommand.name}"] self.commands[f"{interface.prefix}{SelectedCommand.name}"]
else: else:
log.info(f"Ignoring (already defined): {SelectedCommand.__qualname__} -> {interface.prefix}{alias}") log.warning(f"Ignoring (already defined): {SelectedCommand.__qualname__} -> {interface.prefix}{alias}")
def init_network(self, config: HeraldConfig): def init_network(self, config: HeraldConfig):
"""Create a :py:class:`Link`, and run it as a :py:class:`asyncio.Task`.""" """Create a :py:class:`Link`, and run it as a :py:class:`asyncio.Task`."""
@ -226,23 +259,25 @@ class Serf:
elif isinstance(message, Broadcast): elif isinstance(message, Broadcast):
await network_handler(self, **message.data) await network_handler(self, **message.data)
def init_sentry(self): @staticmethod
sentry_dsn = self.get_secret("sentry") def init_sentry(dsn):
if sentry_dsn:
# noinspection PyUnreachableCode # noinspection PyUnreachableCode
if __debug__: if __debug__:
release = f"Dev" release = f"Dev"
else: else:
release = f"{version}" release = f"{version}"
log.debug("Initializing Sentry...") log.debug("Initializing Sentry...")
sentry_sdk.init(sentry_dsn, sentry_sdk.init(dsn,
integrations=[AioHttpIntegration(), integrations=[AioHttpIntegration(),
SqlalchemyIntegration(), SqlalchemyIntegration(),
LoggingIntegration(event_level=None)], LoggingIntegration(event_level=None)],
release=release) release=release)
log.info(f"Sentry: enabled (Royalnet {release})") log.info(f"Sentry: enabled (Royalnet {release})")
else:
log.info("Sentry: disabled") @staticmethod
def sentry_exc(exc: Exception):
if sentry_sdk is not None:
sentry_sdk.capture_exception(exc)
def get_secret(self, username: str): def get_secret(self, username: str):
"""Get a Royalnet secret from the keyring. """Get a Royalnet secret from the keyring.
@ -259,5 +294,13 @@ class Serf:
"""Blockingly run the Serf. """Blockingly run the Serf.
This should be used as the target of a :class:`multiprocessing.Process`.""" This should be used as the target of a :class:`multiprocessing.Process`."""
self.init_sentry() if sentry_sdk is None:
log.info("Sentry: not installed")
else:
sentry_dsn = self.get_secret("sentry")
if sentry_dsn is None:
log.info("Sentry: disabled")
else:
self.init_sentry(sentry_dsn)
self.loop.run_until_complete(self.run()) self.loop.run_until_complete(self.run())

View file

@ -1 +1,7 @@
from .escape import escape from .escape import escape
from .telegramserf import TelegramSerf
__all__ = [
"escape",
"TelegramSerf"
]

View file

@ -1,20 +1,33 @@
import logging import logging
import asyncio import asyncio
import warnings from typing import Type, Optional, List, Callable
import uuid
from typing import Type, Optional, Dict, List, Tuple, Callable
import telegram
import urllib3
import sentry_sdk
from telegram.utils.request import Request as TRequest
from royalnet.commands import Command, CommandInterface, CommandData, CommandArgs, CommandError, InvalidInputError, \ from royalnet.commands import Command, CommandInterface, CommandData, CommandArgs, CommandError, InvalidInputError, \
UnsupportedError, KeyboardExpiredError UnsupportedError
from royalnet.herald import Config as HeraldConfig
from royalnet.utils import asyncify from royalnet.utils import asyncify
from .escape import escape from .escape import escape
from ..alchemyconfig import AlchemyConfig
from ..serf import Serf from ..serf import Serf
try:
import telegram
import urllib3
from telegram.utils.request import Request as TRequest
except ImportError:
telegram = None
urllib3 = None
TRequest = None
try:
from sqlalchemy.orm.session import Session
from ..alchemyconfig import AlchemyConfig
except ImportError:
Session = None
AlchemyConfig = None
try:
from royalnet.herald import Config as HeraldConfig
except ImportError:
HeraldConfig = None
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -27,6 +40,9 @@ class TelegramSerf(Serf):
commands: List[Type[Command]] = None, commands: List[Type[Command]] = None,
network_config: Optional[HeraldConfig] = None, network_config: Optional[HeraldConfig] = None,
secrets_name: str = "__default__"): secrets_name: str = "__default__"):
if telegram is None:
raise ImportError("'telegram' extra is not installed")
super().__init__(alchemy_config=alchemy_config, super().__init__(alchemy_config=alchemy_config,
commands=commands, commands=commands,
network_config=network_config, network_config=network_config,
@ -65,7 +81,7 @@ class TelegramSerf(Serf):
continue continue
except Exception as error: except Exception as error:
log.error(f"{error.__class__.__qualname__} during {f} (skipping): {error}") log.error(f"{error.__class__.__qualname__} during {f} (skipping): {error}")
sentry_sdk.capture_exception(error) TelegramSerf.sentry_exc(error)
break break
return None return None
@ -80,26 +96,14 @@ class TelegramSerf(Serf):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.keys_callbacks: Dict[..., Callable] = {}
def register_keyboard_key(interface, key_name: ..., callback: Callable):
warnings.warn("register_keyboard_key is deprecated", category=DeprecationWarning)
interface.keys_callbacks[key_name] = callback
def unregister_keyboard_key(interface, key_name: ...):
warnings.warn("unregister_keyboard_key is deprecated", category=DeprecationWarning)
try:
del interface.keys_callbacks[key_name]
except KeyError:
raise KeyError(f"Key '{key_name}' is not registered")
return TelegramInterface return TelegramInterface
def data_factory(self) -> Type[CommandData]: def data_factory(self) -> Type[CommandData]:
# noinspection PyMethodParameters # noinspection PyMethodParameters
class TelegramData(CommandData): class TelegramData(CommandData):
def __init__(data, interface: CommandInterface, update: telegram.Update): def __init__(data, interface: CommandInterface, session, update: telegram.Update):
super().__init__(interface) super().__init__(interface=interface, session=session)
data.update = update data.update = update
async def reply(data, text: str): async def reply(data, text: str):
@ -128,34 +132,39 @@ class TelegramSerf(Serf):
raise CommandError("Command caller is not registered") raise CommandError("Command caller is not registered")
return result return result
async def keyboard(data, text: str, keyboard: Dict[str, Callable]) -> None:
warnings.warn("keyboard is deprecated, please avoid using it", category=DeprecationWarning)
tg_keyboard = []
for key in keyboard:
press_id = uuid.uuid4()
tg_keyboard.append([telegram.InlineKeyboardButton(key, callback_data=str(press_id))])
data._interface.register_keyboard_key(key_name=str(press_id), callback=keyboard[key])
await self.api_call(data.update.effective_chat.send_message,
escape(text),
reply_markup=telegram.InlineKeyboardMarkup(tg_keyboard),
parse_mode="HTML",
disable_web_page_preview=True)
async def delete_invoking(data, error_if_unavailable=False) -> None: async def delete_invoking(data, error_if_unavailable=False) -> None:
message: telegram.Message = data.update.message message: telegram.Message = data.update.message
await self.api_call(message.delete) await self.api_call(message.delete)
return TelegramData return TelegramData
async def _handle_update(self, update: telegram.Update): async def handle_update(self, update: telegram.Update):
"""What should be done when a :class:`telegram.Update` is received?""" """Delegate :class:`telegram.Update` handling to the correct message type submethod."""
# Skip non-message updates
if update.message is not None:
await self._handle_message(update)
elif update.callback_query is not None:
await self._handle_callback_query(update)
async def _handle_message(self, update: telegram.Update): if update.message is not None:
await self.handle_message(update)
elif update.edited_message is not None:
pass
elif update.channel_post is not None:
pass
elif update.edited_channel_post is not None:
pass
elif update.inline_query is not None:
pass
elif update.chosen_inline_result is not None:
pass
elif update.callback_query is not None:
pass
elif update.shipping_query is not None:
pass
elif update.pre_checkout_query is not None:
pass
elif update.poll is not None:
pass
else:
log.warning(f"Unknown update type: {update}")
async def handle_message(self, update: telegram.Update):
"""What should be done when a :class:`telegram.Message` is received?""" """What should be done when a :class:`telegram.Message` is received?"""
message: telegram.Message = update.message message: telegram.Message = update.message
text: str = message.text text: str = message.text
@ -171,16 +180,22 @@ class TelegramSerf(Serf):
# Find and clean parameters # Find and clean parameters
command_text, *parameters = text.split(" ") command_text, *parameters = text.split(" ")
command_name = command_text.replace(f"@{self.client.username}", "").lower() command_name = command_text.replace(f"@{self.client.username}", "").lower()
# Send a typing notification
await self.api_call(update.message.chat.send_action, telegram.ChatAction.TYPING)
# Find the command # Find the command
try: try:
command = self.commands[command_name] command = self.commands[command_name]
except KeyError: except KeyError:
# Skip the message # Skip the message
return return
# Send a typing notification
await self.api_call(update.message.chat.send_action, telegram.ChatAction.TYPING)
# Prepare data # Prepare data
data = self.Data(interface=command.interface, update=update) if self.alchemy is not None:
session = await asyncify(self.alchemy.Session)
else:
session = None
try:
# Create the command data
data = self.Data(interface=command.interface, session=session, update=update)
try: try:
# Run the command # Run the command
await command.run(CommandArgs(parameters), data) await command.run(CommandArgs(parameters), data)
@ -192,63 +207,55 @@ class TelegramSerf(Serf):
except CommandError as e: except CommandError as e:
await data.reply(f"⚠️ {e.message}") await data.reply(f"⚠️ {e.message}")
except Exception as e: except Exception as e:
sentry_sdk.capture_exception(e) self.sentry_exc(e)
error_message = f"🦀 [b]{e.__class__.__name__}[/b] 🦀\n" error_message = f"🦀 [b]{e.__class__.__name__}[/b] 🦀\n" \
error_message += '\n'.join(e.args) '\n'.join(e.args)
await data.reply(error_message) await data.reply(error_message)
finally: finally:
# Close the data session if session is not None:
await data.session_close() await asyncify(session.close)
async def _handle_callback_query(self, update: telegram.Update): async def handle_edited_message(self, update: telegram.Update):
query: telegram.CallbackQuery = update.callback_query pass
source: telegram.Message = query.message
callback: Optional[Callable] = None
command: Optional[Command] = None
for command in self.commands.values():
if query.data in command.interface.keys_callbacks:
callback = command.interface.keys_callbacks[query.data]
break
if callback is None:
await self.api_call(source.edit_reply_markup, reply_markup=None)
await self.api_call(query.answer, text="⛔️ This keyboard has expired.")
return
try:
response = await callback(data=self.Data(interface=command.interface, update=update))
except KeyboardExpiredError as e:
# FIXME: May cause a memory leak, as keys are not deleted after use
await self.safe_api_call(source.edit_reply_markup, reply_markup=None)
if len(e.args) > 0:
await self.safe_api_call(query.answer, text=f"⛔️ {e.args[0]}")
else:
await self.safe_api_call(query.answer, text="⛔️ This keyboard has expired.")
return
except Exception as e:
error_text = f"⛔️ {e.__class__.__name__}\n"
error_text += '\n'.join(e.args)
await self.safe_api_call(query.answer, text=error_text)
else:
await self.safe_api_call(query.answer, text=response)
def _initialize(self): async def handle_channel_post(self, update: telegram.Update):
super()._initialize() pass
self._init_client()
async def handle_edited_channel_post(self, update: telegram.Update):
pass
async def handle_inline_query(self, update: telegram.Update):
pass
async def handle_chosen_inline_result(self, update: telegram.Update):
pass
async def handle_callback_query(self, update: telegram.Update):
pass
async def handle_shipping_query(self, update: telegram.Update):
pass
async def handle_pre_checkout_query(self, update: telegram.Update):
pass
async def handle_poll(self, update: telegram.Update):
pass
async def run(self): async def run(self):
if not self.initialized:
self._initialize()
while True: while True:
# Get the latest 100 updates # Get the latest 100 updates
last_updates: List[telegram.Update] = await self.safe_api_call(self.client.get_updates, last_updates: List[telegram.Update] = await self.api_call(self.client.get_updates,
offset=self._offset, offset=self.update_offset,
timeout=30, timeout=60,
read_latency=5.0) read_latency=5.0)
# Handle updates # Handle updates
for update in last_updates: for update in last_updates:
# TODO: don't lose the reference to the task
# noinspection PyAsyncCall # noinspection PyAsyncCall
self.loop.create_task(self._handle_update(update)) self.loop.create_task(self.handle_update(update))
# Recalculate offset # Recalculate offset
try: try:
self._offset = last_updates[-1].update_id + 1 self.update_offset = last_updates[-1].update_id + 1
except IndexError: except IndexError:
pass pass