mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 19:44:20 +00:00
A lot of stuff
This commit is contained in:
parent
bad35daf5b
commit
03b3d52f0e
13 changed files with 177 additions and 88 deletions
|
@ -6,6 +6,7 @@ from royalnet.commands import *
|
||||||
from royalnet.commands.debug_create import DebugCreateCommand
|
from royalnet.commands.debug_create import DebugCreateCommand
|
||||||
from royalnet.commands.error_handler import ErrorHandlerCommand
|
from royalnet.commands.error_handler import ErrorHandlerCommand
|
||||||
from royalnet.network import RoyalnetServer
|
from royalnet.network import RoyalnetServer
|
||||||
|
from royalnet.database import DatabaseConfig
|
||||||
from royalnet.database.tables import Royal, Telegram, Discord
|
from royalnet.database.tables import Royal, Telegram, Discord
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
@ -20,11 +21,13 @@ commands = [PingCommand, ShipCommand, SmecdsCommand, ColorCommand, CiaoruoziComm
|
||||||
KvrollCommand, VideoinfoCommand, SummonCommand, PlayCommand]
|
KvrollCommand, VideoinfoCommand, SummonCommand, PlayCommand]
|
||||||
|
|
||||||
master = RoyalnetServer("localhost", 1234, "sas")
|
master = RoyalnetServer("localhost", 1234, "sas")
|
||||||
# tg_bot = TelegramBot(os.environ["TG_AK"], "ws://localhost:1234", "sas", commands, os.environ["DB_PATH"], Royal, Telegram, "tg_id", error_command=ErrorHandlerCommand)
|
tg_db_cfg = DatabaseConfig(os.environ["DB_PATH"], Royal, Telegram, "tg_id")
|
||||||
ds_bot = DiscordBot(os.environ["DS_AK"], "ws://localhost:1234", "sas", commands, os.environ["DB_PATH"], Royal, Discord, "discord_id", error_command=ErrorHandlerCommand)
|
tg_bot = TelegramBot(os.environ["TG_AK"], "ws://localhost:1234", "sas", commands, NullCommand, ErrorHandlerCommand, tg_db_cfg)
|
||||||
|
ds_db_cfg = DatabaseConfig(os.environ["DB_PATH"], Royal, Discord, "discord_id")
|
||||||
|
ds_bot = DiscordBot(os.environ["DS_AK"], "ws://localhost:1234", "sas", commands, NullCommand, ErrorHandlerCommand, ds_db_cfg)
|
||||||
loop.run_until_complete(master.run())
|
loop.run_until_complete(master.run())
|
||||||
# Dirty hack, remove me asap
|
# Dirty hack, remove me asap
|
||||||
# loop.create_task(tg_bot.run())
|
loop.create_task(tg_bot.run())
|
||||||
loop.create_task(ds_bot.run())
|
loop.create_task(ds_bot.run())
|
||||||
print("Starting loop...")
|
print("Starting loop...")
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
|
|
|
@ -4,11 +4,11 @@ import typing
|
||||||
import logging as _logging
|
import logging as _logging
|
||||||
import sys
|
import sys
|
||||||
from ..commands import NullCommand
|
from ..commands import NullCommand
|
||||||
from ..utils import asyncify, Call, Command
|
from ..utils import asyncify, Call, Command, NetworkHandler
|
||||||
from ..error import UnregisteredError, NoneFoundError, TooManyFoundError
|
from ..error import UnregisteredError, NoneFoundError, TooManyFoundError, InvalidConfigError
|
||||||
from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError
|
from ..network import RoyalnetLink, Message, RequestSuccessful, RequestError
|
||||||
from ..database import Alchemy, relationshiplinkchain
|
from ..database import Alchemy, relationshiplinkchain, DatabaseConfig
|
||||||
from ..audio import RoyalPCMFile, PlayMode, Playlist, Pool
|
from ..audio import RoyalPCMFile, PlayMode, Playlist
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
log = _logging.getLogger(__name__)
|
log = _logging.getLogger(__name__)
|
||||||
|
@ -18,31 +18,15 @@ if not discord.opus.is_loaded():
|
||||||
log.error("Opus is not loaded. Weird behaviour might emerge.")
|
log.error("Opus is not loaded. Weird behaviour might emerge.")
|
||||||
|
|
||||||
|
|
||||||
class PlayMessage(Message):
|
|
||||||
def __init__(self, url: str, guild_identifier: typing.Optional[str] = None):
|
|
||||||
self.url: str = url
|
|
||||||
self.guild_identifier: typing.Optional[str] = guild_identifier
|
|
||||||
|
|
||||||
|
|
||||||
class SummonMessage(Message):
|
|
||||||
def __init__(self, channel_identifier: typing.Union[int, str],
|
|
||||||
guild_identifier: typing.Optional[typing.Union[int, str]]):
|
|
||||||
self.channel_identifier = channel_identifier
|
|
||||||
self.guild_identifier = guild_identifier
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordBot:
|
class DiscordBot:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
token: str,
|
token: str,
|
||||||
master_server_uri: str,
|
master_server_uri: str,
|
||||||
master_server_secret: str,
|
master_server_secret: str,
|
||||||
commands: typing.List[typing.Type[Command]],
|
commands: typing.List[typing.Type[Command]],
|
||||||
database_uri: str,
|
|
||||||
master_table,
|
|
||||||
identity_table,
|
|
||||||
identity_column_name: str,
|
|
||||||
missing_command: typing.Type[Command] = NullCommand,
|
missing_command: typing.Type[Command] = NullCommand,
|
||||||
error_command: typing.Type[Command] = NullCommand):
|
error_command: typing.Type[Command] = NullCommand,
|
||||||
|
database_config: typing.Optional[DatabaseConfig] = None):
|
||||||
self.token = token
|
self.token = token
|
||||||
# Generate commands
|
# Generate commands
|
||||||
self.missing_command = missing_command
|
self.missing_command = missing_command
|
||||||
|
@ -52,12 +36,25 @@ class DiscordBot:
|
||||||
for command in commands:
|
for command in commands:
|
||||||
self.commands[f"!{command.command_name}"] = command
|
self.commands[f"!{command.command_name}"] = command
|
||||||
required_tables = required_tables.union(command.require_alchemy_tables)
|
required_tables = required_tables.union(command.require_alchemy_tables)
|
||||||
|
# Generate network handlers
|
||||||
|
self.network_handlers: typing.Dict[typing.Type[Message], typing.Type[NetworkHandler]] = {}
|
||||||
|
for command in commands:
|
||||||
|
self.network_handlers = {**self.network_handlers, **command.network_handler_dict()}
|
||||||
# Generate the Alchemy database
|
# Generate the Alchemy database
|
||||||
self.alchemy = Alchemy(database_uri, required_tables)
|
if database_config:
|
||||||
self.master_table = self.alchemy.__getattribute__(master_table.__name__)
|
self.alchemy = Alchemy(database_config.database_uri, required_tables)
|
||||||
self.identity_table = self.alchemy.__getattribute__(identity_table.__name__)
|
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
|
||||||
self.identity_column = self.identity_table.__getattribute__(self.identity_table, identity_column_name)
|
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
|
||||||
|
self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name)
|
||||||
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
||||||
|
else:
|
||||||
|
if required_tables:
|
||||||
|
raise InvalidConfigError("Tables are required by the commands, but Alchemy is not configured")
|
||||||
|
self.alchemy = None
|
||||||
|
self.master_table = None
|
||||||
|
self.identity_table = None
|
||||||
|
self.identity_column = None
|
||||||
|
self.identity_chain = None
|
||||||
# Connect to Royalnet
|
# Connect to Royalnet
|
||||||
self.network: RoyalnetLink = RoyalnetLink(master_server_uri, master_server_secret, "discord",
|
self.network: RoyalnetLink = RoyalnetLink(master_server_uri, master_server_secret, "discord",
|
||||||
self.network_handler)
|
self.network_handler)
|
||||||
|
@ -70,6 +67,7 @@ class DiscordBot:
|
||||||
interface_name = "discord"
|
interface_name = "discord"
|
||||||
interface_obj = self
|
interface_obj = self
|
||||||
interface_prefix = "!"
|
interface_prefix = "!"
|
||||||
|
|
||||||
alchemy = self.alchemy
|
alchemy = self.alchemy
|
||||||
|
|
||||||
async def reply(call, text: str):
|
async def reply(call, text: str):
|
||||||
|
@ -213,18 +211,7 @@ class DiscordBot:
|
||||||
async def network_handler(self, message: Message) -> Message:
|
async def network_handler(self, message: Message) -> Message:
|
||||||
"""Handle a Royalnet request."""
|
"""Handle a Royalnet request."""
|
||||||
log.debug(f"Received {message} from Royalnet")
|
log.debug(f"Received {message} from Royalnet")
|
||||||
if isinstance(message, SummonMessage):
|
return await self.network_handlers[message.__class__].discord(message)
|
||||||
return await self.nh_summon(message)
|
|
||||||
elif isinstance(message, PlayMessage):
|
|
||||||
return await self.nh_play(message)
|
|
||||||
|
|
||||||
async def nh_summon(self, message: SummonMessage):
|
|
||||||
"""Handle a summon Royalnet request. That is, join a voice channel, or move to a different one if that is not possible."""
|
|
||||||
channel = self.find_channel(message.channel_identifier)
|
|
||||||
if not isinstance(channel, discord.VoiceChannel):
|
|
||||||
raise NoneFoundError("Channel is not a voice channel")
|
|
||||||
loop.create_task(self.bot.vc_connect_or_move(channel))
|
|
||||||
return RequestSuccessful()
|
|
||||||
|
|
||||||
async def add_to_music_data(self, url: str, guild: discord.Guild):
|
async def add_to_music_data(self, url: str, guild: discord.Guild):
|
||||||
"""Add a file to the corresponding music_data object."""
|
"""Add a file to the corresponding music_data object."""
|
||||||
|
@ -257,23 +244,6 @@ class DiscordBot:
|
||||||
log.debug(f"Starting playback of {next_source}")
|
log.debug(f"Starting playback of {next_source}")
|
||||||
voice_client.play(next_source, after=advance)
|
voice_client.play(next_source, after=advance)
|
||||||
|
|
||||||
async def nh_play(self, message: PlayMessage):
|
|
||||||
"""Handle a play Royalnet request. That is, add audio to a PlayMode."""
|
|
||||||
# Find the matching guild
|
|
||||||
if message.guild_identifier:
|
|
||||||
guild = self.find_guild(message.guild_identifier)
|
|
||||||
else:
|
|
||||||
if len(self.music_data) != 1:
|
|
||||||
raise TooManyFoundError("Multiple guilds found")
|
|
||||||
guild = list(self.music_data)[0]
|
|
||||||
# Ensure the guild has a PlayMode before adding the file to it
|
|
||||||
if not self.music_data.get(guild):
|
|
||||||
# TODO: change Exception
|
|
||||||
raise Exception("No music_data for this guild")
|
|
||||||
# Start downloading
|
|
||||||
loop.create_task(self.add_to_music_data(message.url, guild))
|
|
||||||
return RequestSuccessful()
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
await self.bot.login(self.token)
|
await self.bot.login(self.token)
|
||||||
await self.bot.connect()
|
await self.bot.connect()
|
||||||
|
|
|
@ -5,9 +5,9 @@ import logging as _logging
|
||||||
import sys
|
import sys
|
||||||
from ..commands import NullCommand
|
from ..commands import NullCommand
|
||||||
from ..utils import asyncify, Call, Command
|
from ..utils import asyncify, Call, Command
|
||||||
from ..error import UnregisteredError
|
from ..error import UnregisteredError, InvalidConfigError
|
||||||
from ..network import RoyalnetLink, Message, RequestError
|
from ..network import RoyalnetLink, Message, RequestError
|
||||||
from ..database import Alchemy, relationshiplinkchain
|
from ..database import Alchemy, relationshiplinkchain, DatabaseConfig
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
log = _logging.getLogger(__name__)
|
log = _logging.getLogger(__name__)
|
||||||
|
@ -23,12 +23,9 @@ class TelegramBot:
|
||||||
master_server_uri: str,
|
master_server_uri: str,
|
||||||
master_server_secret: str,
|
master_server_secret: str,
|
||||||
commands: typing.List[typing.Type[Command]],
|
commands: typing.List[typing.Type[Command]],
|
||||||
database_uri: str,
|
|
||||||
master_table,
|
|
||||||
identity_table,
|
|
||||||
identity_column_name: str,
|
|
||||||
missing_command: typing.Type[Command] = NullCommand,
|
missing_command: typing.Type[Command] = NullCommand,
|
||||||
error_command: typing.Type[Command] = NullCommand):
|
error_command: typing.Type[Command] = NullCommand,
|
||||||
|
database_config: typing.Optional[DatabaseConfig] = None):
|
||||||
self.bot: telegram.Bot = telegram.Bot(api_key)
|
self.bot: telegram.Bot = telegram.Bot(api_key)
|
||||||
self.should_run: bool = False
|
self.should_run: bool = False
|
||||||
self.offset: int = -100
|
self.offset: int = -100
|
||||||
|
@ -43,12 +40,20 @@ class TelegramBot:
|
||||||
self.commands[f"/{command.command_name}"] = command
|
self.commands[f"/{command.command_name}"] = command
|
||||||
required_tables = required_tables.union(command.require_alchemy_tables)
|
required_tables = required_tables.union(command.require_alchemy_tables)
|
||||||
# Generate the Alchemy database
|
# Generate the Alchemy database
|
||||||
self.alchemy = Alchemy(database_uri, required_tables)
|
if database_config:
|
||||||
self.master_table = self.alchemy.__getattribute__(master_table.__name__)
|
self.alchemy = Alchemy(database_config.database_uri, required_tables)
|
||||||
self.identity_table = self.alchemy.__getattribute__(identity_table.__name__)
|
self.master_table = self.alchemy.__getattribute__(database_config.master_table.__name__)
|
||||||
self.identity_column = self.identity_table.__getattribute__(self.identity_table, identity_column_name)
|
self.identity_table = self.alchemy.__getattribute__(database_config.identity_table.__name__)
|
||||||
|
self.identity_column = self.identity_table.__getattribute__(self.identity_table, database_config.identity_column_name)
|
||||||
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
self.identity_chain = relationshiplinkchain(self.master_table, self.identity_table)
|
||||||
|
else:
|
||||||
|
if required_tables:
|
||||||
|
raise InvalidConfigError("Tables are required by the commands, but Alchemy is not configured")
|
||||||
|
self.alchemy = None
|
||||||
|
self.master_table = None
|
||||||
|
self.identity_table = None
|
||||||
|
self.identity_column = None
|
||||||
|
self.identity_chain = None
|
||||||
# noinspection PyMethodParameters
|
# noinspection PyMethodParameters
|
||||||
class TelegramCall(Call):
|
class TelegramCall(Call):
|
||||||
interface_name = "telegram"
|
interface_name = "telegram"
|
||||||
|
|
|
@ -1,7 +1,42 @@
|
||||||
import typing
|
import typing
|
||||||
from ..utils import Command, Call
|
import discord
|
||||||
|
import asyncio
|
||||||
|
from ..utils import Command, Call, NetworkHandler
|
||||||
from ..network import Message, RequestSuccessful, RequestError
|
from ..network import Message, RequestSuccessful, RequestError
|
||||||
from ..bots.discord import PlayMessage
|
from ..error import TooManyFoundError
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from ..bots import DiscordBot
|
||||||
|
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
|
||||||
|
class PlayMessage(Message):
|
||||||
|
def __init__(self, url: str, guild_identifier: typing.Optional[str] = None):
|
||||||
|
self.url: str = url
|
||||||
|
self.guild_identifier: typing.Optional[str] = guild_identifier
|
||||||
|
|
||||||
|
|
||||||
|
class PlayNH(NetworkHandler):
|
||||||
|
message_type = PlayMessage
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def nh_play(cls, bot: "DiscordBot", message: PlayMessage):
|
||||||
|
"""Handle a play Royalnet request. That is, add audio to a PlayMode."""
|
||||||
|
# Find the matching guild
|
||||||
|
if message.guild_identifier:
|
||||||
|
guild = bot.find_guild(message.guild_identifier)
|
||||||
|
else:
|
||||||
|
if len(bot.music_data) != 1:
|
||||||
|
raise TooManyFoundError("Multiple guilds found")
|
||||||
|
guild = list(bot.music_data)[0]
|
||||||
|
# Ensure the guild has a PlayMode before adding the file to it
|
||||||
|
if not bot.music_data.get(guild):
|
||||||
|
# TODO: change Exception
|
||||||
|
raise Exception("No music_data for this guild")
|
||||||
|
# Start downloading
|
||||||
|
loop.create_task(bot.add_to_music_data(message.url, guild))
|
||||||
|
return RequestSuccessful()
|
||||||
|
|
||||||
|
|
||||||
class PlayCommand(Command):
|
class PlayCommand(Command):
|
||||||
|
@ -9,6 +44,8 @@ class PlayCommand(Command):
|
||||||
command_description = "Riproduce una canzone in chat vocale."
|
command_description = "Riproduce una canzone in chat vocale."
|
||||||
command_syntax = "[ [guild] ] (url)"
|
command_syntax = "[ [guild] ] (url)"
|
||||||
|
|
||||||
|
network_handlers = [PlayNH]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def common(cls, call: Call):
|
async def common(cls, call: Call):
|
||||||
guild, url = call.args.match(r"(?:\[(.+)])?\s*(\S+)\s*")
|
guild, url = call.args.match(r"(?:\[(.+)])?\s*(\S+)\s*")
|
||||||
|
|
|
@ -1,8 +1,34 @@
|
||||||
import typing
|
import typing
|
||||||
import discord
|
import discord
|
||||||
from ..utils import Command, Call
|
import asyncio
|
||||||
|
from ..utils import Command, Call, NetworkHandler
|
||||||
from ..network import Message, RequestSuccessful, RequestError
|
from ..network import Message, RequestSuccessful, RequestError
|
||||||
from ..bots.discord import SummonMessage
|
from ..error import NoneFoundError
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from ..bots import DiscordBot
|
||||||
|
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
|
||||||
|
class SummonMessage(Message):
|
||||||
|
def __init__(self, channel_identifier: typing.Union[int, str],
|
||||||
|
guild_identifier: typing.Optional[typing.Union[int, str]] = None):
|
||||||
|
self.channel_identifier = channel_identifier
|
||||||
|
self.guild_identifier = guild_identifier
|
||||||
|
|
||||||
|
|
||||||
|
class SummonNH(NetworkHandler):
|
||||||
|
message_type = SummonMessage
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def discord(cls, bot: "DiscordBot", message: SummonMessage):
|
||||||
|
"""Handle a summon Royalnet request. That is, join a voice channel, or move to a different one if that is not possible."""
|
||||||
|
channel = bot.find_channel(message.channel_identifier)
|
||||||
|
if not isinstance(channel, discord.VoiceChannel):
|
||||||
|
raise NoneFoundError("Channel is not a voice channel")
|
||||||
|
loop.create_task(bot.bot.vc_connect_or_move(channel))
|
||||||
|
return RequestSuccessful()
|
||||||
|
|
||||||
|
|
||||||
class SummonCommand(Command):
|
class SummonCommand(Command):
|
||||||
|
@ -11,6 +37,8 @@ class SummonCommand(Command):
|
||||||
command_description = "Evoca il bot in un canale vocale."
|
command_description = "Evoca il bot in un canale vocale."
|
||||||
command_syntax = "[channelname]"
|
command_syntax = "[channelname]"
|
||||||
|
|
||||||
|
network_handlers = [SummonNH]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def common(cls, call: Call):
|
async def common(cls, call: Call):
|
||||||
channel_name: str = call.args[0].lstrip("#")
|
channel_name: str = call.args[0].lstrip("#")
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from .alchemy import Alchemy
|
from .alchemy import Alchemy
|
||||||
from .relationshiplinkchain import relationshiplinkchain
|
from .relationshiplinkchain import relationshiplinkchain
|
||||||
|
from .databaseconfig import DatabaseConfig
|
||||||
|
|
||||||
__all__ = ["Alchemy", "relationshiplinkchain"]
|
__all__ = ["Alchemy", "relationshiplinkchain", "DatabaseConfig"]
|
||||||
|
|
|
@ -4,7 +4,8 @@ from sqlalchemy import create_engine
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from contextlib import contextmanager, asynccontextmanager
|
from contextlib import contextmanager, asynccontextmanager
|
||||||
from ..utils import cdj, asyncify
|
from ..utils import asyncify
|
||||||
|
from ..error import InvalidConfigError
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
|
13
royalnet/database/databaseconfig.py
Normal file
13
royalnet/database/databaseconfig.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
import typing
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConfig:
|
||||||
|
def __init__(self,
|
||||||
|
database_uri: str,
|
||||||
|
master_table: typing.Type,
|
||||||
|
identity_table: typing.Type,
|
||||||
|
identity_column_name: str):
|
||||||
|
self.database_uri: str = database_uri
|
||||||
|
self.master_table: typing.Type = master_table
|
||||||
|
self.identity_table: typing.Type = identity_table
|
||||||
|
self.identity_column_name: str = identity_column_name
|
|
@ -6,5 +6,7 @@ from .safeformat import safeformat
|
||||||
from .classdictjanitor import cdj
|
from .classdictjanitor import cdj
|
||||||
from .sleepuntil import sleep_until
|
from .sleepuntil import sleep_until
|
||||||
from .plusformat import plusformat
|
from .plusformat import plusformat
|
||||||
|
from .networkhandler import NetworkHandler
|
||||||
|
|
||||||
__all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs"]
|
__all__ = ["asyncify", "Call", "Command", "safeformat", "cdj", "sleep_until", "plusformat", "CommandArgs",
|
||||||
|
"NetworkHandler"]
|
||||||
|
|
|
@ -3,8 +3,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from ..network.messages import Message
|
from ..network.messages import Message
|
||||||
from .command import Command
|
from .command import Command
|
||||||
from royalnet.utils import CommandArgs
|
from .commandargs import CommandArgs
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from ..database import Alchemy
|
from ..database import Alchemy
|
||||||
|
|
||||||
|
@ -57,10 +56,7 @@ class Call:
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
await self.session_init()
|
await self.session_init()
|
||||||
try:
|
|
||||||
coroutine = getattr(self.command, self.interface_name)
|
coroutine = getattr(self.command, self.interface_name)
|
||||||
except AttributeError:
|
|
||||||
coroutine = getattr(self.command, "common")
|
|
||||||
try:
|
try:
|
||||||
result = await coroutine(self)
|
result = await coroutine(self)
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
import typing
|
import typing
|
||||||
|
from ..error import UnsupportedError
|
||||||
|
from ..network import Message
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from .call import Call
|
from .call import Call
|
||||||
|
from ..utils import NetworkHandler
|
||||||
|
|
||||||
|
|
||||||
class Command:
|
class Command:
|
||||||
|
@ -12,5 +15,21 @@ class Command:
|
||||||
|
|
||||||
require_alchemy_tables: typing.Set = set()
|
require_alchemy_tables: typing.Set = set()
|
||||||
|
|
||||||
async def common(self, call: "Call"):
|
network_handlers: typing.List[typing.Type["NetworkHandler"]] = {}
|
||||||
raise NotImplementedError()
|
|
||||||
|
@classmethod
|
||||||
|
async def common(cls, call: "Call"):
|
||||||
|
raise UnsupportedError()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def network_handler_dict(cls):
|
||||||
|
d = {}
|
||||||
|
for network_handler in cls.network_handlers:
|
||||||
|
d[network_handler.message_type] = network_handler
|
||||||
|
return d
|
||||||
|
|
||||||
|
def __getattribute__(self, item: str):
|
||||||
|
try:
|
||||||
|
return self.__dict__[item]
|
||||||
|
except KeyError:
|
||||||
|
return self.common
|
||||||
|
|
14
royalnet/utils/networkhandler.py
Normal file
14
royalnet/utils/networkhandler.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
from ..network import Message
|
||||||
|
from ..error import UnsupportedError
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkHandler:
|
||||||
|
"""The NetworkHandler functions are called when a specific Message type is received."""
|
||||||
|
|
||||||
|
message_type = NotImplemented
|
||||||
|
|
||||||
|
def __getattribute__(self, item: str):
|
||||||
|
try:
|
||||||
|
return self.__dict__[item]
|
||||||
|
except KeyError:
|
||||||
|
raise UnsupportedError()
|
Loading…
Reference in a new issue