diff --git a/royalnet/network/messages.py b/royalnet/network/messages.py index 6b13cfa2..99696f02 100644 --- a/royalnet/network/messages.py +++ b/royalnet/network/messages.py @@ -1,5 +1,6 @@ class Message: - pass + def __repr__(self): + return f"<{self.__class__.__name__}>" class IdentifySuccessfulMessage(Message): @@ -22,5 +23,3 @@ class InvalidPackageEM(ErrorMessage): class InvalidDestinationEM(InvalidPackageEM): pass - - diff --git a/royalnet/network/packages.py b/royalnet/network/packages.py index 9e2239c1..499ca0a4 100644 --- a/royalnet/network/packages.py +++ b/royalnet/network/packages.py @@ -6,12 +6,14 @@ class Package: def __init__(self, data, destination: str, source: str, *, conversation_id: str = None): self.data = data self.destination: str = destination - self.source, = source + self.source = source self.conversation_id = conversation_id or str(uuid.uuid4()) + def __repr__(self): + return f"" + def reply(self, data) -> "Package": return Package(data, self.source, self.destination, conversation_id=self.conversation_id) def pickle(self): return pickle.dumps(self) - diff --git a/royalnet/network/royalnetlink.py b/royalnet/network/royalnetlink.py index c40419b6..65aa68ab 100644 --- a/royalnet/network/royalnetlink.py +++ b/royalnet/network/royalnetlink.py @@ -1,13 +1,15 @@ import asyncio -from asyncio import Event import websockets import uuid import functools import typing import pickle +import logging from .messages import Message, ErrorMessage from .packages import Package + loop = asyncio.get_event_loop() +log = logging.getLogger(__name__) class NotConnectedError(Exception): @@ -21,19 +23,40 @@ class NotIdentifiedError(Exception): class NetworkError(Exception): def __init__(self, error_msg: ErrorMessage, *args): super().__init__(*args) - self.error_msg = error_msg + self.error_msg: ErrorMessage = error_msg class PendingRequest: def __init__(self): - self.event = Event() - self.data = None + self.event: asyncio.Event = asyncio.Event() + self.data: Message = None + + def __repr__(self): + if self.event.is_set(): + return f"" + return f"" def set(self, data): self.data = data self.event.set() +def requires_connection(func): + @functools.wraps(func) + async def new_func(self, *args, **kwargs): + await self._connect_event.wait() + return await func(self, *args, **kwargs) + return new_func + + +def requires_identification(func): + @functools.wraps(func) + async def new_func(self, *args, **kwargs): + await self._identify_event.wait() + return await func(self, *args, **kwargs) + return new_func + + class RoyalnetLink: def __init__(self, master_uri: str, secret: str, link_type: str, request_handler): assert ":" not in link_type @@ -42,20 +65,16 @@ class RoyalnetLink: self.nid: str = str(uuid.uuid4()) self.secret: str = secret self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None - self.identified: bool = False self.request_handler = request_handler self._pending_requests: typing.Dict[typing.Optional[Message]] = {} + self._connect_event: asyncio.Event = asyncio.Event() + self._identify_event: asyncio.Event = asyncio.Event() async def connect(self): + log.info(f"Connecting to {self.master_uri}...") self.websocket = await websockets.connect(self.master_uri) - - def requires_connection(func): - @functools.wraps(func) - def new_func(self, *args, **kwargs): - if self.websocket is None: - raise NotConnectedError("Tried to call a method which @requires_connection while not connected") - return func(self, *args, **kwargs) - return new_func + self._connect_event.set() + log.info(f"Connected!") @requires_connection async def receive(self) -> Package: @@ -63,34 +82,32 @@ class RoyalnetLink: raw_pickle = await self.websocket.recv() except websockets.ConnectionClosed: self.websocket = None - self.identified = False + self._connect_event.clear() + self._identify_event.clear() + log.info(f"Connection to {self.master_uri} was closed.") # What to do now? Let's just reraise. raise package: typing.Union[Package, Package] = pickle.loads(raw_pickle) assert package.destination == self.nid + log.debug(f"Received package: {package}") return package @requires_connection - async def identify(self, secret) -> None: - await self.websocket.send(f"Identify {self.nid}:{self.link_type}:{secret}") + async def identify(self) -> None: + log.info(f"Identifying to {self.master_uri}...") + await self.websocket.send(f"Identify {self.nid}:{self.link_type}:{self.secret}") response_package = await self.receive() response = response_package.data if isinstance(response, ErrorMessage): raise NetworkError(response, "Server returned error while identifying self") - self.identified = True - - def requires_identification(func): - @functools.wraps(func) - def new_func(self, *args, **kwargs): - if not self.identified: - raise NotIdentifiedError("Tried to call a method which @requires_identification while not identified") - return func(self, *args, **kwargs) - return new_func + self._identify_event.set() + log.info(f"Identified successfully!") @requires_identification async def send(self, package: Package): raw_pickle: bytes = pickle.dumps(package) await self.websocket.send(raw_pickle) + log.debug(f"Sent package: {package}") @requires_identification async def request(self, message, destination): @@ -98,19 +115,22 @@ class RoyalnetLink: request = PendingRequest() self._pending_requests[package.conversation_id] = request await self.send(package) + log.debug(f"Sent request: {message} -> {destination}") await request.event.wait() - result = request.data + result: Message = request.data + log.debug(f"Received response: {request} -> {result}") if isinstance(result, ErrorMessage): raise NetworkError(result, "Server returned error while requesting something") return result async def run(self): + log.debug(f"Running main client loop for {self.nid}.") while True: if self.websocket is None: await self.connect() - if not self.identified: + if not self._identify_event.is_set(): await self.identify() - package: Package = self.receive() + package: Package = await self.receive() # Package is a response if package.conversation_id in self._pending_requests: request = self._pending_requests[package.conversation_id] @@ -118,7 +138,9 @@ class RoyalnetLink: continue # Package is a request assert isinstance(package, Package) + log.debug(f"Received request: {package.source} -> {package.data}") response = await self.request_handler(package.data) if response is not None: response_package: Package = package.reply(response) await self.send(response_package) + log.debug(f"Replied to request: {response_package.data} -> {response_package.destination}") diff --git a/royalnet/network/royalnetserver.py b/royalnet/network/royalnetserver.py index 0452d5ea..54500c4f 100644 --- a/royalnet/network/royalnetserver.py +++ b/royalnet/network/royalnetserver.py @@ -3,11 +3,15 @@ import websockets import re import datetime import pickle -import asyncio import uuid +import asyncio +import logging from .messages import Message, ErrorMessage, InvalidPackageEM, InvalidSecretEM, IdentifySuccessfulMessage from .packages import Package +loop = asyncio.get_event_loop() +log = logging.getLogger(__name__) + class ConnectedClient: def __init__(self, socket: websockets.WebSocketServerProtocol): @@ -21,7 +25,7 @@ class ConnectedClient: return bool(self.nid) async def send(self, package: Package): - self.socket.send(package.pickle()) + await self.socket.send(package.pickle()) class RoyalnetServer: @@ -29,9 +33,9 @@ class RoyalnetServer: self.address: str = address self.port: int = port self.required_secret: str = required_secret - self.identified_clients: typing.List[ConnectedClient] = {} + self.identified_clients: typing.List[ConnectedClient] = [] - def find_client(self, *, nid: str=None, link_type: str=None) -> typing.List[ConnectedClient]: + def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]: assert not (nid and link_type) if nid: matching = [client for client in self.identified_clients if client.nid == nid] @@ -42,13 +46,15 @@ class RoyalnetServer: return matching or [] async def listener(self, websocket: websockets.server.WebSocketServerProtocol, request_uri: str): + log.info(f"{websocket.remote_address} connected to the server.") connected_client = ConnectedClient(websocket) # Wait for identification - identify_msg = websocket.recv() + identify_msg = await websocket.recv() + log.debug(f"{websocket.remote_address} identified itself with: {identify_msg}.") if not isinstance(identify_msg, str): websocket.send(InvalidPackageEM("Invalid identification message (not a str)")) return - identification = re.match(r"Identify ([A-Za-z0-9\-]+):([a-z]+):([A-Za-z0-9\-])", identify_msg) + identification = re.match(r"Identify ([A-Za-z0-9\-]+):([a-z]+):([A-Za-z0-9\-]+)", identify_msg) if identification is None: websocket.send(InvalidPackageEM("Invalid identification message (regex failed)")) return @@ -60,18 +66,21 @@ class RoyalnetServer: connected_client.nid = identification.group(1) connected_client.link_type = identification.group(2) self.identified_clients.append(connected_client) + log.debug(f"{websocket.remote_address} identified successfully as {connected_client.nid} ({connected_client.link_type}).") await connected_client.send(Package(IdentifySuccessfulMessage(), connected_client.nid, "__master__")) + log.debug(f"{connected_client.nid}'s identification confirmed.") # Main loop while True: # Receive packages raw_pickle = await websocket.recv() package: Package = pickle.loads(raw_pickle) + log.debug(f"Received package: {package}") # Check if the package destination is the server itself. if package.destination == "__master__": # TODO: do stuff pass # Otherwise, route the package to its destination - asyncio.create_task(self.route_package(package)) + loop.create_task(self.route_package(package)) def find_destination(self, package: Package) -> typing.List[ConnectedClient]: """Find a list of destinations for the sent packages""" @@ -95,8 +104,11 @@ class RoyalnetServer: async def route_package(self, package: Package) -> None: """Executed every time a package is received and must be routed somewhere.""" destinations = self.find_destination(package) + log.debug(f"Routing package: {package} -> {destinations}") for destination in destinations: - await destination.send(package) + specific_package = Package(package.data, destination.nid, package.source, conversation_id=package.conversation_id) + await destination.send(specific_package) async def run(self): - websockets.serve(self.listener, host=self.address, port=self.port) + log.debug(f"Running main server loop for __master__ on ws://{self.address}:{self.port}") + await websockets.serve(self.listener, host=self.address, port=self.port)