diff --git a/royalnet/__init__.py b/royalnet/__init__.py deleted file mode 100644 index 3a58404e..00000000 --- a/royalnet/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from . import audio, \ - bots, \ - commands, \ - database, \ - network, \ - utils, \ - error - -version = "5.0a7" - -__all__ = ["audio", - "bots", - "commands", - "database", - "network", - "utils", - "error"] diff --git a/royalnet/commands/play.py b/royalnet/commands/play.py index 6d187024..21f0908a 100644 --- a/royalnet/commands/play.py +++ b/royalnet/commands/play.py @@ -3,7 +3,7 @@ import asyncio import youtube_dl import ffmpeg from ..utils import Command, Call, NetworkHandler, asyncify -from ..network import Message, RequestSuccessful +from ..network import Request, Data from ..error import TooManyFoundError, NoneFoundError from ..audio import RoyalPCMAudio, YtdlInfo if typing.TYPE_CHECKING: @@ -13,14 +13,16 @@ if typing.TYPE_CHECKING: loop = asyncio.get_event_loop() -class PlayMessage(Message): +class PlayMessage(Data): def __init__(self, url: str, guild_name: typing.Optional[str] = None): + super().__init__() self.url: str = url self.guild_name: typing.Optional[str] = guild_name -class PlaySuccessful(RequestSuccessful): +class PlaySuccessful(Data): def __init__(self, info_list: typing.List[YtdlInfo]): + super().__init__() self.info_list: typing.List[YtdlInfo] = info_list diff --git a/royalnet/network/__init__.py b/royalnet/network/__init__.py index b774f370..68e2f3c4 100644 --- a/royalnet/network/__init__.py +++ b/royalnet/network/__init__.py @@ -1,6 +1,6 @@ """Royalnet realated classes.""" - -from .packages import Package +from .data import Data, Request +from .package import Package from .royalnetlink import RoyalnetLink, NetworkError, NotConnectedError, NotIdentifiedError, ConnectionClosedError from .royalnetserver import RoyalnetServer from .royalnetconfig import RoyalnetConfig @@ -12,4 +12,6 @@ __all__ = ["RoyalnetLink", "Package", "RoyalnetServer", "RoyalnetConfig", - "ConnectionClosedError"] + "ConnectionClosedError", + "Data", + "Request"] diff --git a/royalnet/network/data.py b/royalnet/network/data.py new file mode 100644 index 00000000..63895037 --- /dev/null +++ b/royalnet/network/data.py @@ -0,0 +1,25 @@ +class Data: + """Royalnet data. All fields in this class will be converted to a dict when about to be sent.""" + def __init__(self): + pass + + def to_dict(self): + return self.__dict__ + + +class Request(Data): + """A Royalnet request. It contains the name of the requested handler, in addition to the data.""" + + def __init__(self, handler: str, data: dict): + super().__init__() + self.handler: str = handler + self.data: dict = data + + @staticmethod + def from_dict(d: dict): + return Request(**d) + + def __eq__(self, other): + if isinstance(other, Request): + return self.handler == other.handler and self.data == other.data + return False diff --git a/royalnet/network/message.py b/royalnet/network/message.py deleted file mode 100644 index 091f8249..00000000 --- a/royalnet/network/message.py +++ /dev/null @@ -1,7 +0,0 @@ -from ..utils import classdictjanitor - - -class Message: - """A Royalnet message. All fields of this class will be converted in a dict.""" - - # idk use classdictjanitor diff --git a/royalnet/network/packages.py b/royalnet/network/package.py similarity index 97% rename from royalnet/network/packages.py rename to royalnet/network/package.py index 00a02bdf..282879ef 100644 --- a/royalnet/network/packages.py +++ b/royalnet/network/package.py @@ -4,7 +4,8 @@ import typing class Package: - """A Royalnet package, the data type with which a :py:class:`royalnet.network.RoyalnetLink` communicates with a :py:class:`royalnet.network.RoyalnetServer` or another link. """ + """A Royalnet package, the data type with which a :py:class:`royalnet.network.RoyalnetLink` communicates with a :py:class:`royalnet.network.RoyalnetServer` or another link. + Contains info about the source and the destination.""" def __init__(self, data: dict, diff --git a/royalnet/network/royalnetlink.py b/royalnet/network/royalnetlink.py index ab673c45..439bd7aa 100644 --- a/royalnet/network/royalnetlink.py +++ b/royalnet/network/royalnetlink.py @@ -6,7 +6,7 @@ import math import numbers import logging as _logging import typing -from .packages import Package +from .package import Package default_loop = asyncio.get_event_loop() log = _logging.getLogger(__name__) @@ -24,6 +24,10 @@ class ConnectionClosedError(Exception): """The :py:class:`royalnet.network.RoyalnetLink`'s connection was closed unexpectedly. The link can't be used anymore.""" +class InvalidServerResponseError(Exception): + """The :py:class:`royalnet.network.RoyalnetServer` sent invalid data to the :py:class:`royalnet.network.RoyalnetLink`.""" + + class NetworkError(Exception): def __init__(self, error_data: dict, *args): super().__init__(*args) @@ -100,7 +104,8 @@ class RoyalnetLink: log.info(f"Connection to {self.master_uri} was closed.") # What to do now? Let's just reraise. raise ConnectionClosedError() - assert package.destination == self.nid + if self.identify_event.is_set() and package.destination != self.nid: + raise InvalidServerResponseError("Package is not addressed to this RoyalnetLink.") log.debug(f"Received package: {package}") return package @@ -109,10 +114,13 @@ class RoyalnetLink: log.info(f"Identifying to {self.master_uri}...") await self.websocket.send(f"Identify {self.nid}:{self.link_type}:{self.secret}") response: Package = await self.receive() - assert response.source == "" - if "error" in response.data: - raise ConnectionClosedError(f"Identification error: {response.data['error']}") - assert "success" in response.data + if not response.source == "": + raise InvalidServerResponseError("Received a non-service package before identification.") + if "type" not in response.data: + raise InvalidServerResponseError("Missing 'type' in response data") + if response.data["type"] == "error": + raise ConnectionClosedError(f"Identification error: {response.data['type']}") + assert response.data["type"] == "success" self.identify_event.set() log.info(f"Identified successfully!") diff --git a/royalnet/network/royalnetserver.py b/royalnet/network/royalnetserver.py index 72e52c4b..1b328e8a 100644 --- a/royalnet/network/royalnetserver.py +++ b/royalnet/network/royalnetserver.py @@ -5,7 +5,7 @@ import datetime import uuid import asyncio import logging as _logging -from .packages import Package +from .package import Package default_loop = asyncio.get_event_loop() log = _logging.getLogger(__name__) @@ -52,22 +52,22 @@ class RoyalnetServer: matching = [client for client in self.identified_clients if client.link_type == link_type] return matching or [] - async def listener(self, websocket: websockets.server.WebSocketServerProtocol): + async def listener(self, websocket: websockets.server.WebSocketServerProtocol, path): log.info(f"{websocket.remote_address} connected to the server.") connected_client = ConnectedClient(websocket) # Wait for identification identify_msg = await websocket.recv() log.debug(f"{websocket.remote_address} identified itself with: {identify_msg}.") if not isinstance(identify_msg, str): - await websocket.send(connected_client.send_service("error", "Invalid identification message (not a str)")) + await connected_client.send_service("error", "Invalid identification message (not a str)") return identification = re.match(r"Identify ([^:\s]+):([^:\s]+):([^:\s]+)", identify_msg) if identification is None: - await websocket.send(connected_client.send_service("error", "Invalid identification message (regex failed)")) + await connected_client.send_service("error", "Invalid identification message (regex failed)") return secret = identification.group(3) if secret != self.required_secret: - await websocket.send(connected_client.send_service("error", "Invalid secret")) + await connected_client.send_service("error", "Invalid secret") return # Identification successful connected_client.nid = identification.group(1) diff --git a/tests/test_network.py b/tests/test_network.py index ab1d24ef..48c85ba1 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -1,7 +1,15 @@ import pytest import uuid import asyncio -from royalnet.network import Package, RoyalnetLink, RoyalnetServer, ConnectionClosedError +import logging +from royalnet.network import Package, RoyalnetLink, RoyalnetServer, ConnectionClosedError, Request + + +log = logging.root +stream_handler = logging.StreamHandler() +stream_handler.formatter = logging.Formatter("{asctime}\t{name}\t{levelname}\t{message}", style="{") +log.addHandler(stream_handler) +log.setLevel(logging.WARNING) @pytest.fixture @@ -11,8 +19,7 @@ def async_loop(): loop.close() -@pytest.mark.skip("Not a test") -def echo_request_handler(message): +async def echo_request_handler(message): return message @@ -27,19 +34,24 @@ def test_package_serialization(): assert pkg == Package.from_json_bytes(pkg.to_json_bytes()) +def test_request_creation(): + request = Request("pytest", {"testing": "is fun", "bugs": "are less fun"}) + assert request == Request.from_dict(request.to_dict()) + + def test_links(async_loop: asyncio.AbstractEventLoop): - address, port = "127.0.0.1", 1234 + address, port = "127.0.0.1", 1235 master = RoyalnetServer(address, port, "test") async_loop.run_until_complete(master.start()) # Test invalid secret - wrong_secret_link = RoyalnetLink("ws://127.0.0.1:1234", "invalid", "test", echo_request_handler, loop=async_loop) + wrong_secret_link = RoyalnetLink("ws://127.0.0.1:1235", "invalid", "test", echo_request_handler, loop=async_loop) with pytest.raises(ConnectionClosedError): async_loop.run_until_complete(wrong_secret_link.run()) # Test regular connection - link1 = RoyalnetLink("ws://127.0.0.1:1234", "test", "one", echo_request_handler, loop=async_loop) - link1_run_task = async_loop.create_task(link1.run()) - link2 = RoyalnetLink("ws://127.0.0.1:1234", "test", "two", echo_request_handler, loop=async_loop) - link2_run_task = async_loop.create_task(link2.run()) + link1 = RoyalnetLink("ws://127.0.0.1:1235", "test", "one", echo_request_handler, loop=async_loop) + async_loop.create_task(link1.run()) + link2 = RoyalnetLink("ws://127.0.0.1:1235", "test", "two", echo_request_handler, loop=async_loop) + async_loop.create_task(link2.run()) message = {"ciao": "ciao"} response = async_loop.run_until_complete(link1.request(message, "two")) assert message == response