1
Fork 0
mirror of https://github.com/RYGhub/royalnet.git synced 2024-11-23 19:44:20 +00:00

A lot of stuff

This commit is contained in:
Steffo 2019-04-18 16:09:02 +02:00
parent bad35daf5b
commit 03b3d52f0e
13 changed files with 177 additions and 88 deletions

View file

@ -6,6 +6,7 @@ from royalnet.commands import *
from royalnet.commands.debug_create import DebugCreateCommand from royalnet.commands.debug_create import DebugCreateCommand
from royalnet.commands.error_handler import ErrorHandlerCommand from royalnet.commands.error_handler import ErrorHandlerCommand
from royalnet.network import RoyalnetServer from royalnet.network import RoyalnetServer
from royalnet.database import DatabaseConfig
from royalnet.database.tables import Royal, Telegram, Discord from royalnet.database.tables import Royal, Telegram, Discord
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -20,11 +21,13 @@ commands = [PingCommand, ShipCommand, SmecdsCommand, ColorCommand, CiaoruoziComm
KvrollCommand, VideoinfoCommand, SummonCommand, PlayCommand] KvrollCommand, VideoinfoCommand, SummonCommand, PlayCommand]
master = RoyalnetServer("localhost", 1234, "sas") master = RoyalnetServer("localhost", 1234, "sas")
# tg_bot = TelegramBot(os.environ["TG_AK"], "ws://localhost:1234", "sas", commands, os.environ["DB_PATH"], Royal, Telegram, "tg_id", error_command=ErrorHandlerCommand) tg_db_cfg = DatabaseConfig(os.environ["DB_PATH"], Royal, Telegram, "tg_id")
ds_bot = DiscordBot(os.environ["DS_AK"], "ws://localhost:1234", "sas", commands, os.environ["DB_PATH"], Royal, Discord, "discord_id", error_command=ErrorHandlerCommand) tg_bot = TelegramBot(os.environ["TG_AK"], "ws://localhost:1234", "sas", commands, NullCommand, ErrorHandlerCommand, tg_db_cfg)
ds_db_cfg = DatabaseConfig(os.environ["DB_PATH"], Royal, Discord, "discord_id")
ds_bot = DiscordBot(os.environ["DS_AK"], "ws://localhost:1234", "sas", commands, NullCommand, ErrorHandlerCommand, ds_db_cfg)
loop.run_until_complete(master.run()) loop.run_until_complete(master.run())
# Dirty hack, remove me asap # Dirty hack, remove me asap
# loop.create_task(tg_bot.run()) loop.create_task(tg_bot.run())
loop.create_task(ds_bot.run()) loop.create_task(ds_bot.run())
print("Starting loop...") print("Starting loop...")
loop.run_forever() loop.run_forever()

View file

@ -4,11 +4,11 @@ import typing
import logging as _logging import logging as _logging
import sys import sys
from ..commands import NullCommand from ..commands import NullCommand
from ..utils import asyncify, Call, Command from ..utils import asyncify, Call, Command, NetworkHandler
from ..error import UnregisteredError, NoneFoundError, TooManyFoundError from ..error import UnregisteredError, NoneFoundError, TooManyFoundError, InvalidConfigError
from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError
from ..database import Alchemy, relationshiplinkchain from ..database import Alchemy, relationshiplinkchain, DatabaseConfig
from ..audio import RoyalPCMFile, PlayMode, Playlist, Pool from ..audio import RoyalPCMFile, PlayMode, Playlist
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__) log = _logging.getLogger(__name__)
@ -18,31 +18,15 @@ if not discord.opus.is_loaded():
log.error("Opus is not loaded. Weird behaviour might emerge.") log.error("Opus is not loaded. Weird behaviour might emerge.")
class PlayMessage(Message):
def __init__(self, url: str, guild_identifier: typing.Optional[str] = None):
self.url: str = url
self.guild_identifier: typing.Optional[str] = guild_identifier
class SummonMessage(Message):
def __init__(self, channel_identifier: typing.Union[int, str],
guild_identifier: typing.Optional[typing.Union[int, str]]):
self.channel_identifier = channel_identifier
self.guild_identifier = guild_identifier
class DiscordBot: class DiscordBot:
def __init__(self, def __init__(self,
token: str, token: str,
master_server_uri: str, master_server_uri: str,
master_server_secret: str, master_server_secret: str,
commands: typing.List[typing.Type[Command]], commands: typing.List[typing.Type[Command]],
database_uri: str,
master_table,
identity_table,
identity_column_name: str,
missing_command: typing.Type[Command] = NullCommand, missing_command: typing.Type[Command] = NullCommand,
error_command: typing.Type[Command] = NullCommand): error_command: typing.Type[Command] = NullCommand,
database_config: typing.Optional[DatabaseConfig] = None):
self.token = token self.token = token
# Generate commands # Generate commands
self.missing_command = missing_command self.missing_command = missing_command
@ -52,12 +36,25 @@ class DiscordBot:
for command in commands: for command in commands:
self.commands[f"!{command.command_name}"] = command self.commands[f"!{command.command_name}"] = command
required_tables = required_tables.union(command.require_alchemy_tables) required_tables = required_tables.union(command.require_alchemy_tables)
# Generate network handlers
self.network_handlers: typing.Dict[typing.Type[Message], typing.Type[NetworkHandler]] = {}
for command in commands:
self.network_handlers = {**self.network_handlers, **command.network_handler_dict()}
# Generate the Alchemy database # Generate the Alchemy database
self.alchemy = Alchemy(database_uri, required_tables) if database_config:
self.master_table = self.alchemy.__getattribute__(master_table.__name__) self.alchemy = Alchemy(database_config.database_uri, required_tables)
self.identity_table = self.alchemy.__getattribute__(identity_table.__name__) self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
self.identity_column = self.identity_table.__getattribute__(self.identity_table, identity_column_name) self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table) self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name)
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
else:
if required_tables:
raise InvalidConfigError("Tables are required by the commands, but Alchemy is not configured")
self.alchemy = None
self.master_table = None
self.identity_table = None
self.identity_column = None
self.identity_chain = None
# Connect to Royalnet # Connect to Royalnet
self.network: RoyalnetLink = RoyalnetLink(master_server_uri, master_server_secret, "discord", self.network: RoyalnetLink = RoyalnetLink(master_server_uri, master_server_secret, "discord",
self.network_handler) self.network_handler)
@ -70,6 +67,7 @@ class DiscordBot:
interface_name = "discord" interface_name = "discord"
interface_obj = self interface_obj = self
interface_prefix = "!" interface_prefix = "!"
alchemy = self.alchemy alchemy = self.alchemy
async def reply(call, text: str): async def reply(call, text: str):
@ -213,18 +211,7 @@ class DiscordBot:
async def network_handler(self, message: Message) -> Message: async def network_handler(self, message: Message) -> Message:
"""Handle a Royalnet request.""" """Handle a Royalnet request."""
log.debug(f"Received {message} from Royalnet") log.debug(f"Received {message} from Royalnet")
if isinstance(message, SummonMessage): return await self.network_handlers[message.__class__].discord(message)
return await self.nh_summon(message)
elif isinstance(message, PlayMessage):
return await self.nh_play(message)
async def nh_summon(self, message: SummonMessage):
"""Handle a summon Royalnet request. That is, join a voice channel, or move to a different one if that is not possible."""
channel = self.find_channel(message.channel_identifier)
if not isinstance(channel, discord.VoiceChannel):
raise NoneFoundError("Channel is not a voice channel")
loop.create_task(self.bot.vc_connect_or_move(channel))
return RequestSuccessful()
async def add_to_music_data(self, url: str, guild: discord.Guild): async def add_to_music_data(self, url: str, guild: discord.Guild):
"""Add a file to the corresponding music_data object.""" """Add a file to the corresponding music_data object."""
@ -257,23 +244,6 @@ class DiscordBot:
log.debug(f"Starting playback of {next_source}") log.debug(f"Starting playback of {next_source}")
voice_client.play(next_source, after=advance) voice_client.play(next_source, after=advance)
async def nh_play(self, message: PlayMessage):
"""Handle a play Royalnet request. That is, add audio to a PlayMode."""
# Find the matching guild
if message.guild_identifier:
guild = self.find_guild(message.guild_identifier)
else:
if len(self.music_data) != 1:
raise TooManyFoundError("Multiple guilds found")
guild = list(self.music_data)[0]
# Ensure the guild has a PlayMode before adding the file to it
if not self.music_data.get(guild):
# TODO: change Exception
raise Exception("No music_data for this guild")
# Start downloading
loop.create_task(self.add_to_music_data(message.url, guild))
return RequestSuccessful()
async def run(self): async def run(self):
await self.bot.login(self.token) await self.bot.login(self.token)
await self.bot.connect() await self.bot.connect()

View file

@ -5,9 +5,9 @@ import logging as _logging
import sys import sys
from ..commands import NullCommand from ..commands import NullCommand
from ..utils import asyncify, Call, Command from ..utils import asyncify, Call, Command
from ..error import UnregisteredError from ..error import UnregisteredError, InvalidConfigError
from ..network import RoyalnetLink, Message, RequestError from ..network import RoyalnetLink, Message, RequestError
from ..database import Alchemy, relationshiplinkchain from ..database import Alchemy, relationshiplinkchain, DatabaseConfig
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__) log = _logging.getLogger(__name__)
@ -23,12 +23,9 @@ class TelegramBot:
master_server_uri: str, master_server_uri: str,
master_server_secret: str, master_server_secret: str,
commands: typing.List[typing.Type[Command]], commands: typing.List[typing.Type[Command]],
database_uri: str,
master_table,
identity_table,
identity_column_name: str,
missing_command: typing.Type[Command] = NullCommand, missing_command: typing.Type[Command] = NullCommand,
error_command: typing.Type[Command] = NullCommand): error_command: typing.Type[Command] = NullCommand,
database_config: typing.Optional[DatabaseConfig] = None):
self.bot: telegram.Bot = telegram.Bot(api_key) self.bot: telegram.Bot = telegram.Bot(api_key)
self.should_run: bool = False self.should_run: bool = False
self.offset: int = -100 self.offset: int = -100
@ -43,12 +40,20 @@ class TelegramBot:
self.commands[f"/{command.command_name}"] = command self.commands[f"/{command.command_name}"] = command
required_tables = required_tables.union(command.require_alchemy_tables) required_tables = required_tables.union(command.require_alchemy_tables)
# Generate the Alchemy database # Generate the Alchemy database
self.alchemy = Alchemy(database_uri, required_tables) if database_config:
self.master_table = self.alchemy.__getattribute__(master_table.__name__) self.alchemy = Alchemy(database_config.database_uri, required_tables)
self.identity_table = self.alchemy.__getattribute__(identity_table.__name__) self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
self.identity_column = self.identity_table.__getattribute__(self.identity_table, identity_column_name) self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table) self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name)
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
else:
if required_tables:
raise InvalidConfigError("Tables are required by the commands, but Alchemy is not configured")
self.alchemy = None
self.master_table = None
self.identity_table = None
self.identity_column = None
self.identity_chain = None
# noinspection PyMethodParameters # noinspection PyMethodParameters
class TelegramCall(Call): class TelegramCall(Call):
interface_name = "telegram" interface_name = "telegram"

View file

@ -1,7 +1,42 @@
import typing import typing
from ..utils import Command, Call import discord
import asyncio
from ..utils import Command, Call, NetworkHandler
from ..network import Message, RequestSuccessful, RequestError from ..network import Message, RequestSuccessful, RequestError
from ..bots.discord import PlayMessage from ..error import TooManyFoundError
if typing.TYPE_CHECKING:
from ..bots import DiscordBot
loop = asyncio.get_event_loop()
class PlayMessage(Message):
def __init__(self, url: str, guild_identifier: typing.Optional[str] = None):
self.url: str = url
self.guild_identifier: typing.Optional[str] = guild_identifier
class PlayNH(NetworkHandler):
message_type = PlayMessage
@classmethod
async def nh_play(cls, bot: "DiscordBot", message: PlayMessage):
"""Handle a play Royalnet request. That is, add audio to a PlayMode."""
# Find the matching guild
if message.guild_identifier:
guild = bot.find_guild(message.guild_identifier)
else:
if len(bot.music_data) != 1:
raise TooManyFoundError("Multiple guilds found")
guild = list(bot.music_data)[0]
# Ensure the guild has a PlayMode before adding the file to it
if not bot.music_data.get(guild):
# TODO: change Exception
raise Exception("No music_data for this guild")
# Start downloading
loop.create_task(bot.add_to_music_data(message.url, guild))
return RequestSuccessful()
class PlayCommand(Command): class PlayCommand(Command):
@ -9,6 +44,8 @@ class PlayCommand(Command):
command_description = "Riproduce una canzone in chat vocale." command_description = "Riproduce una canzone in chat vocale."
command_syntax = "[ [guild] ] (url)" command_syntax = "[ [guild] ] (url)"
network_handlers = [PlayNH]
@classmethod @classmethod
async def common(cls, call: Call): async def common(cls, call: Call):
guild, url = call.args.match(r"(?:\[(.+)])?\s*(\S+)\s*") guild, url = call.args.match(r"(?:\[(.+)])?\s*(\S+)\s*")

View file

@ -1,8 +1,34 @@
import typing import typing
import discord import discord
from ..utils import Command, Call import asyncio
from ..utils import Command, Call, NetworkHandler
from ..network import Message, RequestSuccessful, RequestError from ..network import Message, RequestSuccessful, RequestError
from ..bots.discord import SummonMessage from ..error import NoneFoundError
if typing.TYPE_CHECKING:
from ..bots import DiscordBot
loop = asyncio.get_event_loop()
class SummonMessage(Message):
def __init__(self, channel_identifier: typing.Union[int, str],
guild_identifier: typing.Optional[typing.Union[int, str]] = None):
self.channel_identifier = channel_identifier
self.guild_identifier = guild_identifier
class SummonNH(NetworkHandler):
message_type = SummonMessage
@classmethod
async def discord(cls, bot: "DiscordBot", message: SummonMessage):
"""Handle a summon Royalnet request. That is, join a voice channel, or move to a different one if that is not possible."""
channel = bot.find_channel(message.channel_identifier)
if not isinstance(channel, discord.VoiceChannel):
raise NoneFoundError("Channel is not a voice channel")
loop.create_task(bot.bot.vc_connect_or_move(channel))
return RequestSuccessful()
class SummonCommand(Command): class SummonCommand(Command):
@ -11,6 +37,8 @@ class SummonCommand(Command):
command_description = "Evoca il bot in un canale vocale." command_description = "Evoca il bot in un canale vocale."
command_syntax = "[channelname]" command_syntax = "[channelname]"
network_handlers = [SummonNH]
@classmethod @classmethod
async def common(cls, call: Call): async def common(cls, call: Call):
channel_name: str = call.args[0].lstrip("#") channel_name: str = call.args[0].lstrip("#")

View file

@ -1,4 +1,5 @@
from .alchemy import Alchemy from .alchemy import Alchemy
from .relationshiplinkchain import relationshiplinkchain from .relationshiplinkchain import relationshiplinkchain
from .databaseconfig import DatabaseConfig
__all__ = ["Alchemy", "relationshiplinkchain"] __all__ = ["Alchemy", "relationshiplinkchain", "DatabaseConfig"]

View file

@ -4,7 +4,8 @@ from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager, asynccontextmanager from contextlib import contextmanager, asynccontextmanager
from ..utils import cdj, asyncify from ..utils import asyncify
from ..error import InvalidConfigError
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()

View file

@ -0,0 +1,13 @@
import typing
class DatabaseConfig:
def __init__(self,
database_uri: str,
master_table: typing.Type,
identity_table: typing.Type,
identity_column_name: str):
self.database_uri: str = database_uri
self.master_table: typing.Type = master_table
self.identity_table: typing.Type = identity_table
self.identity_column_name: str = identity_column_name

View file

@ -6,5 +6,7 @@ from .safeformat import safeformat
from .classdictjanitor import cdj from .classdictjanitor import cdj
from .sleepuntil import sleep_until from .sleepuntil import sleep_until
from .plusformat import plusformat from .plusformat import plusformat
from .networkhandler import NetworkHandler
__all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs"] __all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs",
"NetworkHandler"]

View file

@ -3,8 +3,7 @@ import asyncio
import logging import logging
from ..network.messages import Message from ..network.messages import Message
from .command import Command from .command import Command
from royalnet.utils import CommandArgs from .commandargs import CommandArgs
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from ..database import Alchemy from ..database import Alchemy
@ -57,10 +56,7 @@ class Call:
async def run(self): async def run(self):
await self.session_init() await self.session_init()
try: coroutine = getattr(self.command, self.interface_name)
coroutine = getattr(self.command, self.interface_name)
except AttributeError:
coroutine = getattr(self.command, "common")
try: try:
result = await coroutine(self) result = await coroutine(self)
finally: finally:

View file

@ -1,6 +1,9 @@
import typing import typing
from ..error import UnsupportedError
from ..network import Message
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from .call import Call from .call import Call
from ..utils import NetworkHandler
class Command: class Command:
@ -12,5 +15,21 @@ class Command:
require_alchemy_tables: typing.Set = set() require_alchemy_tables: typing.Set = set()
async def common(self, call: "Call"): network_handlers: typing.List[typing.Type["NetworkHandler"]] = {}
raise NotImplementedError()
@classmethod
async def common(cls, call: "Call"):
raise UnsupportedError()
@classmethod
def network_handler_dict(cls):
d = {}
for network_handler in cls.network_handlers:
d[network_handler.message_type] = network_handler
return d
def __getattribute__(self, item: str):
try:
return self.__dict__[item]
except KeyError:
return self.common

View file

@ -35,4 +35,4 @@ class CommandArgs(list):
try: try:
return self[index] return self[index]
except InvalidInputError: except InvalidInputError:
return default return default

View file

@ -0,0 +1,14 @@
from ..network import Message
from ..error import UnsupportedError
class NetworkHandler:
"""The NetworkHandler functions are called when a specific Message type is received."""
message_type = NotImplemented
def __getattribute__(self, item: str):
try:
return self.__dict__[item]
except KeyError:
raise UnsupportedError()