diff --git a/royalnet/network/__init__.py b/royalnet/network/__init__.py index 42227ee1..b774f370 100644 --- a/royalnet/network/__init__.py +++ b/royalnet/network/__init__.py @@ -1,23 +1,15 @@ """Royalnet realated classes.""" -from .messages import Message, ServerErrorMessage, InvalidSecretEM, InvalidDestinationEM, InvalidPackageEM, RequestSuccessful, RequestError, Reply from .packages import Package -from .royalnetlink import RoyalnetLink, NetworkError, NotConnectedError, NotIdentifiedError +from .royalnetlink import RoyalnetLink, NetworkError, NotConnectedError, NotIdentifiedError, ConnectionClosedError from .royalnetserver import RoyalnetServer from .royalnetconfig import RoyalnetConfig -__all__ = ["Message", - "ServerErrorMessage", - "InvalidSecretEM", - "InvalidDestinationEM", - "InvalidPackageEM", - "RoyalnetLink", +__all__ = ["RoyalnetLink", "NetworkError", "NotConnectedError", "NotIdentifiedError", "Package", "RoyalnetServer", - "RequestSuccessful", - "RequestError", "RoyalnetConfig", - "Reply"] + "ConnectionClosedError"] diff --git a/royalnet/network/messages.py b/royalnet/network/messages.py deleted file mode 100644 index 132231f2..00000000 --- a/royalnet/network/messages.py +++ /dev/null @@ -1,78 +0,0 @@ -import typing -import pickle -from ..error import RoyalnetError - - -class Message: - """A message sent through the Royalnet.""" - def __repr__(self): - return f"<{self.__class__.__name__}>" - - -class IdentifySuccessfulMessage(Message): - """The Royalnet identification step was successful.""" - - -class ServerErrorMessage(Message): - """Something went wrong in the connection to the :py:class:`royalnet.network.RoyalnetServer`.""" - def __init__(self, reason): - super().__init__() - self.reason = reason - - -class InvalidSecretEM(ServerErrorMessage): - """The sent secret was incorrect. - - This message terminates connection to the :py:class:`royalnet.network.RoyalnetServer`.""" - - -class InvalidPackageEM(ServerErrorMessage): - """The sent :py:class:`royalnet.network.Package` was invalid.""" - - -class InvalidDestinationEM(InvalidPackageEM): - """The :py:class:`royalnet.network.Package` destination was invalid or not found.""" - - -class Reply(Message): - """A reply to a request sent through the Royalnet.""" - - def raise_on_error(self) -> None: - """If the reply is an error, raise an error, otherwise, do nothing. - - Raises: - A :py:exc:`RoyalnetError`, if the Reply is an error, otherwise, nothing.""" - raise NotImplementedError() - - -class RequestSuccessful(Reply): - """The sent request was successful.""" - - def raise_on_error(self) -> None: - """If the reply is an error, raise an error, otherwise, do nothing. - - Does nothing.""" - pass - - -class RequestError(Reply): - """The sent request wasn't successful.""" - - def __init__(self, exc: typing.Optional[Exception] = None): - """Create a RequestError. - - Parameters: - exc: The exception that caused the error in the request.""" - try: - pickle.dumps(exc) - except TypeError: - self.exc: Exception = Exception(repr(exc)) - else: - self.exc = exc - - def raise_on_error(self) -> None: - """If the reply is an error, raise an error, otherwise, do nothing. - - Raises: - Always raises a :py:exc:`royalnet.error.RoyalnetError`, containing the exception that caused the error.""" - raise RoyalnetError(exc=self.exc) diff --git a/royalnet/network/packages.py b/royalnet/network/packages.py index 766499b5..00a02bdf 100644 --- a/royalnet/network/packages.py +++ b/royalnet/network/packages.py @@ -7,7 +7,7 @@ class Package: """A Royalnet package, the data type with which a :py:class:`royalnet.network.RoyalnetLink` communicates with a :py:class:`royalnet.network.RoyalnetServer` or another link. """ def __init__(self, - data: typing.Union[None, int, float, str, list, dict], + data: dict, *, source: str, destination: str, @@ -22,7 +22,7 @@ class Package: source_conv_id: The conversation id of the node that created this package. Akin to the sequence number on IP packets. destination_conv_id: The conversation id of the node that this Package is a reply to.""" # TODO: something is not right in these type hints. Check them. - self.data: typing.Union[None, int, float, str, list, dict] = data + self.data: dict = data self.source: str = source self.source_conv_id: str = source_conv_id or str(uuid.uuid4()) self.destination: str = destination diff --git a/royalnet/network/royalnetlink.py b/royalnet/network/royalnetlink.py index 5edab494..f7ec976b 100644 --- a/royalnet/network/royalnetlink.py +++ b/royalnet/network/royalnetlink.py @@ -2,10 +2,10 @@ import asyncio import websockets import uuid import functools -import typing -import pickle +import math +import numbers import logging as _logging -from .messages import Message, ServerErrorMessage, RequestError +import typing from .packages import Package default_loop = asyncio.get_event_loop() @@ -20,16 +20,20 @@ class NotIdentifiedError(Exception): """The :py:class:`royalnet.network.RoyalnetLink` has not identified yet to a :py:class:`royalnet.network.RoyalnetServer`.""" +class ConnectionClosedError(Exception): + """The :py:class:`royalnet.network.RoyalnetLink`'s connection was closed unexpectedly. The link can't be used anymore.""" + + class NetworkError(Exception): - def __init__(self, error_msg: ServerErrorMessage, *args): + def __init__(self, error_data: dict, *args): super().__init__(*args) - self.error_msg: ServerErrorMessage = error_msg + self.error_data: dict = error_data class PendingRequest: def __init__(self, *, loop=default_loop): self.event: asyncio.Event = asyncio.Event(loop=loop) - self.data: typing.Optional[Message] = None + self.data: typing.Optional[dict] = None def __repr__(self): if self.event.is_set(): @@ -44,7 +48,7 @@ class PendingRequest: def requires_connection(func): @functools.wraps(func) async def new_func(self, *args, **kwargs): - await self._connect_event.wait() + await self.connect_event.wait() return await func(self, *args, **kwargs) return new_func @@ -67,29 +71,35 @@ class RoyalnetLink: self.secret: str = secret self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None self.request_handler = request_handler - self._pending_requests: typing.Dict[str, typing.Optional[Message]] = {} + self._pending_requests: typing.Dict[str, PendingRequest] = {} self._loop: asyncio.AbstractEventLoop = loop - self._connect_event: asyncio.Event = asyncio.Event(loop=self._loop) + self.error_event: asyncio.Event = asyncio.Event(loop=self._loop) + self.connect_event: asyncio.Event = asyncio.Event(loop=self._loop) self.identify_event: asyncio.Event = asyncio.Event(loop=self._loop) async def connect(self): + """Connect to the :py:class:`royalnet.network.RoyalnetServer` at ``self.master_uri``.""" log.info(f"Connecting to {self.master_uri}...") self.websocket = await websockets.connect(self.master_uri, loop=self._loop) - self._connect_event.set() + self.connect_event.set() log.info(f"Connected!") @requires_connection async def receive(self) -> Package: + """Recieve a :py:class:`Package` from the :py:class:`royalnet.network.RoyalnetServer`. + + Raises: + :py:exc:`royalnet.network.royalnetlink.ConnectionClosedError` if the connection closes.""" try: - raw_pickle = await self.websocket.recv() + jbytes = await self.websocket.recv() + package: Package = Package.from_json_bytes(jbytes) except websockets.ConnectionClosed: - self.websocket = None - self._connect_event.clear() + self.error_event.set() + self.connect_event.clear() self.identify_event.clear() log.info(f"Connection to {self.master_uri} was closed.") # What to do now? Let's just reraise. - raise - package: typing.Union[Package, Package] = pickle.loads(raw_pickle) + raise ConnectionClosedError("") assert package.destination == self.nid log.debug(f"Received package: {package}") return package @@ -107,13 +117,12 @@ class RoyalnetLink: @requires_identification async def send(self, package: Package): - raw_pickle: bytes = pickle.dumps(package) - await self.websocket.send(raw_pickle) + await self.websocket.send(package.to_json_bytes()) log.debug(f"Sent package: {package}") @requires_identification async def request(self, message, destination): - package = Package(message, destination, self.nid) + package = Package(message, source=self.nid, destination=destination) request = PendingRequest(loop=self._loop) self._pending_requests[package.source_conv_id] = request await self.send(package) @@ -125,10 +134,14 @@ class RoyalnetLink: raise NetworkError(result, "Server returned error while requesting something") return result - async def run(self): + async def run(self, loops: numbers.Real = math.inf): + """Blockingly run the Link.""" log.debug(f"Running main client loop for {self.nid}.") - while True: - if self.websocket is None: + if self.error_event.is_set(): + raise ConnectionClosedError("RoyalnetLinks can't be rerun after an error.") + while loops: + loops -= 1 + if not self.connect_event.is_set(): await self.connect() if not self.identify_event.is_set(): await self.identify() diff --git a/royalnet/network/royalnetserver.py b/royalnet/network/royalnetserver.py index c61166d4..2a55f14a 100644 --- a/royalnet/network/royalnetserver.py +++ b/royalnet/network/royalnetserver.py @@ -6,7 +6,6 @@ import pickle import uuid import asyncio import logging as _logging -from .messages import InvalidPackageEM, InvalidSecretEM, IdentifySuccessfulMessage from .packages import Package default_loop = asyncio.get_event_loop() @@ -26,9 +25,14 @@ class ConnectedClient: """Has the client sent a valid identification package?""" return bool(self.nid) + async def send_server_error(self, reason: str): + await self.send(Package({"error": reason}, + source="", + destination=self.nid)) + async def send(self, package: Package): """Send a :py:class:`royalnet.network.Package` to the :py:class:`royalnet.network.RoyalnetLink`.""" - await self.socket.send(package.pickle()) + await self.socket.send(package.to_json_bytes()) class RoyalnetServer: @@ -49,22 +53,22 @@ class RoyalnetServer: matching = [client for client in self.identified_clients if client.link_type == link_type] return matching or [] - async def listener(self, websocket: websockets.server.WebSocketServerProtocol, request_uri: str): + async def listener(self, websocket: websockets.server.WebSocketServerProtocol): log.info(f"{websocket.remote_address} connected to the server.") connected_client = ConnectedClient(websocket) # Wait for identification identify_msg = await websocket.recv() log.debug(f"{websocket.remote_address} identified itself with: {identify_msg}.") if not isinstance(identify_msg, str): - await websocket.send(InvalidPackageEM("Invalid identification message (not a str)")) + await websocket.send(connected_client.send_server_error("Invalid identification message (not a str)")) return identification = re.match(r"Identify ([^:\s]+):([^:\s]+):([^:\s]+)", identify_msg) if identification is None: - await websocket.send(InvalidPackageEM("Invalid identification message (regex failed)")) + await websocket.send(connected_client.send_server_error("Invalid identification message (regex failed)")) return secret = identification.group(3) if secret != self.required_secret: - await websocket.send(InvalidSecretEM("Invalid secret")) + await websocket.send(connected_client.send_server_error("Invalid secret")) return # Identification successful connected_client.nid = identification.group(1) diff --git a/tests/test_network.py b/tests/test_network.py index a37600ee..ab1d24ef 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -1,10 +1,23 @@ import pytest import uuid -from royalnet.network.packages import Package +import asyncio +from royalnet.network import Package, RoyalnetLink, RoyalnetServer, ConnectionClosedError + + +@pytest.fixture +def async_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.mark.skip("Not a test") +def echo_request_handler(message): + return message def test_package_serialization(): - pkg = Package("ciao", + pkg = Package({"ciao": "ciao"}, source=str(uuid.uuid4()), destination=str(uuid.uuid4()), source_conv_id=str(uuid.uuid4()), @@ -12,3 +25,21 @@ def test_package_serialization(): assert pkg == Package.from_dict(pkg.to_dict()) assert pkg == Package.from_json_string(pkg.to_json_string()) assert pkg == Package.from_json_bytes(pkg.to_json_bytes()) + + +def test_links(async_loop: asyncio.AbstractEventLoop): + address, port = "127.0.0.1", 1234 + master = RoyalnetServer(address, port, "test") + async_loop.run_until_complete(master.start()) + # Test invalid secret + wrong_secret_link = RoyalnetLink("ws://127.0.0.1:1234", "invalid", "test", echo_request_handler, loop=async_loop) + with pytest.raises(ConnectionClosedError): + async_loop.run_until_complete(wrong_secret_link.run()) + # Test regular connection + link1 = RoyalnetLink("ws://127.0.0.1:1234", "test", "one", echo_request_handler, loop=async_loop) + link1_run_task = async_loop.create_task(link1.run()) + link2 = RoyalnetLink("ws://127.0.0.1:1234", "test", "two", echo_request_handler, loop=async_loop) + link2_run_task = async_loop.create_task(link2.run()) + message = {"ciao": "ciao"} + response = async_loop.run_until_complete(link1.request(message, "two")) + assert message == response