mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 19:44:20 +00:00
A lot of stuff
This commit is contained in:
parent
bad35daf5b
commit
03b3d52f0e
13 changed files with 177 additions and 88 deletions
|
@ -6,6 +6,7 @@ from royalnet.commands import *
|
|||
from royalnet.commands.debug_create import DebugCreateCommand
|
||||
from royalnet.commands.error_handler import ErrorHandlerCommand
|
||||
from royalnet.network import RoyalnetServer
|
||||
from royalnet.database import DatabaseConfig
|
||||
from royalnet.database.tables import Royal, Telegram, Discord
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
@ -20,11 +21,13 @@ commands = [PingCommand, ShipCommand, SmecdsCommand, ColorCommand, CiaoruoziComm
|
|||
KvrollCommand, VideoinfoCommand, SummonCommand, PlayCommand]
|
||||
|
||||
master = RoyalnetServer("localhost", 1234, "sas")
|
||||
# tg_bot = TelegramBot(os.environ["TG_AK"], "ws://localhost:1234", "sas", commands, os.environ["DB_PATH"], Royal, Telegram, "tg_id", error_command=ErrorHandlerCommand)
|
||||
ds_bot = DiscordBot(os.environ["DS_AK"], "ws://localhost:1234", "sas", commands, os.environ["DB_PATH"], Royal, Discord, "discord_id", error_command=ErrorHandlerCommand)
|
||||
tg_db_cfg = DatabaseConfig(os.environ["DB_PATH"], Royal, Telegram, "tg_id")
|
||||
tg_bot = TelegramBot(os.environ["TG_AK"], "ws://localhost:1234", "sas", commands, NullCommand, ErrorHandlerCommand, tg_db_cfg)
|
||||
ds_db_cfg = DatabaseConfig(os.environ["DB_PATH"], Royal, Discord, "discord_id")
|
||||
ds_bot = DiscordBot(os.environ["DS_AK"], "ws://localhost:1234", "sas", commands, NullCommand, ErrorHandlerCommand, ds_db_cfg)
|
||||
loop.run_until_complete(master.run())
|
||||
# Dirty hack, remove me asap
|
||||
# loop.create_task(tg_bot.run())
|
||||
loop.create_task(tg_bot.run())
|
||||
loop.create_task(ds_bot.run())
|
||||
print("Starting loop...")
|
||||
loop.run_forever()
|
||||
|
|
|
@ -4,11 +4,11 @@ import typing
|
|||
import logging as _logging
|
||||
import sys
|
||||
from ..commands import NullCommand
|
||||
from ..utils import asyncify, Call, Command
|
||||
from ..error import UnregisteredError, NoneFoundError, TooManyFoundError
|
||||
from ..utils import asyncify, Call, Command, NetworkHandler
|
||||
from ..error import UnregisteredError, NoneFoundError, TooManyFoundError, InvalidConfigError
|
||||
from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError
|
||||
from ..database import Alchemy, relationshiplinkchain
|
||||
from ..audio import RoyalPCMFile, PlayMode, Playlist, Pool
|
||||
from ..database import Alchemy, relationshiplinkchain, DatabaseConfig
|
||||
from ..audio import RoyalPCMFile, PlayMode, Playlist
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
log = _logging.getLogger(__name__)
|
||||
|
@ -18,31 +18,15 @@ if not discord.opus.is_loaded():
|
|||
log.error("Opus is not loaded. Weird behaviour might emerge.")
|
||||
|
||||
|
||||
class PlayMessage(Message):
|
||||
def __init__(self, url: str, guild_identifier: typing.Optional[str] = None):
|
||||
self.url: str = url
|
||||
self.guild_identifier: typing.Optional[str] = guild_identifier
|
||||
|
||||
|
||||
class SummonMessage(Message):
|
||||
def __init__(self, channel_identifier: typing.Union[int, str],
|
||||
guild_identifier: typing.Optional[typing.Union[int, str]]):
|
||||
self.channel_identifier = channel_identifier
|
||||
self.guild_identifier = guild_identifier
|
||||
|
||||
|
||||
class DiscordBot:
|
||||
def __init__(self,
|
||||
token: str,
|
||||
master_server_uri: str,
|
||||
master_server_secret: str,
|
||||
commands: typing.List[typing.Type[Command]],
|
||||
database_uri: str,
|
||||
master_table,
|
||||
identity_table,
|
||||
identity_column_name: str,
|
||||
missing_command: typing.Type[Command] = NullCommand,
|
||||
error_command: typing.Type[Command] = NullCommand):
|
||||
error_command: typing.Type[Command] = NullCommand,
|
||||
database_config: typing.Optional[DatabaseConfig] = None):
|
||||
self.token = token
|
||||
# Generate commands
|
||||
self.missing_command = missing_command
|
||||
|
@ -52,12 +36,25 @@ class DiscordBot:
|
|||
for command in commands:
|
||||
self.commands[f"!{command.command_name}"] = command
|
||||
required_tables = required_tables.union(command.require_alchemy_tables)
|
||||
# Generate network handlers
|
||||
self.network_handlers: typing.Dict[typing.Type[Message], typing.Type[NetworkHandler]] = {}
|
||||
for command in commands:
|
||||
self.network_handlers = {**self.network_handlers, **command.network_handler_dict()}
|
||||
# Generate the Alchemy database
|
||||
self.alchemy = Alchemy(database_uri, required_tables)
|
||||
self.master_table = self.alchemy.__getattribute__(master_table.__name__)
|
||||
self.identity_table = self.alchemy.__getattribute__(identity_table.__name__)
|
||||
self.identity_column = self.identity_table.__getattribute__(self.identity_table, identity_column_name)
|
||||
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
||||
if database_config:
|
||||
self.alchemy = Alchemy(database_config.database_uri, required_tables)
|
||||
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
|
||||
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
|
||||
self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name)
|
||||
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
||||
else:
|
||||
if required_tables:
|
||||
raise InvalidConfigError("Tables are required by the commands, but Alchemy is not configured")
|
||||
self.alchemy = None
|
||||
self.master_table = None
|
||||
self.identity_table = None
|
||||
self.identity_column = None
|
||||
self.identity_chain = None
|
||||
# Connect to Royalnet
|
||||
self.network: RoyalnetLink = RoyalnetLink(master_server_uri, master_server_secret, "discord",
|
||||
self.network_handler)
|
||||
|
@ -70,6 +67,7 @@ class DiscordBot:
|
|||
interface_name = "discord"
|
||||
interface_obj = self
|
||||
interface_prefix = "!"
|
||||
|
||||
alchemy = self.alchemy
|
||||
|
||||
async def reply(call, text: str):
|
||||
|
@ -213,18 +211,7 @@ class DiscordBot:
|
|||
async def network_handler(self, message: Message) -> Message:
|
||||
"""Handle a Royalnet request."""
|
||||
log.debug(f"Received {message} from Royalnet")
|
||||
if isinstance(message, SummonMessage):
|
||||
return await self.nh_summon(message)
|
||||
elif isinstance(message, PlayMessage):
|
||||
return await self.nh_play(message)
|
||||
|
||||
async def nh_summon(self, message: SummonMessage):
|
||||
"""Handle a summon Royalnet request. That is, join a voice channel, or move to a different one if that is not possible."""
|
||||
channel = self.find_channel(message.channel_identifier)
|
||||
if not isinstance(channel, discord.VoiceChannel):
|
||||
raise NoneFoundError("Channel is not a voice channel")
|
||||
loop.create_task(self.bot.vc_connect_or_move(channel))
|
||||
return RequestSuccessful()
|
||||
return await self.network_handlers[message.__class__].discord(message)
|
||||
|
||||
async def add_to_music_data(self, url: str, guild: discord.Guild):
|
||||
"""Add a file to the corresponding music_data object."""
|
||||
|
@ -257,23 +244,6 @@ class DiscordBot:
|
|||
log.debug(f"Starting playback of {next_source}")
|
||||
voice_client.play(next_source, after=advance)
|
||||
|
||||
async def nh_play(self, message: PlayMessage):
|
||||
"""Handle a play Royalnet request. That is, add audio to a PlayMode."""
|
||||
# Find the matching guild
|
||||
if message.guild_identifier:
|
||||
guild = self.find_guild(message.guild_identifier)
|
||||
else:
|
||||
if len(self.music_data) != 1:
|
||||
raise TooManyFoundError("Multiple guilds found")
|
||||
guild = list(self.music_data)[0]
|
||||
# Ensure the guild has a PlayMode before adding the file to it
|
||||
if not self.music_data.get(guild):
|
||||
# TODO: change Exception
|
||||
raise Exception("No music_data for this guild")
|
||||
# Start downloading
|
||||
loop.create_task(self.add_to_music_data(message.url, guild))
|
||||
return RequestSuccessful()
|
||||
|
||||
async def run(self):
|
||||
await self.bot.login(self.token)
|
||||
await self.bot.connect()
|
||||
|
|
|
@ -5,9 +5,9 @@ import logging as _logging
|
|||
import sys
|
||||
from ..commands import NullCommand
|
||||
from ..utils import asyncify, Call, Command
|
||||
from ..error import UnregisteredError
|
||||
from ..error import UnregisteredError, InvalidConfigError
|
||||
from ..network import RoyalnetLink, Message, RequestError
|
||||
from ..database import Alchemy, relationshiplinkchain
|
||||
from ..database import Alchemy, relationshiplinkchain, DatabaseConfig
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
log = _logging.getLogger(__name__)
|
||||
|
@ -23,12 +23,9 @@ class TelegramBot:
|
|||
master_server_uri: str,
|
||||
master_server_secret: str,
|
||||
commands: typing.List[typing.Type[Command]],
|
||||
database_uri: str,
|
||||
master_table,
|
||||
identity_table,
|
||||
identity_column_name: str,
|
||||
missing_command: typing.Type[Command] = NullCommand,
|
||||
error_command: typing.Type[Command] = NullCommand):
|
||||
error_command: typing.Type[Command] = NullCommand,
|
||||
database_config: typing.Optional[DatabaseConfig] = None):
|
||||
self.bot: telegram.Bot = telegram.Bot(api_key)
|
||||
self.should_run: bool = False
|
||||
self.offset: int = -100
|
||||
|
@ -43,12 +40,20 @@ class TelegramBot:
|
|||
self.commands[f"/{command.command_name}"] = command
|
||||
required_tables = required_tables.union(command.require_alchemy_tables)
|
||||
# Generate the Alchemy database
|
||||
self.alchemy = Alchemy(database_uri, required_tables)
|
||||
self.master_table = self.alchemy.__getattribute__(master_table.__name__)
|
||||
self.identity_table = self.alchemy.__getattribute__(identity_table.__name__)
|
||||
self.identity_column = self.identity_table.__getattribute__(self.identity_table, identity_column_name)
|
||||
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
||||
|
||||
if database_config:
|
||||
self.alchemy = Alchemy(database_config.database_uri, required_tables)
|
||||
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
|
||||
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
|
||||
self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name)
|
||||
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
||||
else:
|
||||
if required_tables:
|
||||
raise InvalidConfigError("Tables are required by the commands, but Alchemy is not configured")
|
||||
self.alchemy = None
|
||||
self.master_table = None
|
||||
self.identity_table = None
|
||||
self.identity_column = None
|
||||
self.identity_chain = None
|
||||
# noinspection PyMethodParameters
|
||||
class TelegramCall(Call):
|
||||
interface_name = "telegram"
|
||||
|
|
|
@ -1,7 +1,42 @@
|
|||
import typing
|
||||
from ..utils import Command, Call
|
||||
import discord
|
||||
import asyncio
|
||||
from ..utils import Command, Call, NetworkHandler
|
||||
from ..network import Message, RequestSuccessful, RequestError
|
||||
from ..bots.discord import PlayMessage
|
||||
from ..error import TooManyFoundError
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..bots import DiscordBot
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
class PlayMessage(Message):
|
||||
def __init__(self, url: str, guild_identifier: typing.Optional[str] = None):
|
||||
self.url: str = url
|
||||
self.guild_identifier: typing.Optional[str] = guild_identifier
|
||||
|
||||
|
||||
class PlayNH(NetworkHandler):
|
||||
message_type = PlayMessage
|
||||
|
||||
@classmethod
|
||||
async def nh_play(cls, bot: "DiscordBot", message: PlayMessage):
|
||||
"""Handle a play Royalnet request. That is, add audio to a PlayMode."""
|
||||
# Find the matching guild
|
||||
if message.guild_identifier:
|
||||
guild = bot.find_guild(message.guild_identifier)
|
||||
else:
|
||||
if len(bot.music_data) != 1:
|
||||
raise TooManyFoundError("Multiple guilds found")
|
||||
guild = list(bot.music_data)[0]
|
||||
# Ensure the guild has a PlayMode before adding the file to it
|
||||
if not bot.music_data.get(guild):
|
||||
# TODO: change Exception
|
||||
raise Exception("No music_data for this guild")
|
||||
# Start downloading
|
||||
loop.create_task(bot.add_to_music_data(message.url, guild))
|
||||
return RequestSuccessful()
|
||||
|
||||
|
||||
class PlayCommand(Command):
|
||||
|
@ -9,6 +44,8 @@ class PlayCommand(Command):
|
|||
command_description = "Riproduce una canzone in chat vocale."
|
||||
command_syntax = "[ [guild] ] (url)"
|
||||
|
||||
network_handlers = [PlayNH]
|
||||
|
||||
@classmethod
|
||||
async def common(cls, call: Call):
|
||||
guild, url = call.args.match(r"(?:\[(.+)])?\s*(\S+)\s*")
|
||||
|
|
|
@ -1,8 +1,34 @@
|
|||
import typing
|
||||
import discord
|
||||
from ..utils import Command, Call
|
||||
import asyncio
|
||||
from ..utils import Command, Call, NetworkHandler
|
||||
from ..network import Message, RequestSuccessful, RequestError
|
||||
from ..bots.discord import SummonMessage
|
||||
from ..error import NoneFoundError
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..bots import DiscordBot
|
||||
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
class SummonMessage(Message):
|
||||
def __init__(self, channel_identifier: typing.Union[int, str],
|
||||
guild_identifier: typing.Optional[typing.Union[int, str]] = None):
|
||||
self.channel_identifier = channel_identifier
|
||||
self.guild_identifier = guild_identifier
|
||||
|
||||
|
||||
class SummonNH(NetworkHandler):
|
||||
message_type = SummonMessage
|
||||
|
||||
@classmethod
|
||||
async def discord(cls, bot: "DiscordBot", message: SummonMessage):
|
||||
"""Handle a summon Royalnet request. That is, join a voice channel, or move to a different one if that is not possible."""
|
||||
channel = bot.find_channel(message.channel_identifier)
|
||||
if not isinstance(channel, discord.VoiceChannel):
|
||||
raise NoneFoundError("Channel is not a voice channel")
|
||||
loop.create_task(bot.bot.vc_connect_or_move(channel))
|
||||
return RequestSuccessful()
|
||||
|
||||
|
||||
class SummonCommand(Command):
|
||||
|
@ -11,6 +37,8 @@ class SummonCommand(Command):
|
|||
command_description = "Evoca il bot in un canale vocale."
|
||||
command_syntax = "[channelname]"
|
||||
|
||||
network_handlers = [SummonNH]
|
||||
|
||||
@classmethod
|
||||
async def common(cls, call: Call):
|
||||
channel_name: str = call.args[0].lstrip("#")
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from .alchemy import Alchemy
|
||||
from .relationshiplinkchain import relationshiplinkchain
|
||||
from .databaseconfig import DatabaseConfig
|
||||
|
||||
__all__ = ["Alchemy", "relationshiplinkchain"]
|
||||
__all__ = ["Alchemy", "relationshiplinkchain", "DatabaseConfig"]
|
||||
|
|
|
@ -4,7 +4,8 @@ from sqlalchemy import create_engine
|
|||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from contextlib import contextmanager, asynccontextmanager
|
||||
from ..utils import cdj, asyncify
|
||||
from ..utils import asyncify
|
||||
from ..error import InvalidConfigError
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
|
13
royalnet/database/databaseconfig.py
Normal file
13
royalnet/database/databaseconfig.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
import typing
|
||||
|
||||
|
||||
class DatabaseConfig:
|
||||
def __init__(self,
|
||||
database_uri: str,
|
||||
master_table: typing.Type,
|
||||
identity_table: typing.Type,
|
||||
identity_column_name: str):
|
||||
self.database_uri: str = database_uri
|
||||
self.master_table: typing.Type = master_table
|
||||
self.identity_table: typing.Type = identity_table
|
||||
self.identity_column_name: str = identity_column_name
|
|
@ -6,5 +6,7 @@ from .safeformat import safeformat
|
|||
from .classdictjanitor import cdj
|
||||
from .sleepuntil import sleep_until
|
||||
from .plusformat import plusformat
|
||||
from .networkhandler import NetworkHandler
|
||||
|
||||
__all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs"]
|
||||
__all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs",
|
||||
"NetworkHandler"]
|
||||
|
|
|
@ -3,8 +3,7 @@ import asyncio
|
|||
import logging
|
||||
from ..network.messages import Message
|
||||
from .command import Command
|
||||
from royalnet.utils import CommandArgs
|
||||
|
||||
from .commandargs import CommandArgs
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..database import Alchemy
|
||||
|
||||
|
@ -57,10 +56,7 @@ class Call:
|
|||
|
||||
async def run(self):
|
||||
await self.session_init()
|
||||
try:
|
||||
coroutine = getattr(self.command, self.interface_name)
|
||||
except AttributeError:
|
||||
coroutine = getattr(self.command, "common")
|
||||
coroutine = getattr(self.command, self.interface_name)
|
||||
try:
|
||||
result = await coroutine(self)
|
||||
finally:
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import typing
|
||||
from ..error import UnsupportedError
|
||||
from ..network import Message
|
||||
if typing.TYPE_CHECKING:
|
||||
from .call import Call
|
||||
from ..utils import NetworkHandler
|
||||
|
||||
|
||||
class Command:
|
||||
|
@ -12,5 +15,21 @@ class Command:
|
|||
|
||||
require_alchemy_tables: typing.Set = set()
|
||||
|
||||
async def common(self, call: "Call"):
|
||||
raise NotImplementedError()
|
||||
network_handlers: typing.List[typing.Type["NetworkHandler"]] = {}
|
||||
|
||||
@classmethod
|
||||
async def common(cls, call: "Call"):
|
||||
raise UnsupportedError()
|
||||
|
||||
@classmethod
|
||||
def network_handler_dict(cls):
|
||||
d = {}
|
||||
for network_handler in cls.network_handlers:
|
||||
d[network_handler.message_type] = network_handler
|
||||
return d
|
||||
|
||||
def __getattribute__(self, item: str):
|
||||
try:
|
||||
return self.__dict__[item]
|
||||
except KeyError:
|
||||
return self.common
|
||||
|
|
14
royalnet/utils/networkhandler.py
Normal file
14
royalnet/utils/networkhandler.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from ..network import Message
|
||||
from ..error import UnsupportedError
|
||||
|
||||
|
||||
class NetworkHandler:
|
||||
"""The NetworkHandler functions are called when a specific Message type is received."""
|
||||
|
||||
message_type = NotImplemented
|
||||
|
||||
def __getattribute__(self, item: str):
|
||||
try:
|
||||
return self.__dict__[item]
|
||||
except KeyError:
|
||||
raise UnsupportedError()
|
Loading…
Reference in a new issue