diff --git a/royalnet/network/royalnetlink.py b/royalnet/network/royalnetlink.py index bffe8db4..2be3d189 100644 --- a/royalnet/network/royalnetlink.py +++ b/royalnet/network/royalnetlink.py @@ -8,7 +8,7 @@ import logging as _logging from .messages import Message, ServerErrorMessage, RequestError from .packages import Package -loop = asyncio.get_event_loop() +default_loop = asyncio.get_event_loop() log = _logging.getLogger(__name__) @@ -27,8 +27,8 @@ class NetworkError(Exception): class PendingRequest: - def __init__(self): - self.event: asyncio.Event = asyncio.Event() + def __init__(self, *, loop=default_loop): + self.event: asyncio.Event = asyncio.Event(loop=loop) self.data: typing.Optional[Message] = None def __repr__(self): @@ -58,7 +58,8 @@ def requires_identification(func): class RoyalnetLink: - def __init__(self, master_uri: str, secret: str, link_type: str, request_handler): + def __init__(self, master_uri: str, secret: str, link_type: str, request_handler, *, + loop: asyncio.AbstractEventLoop = default_loop): assert ":" not in link_type self.master_uri: str = master_uri self.link_type: str = link_type @@ -67,12 +68,13 @@ class RoyalnetLink: self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None self.request_handler = request_handler self._pending_requests: typing.Dict[typing.Optional[Message]] = {} - self._connect_event: asyncio.Event = asyncio.Event() - self.identify_event: asyncio.Event = asyncio.Event() + self._loop: asyncio.AbstractEventLoop = loop + self._connect_event: asyncio.Event = asyncio.Event(loop=self._loop) + self.identify_event: asyncio.Event = asyncio.Event(loop=self._loop) async def connect(self): log.info(f"Connecting to {self.master_uri}...") - self.websocket = await websockets.connect(self.master_uri) + self.websocket = await websockets.connect(self.master_uri, loop=self._loop) self._connect_event.set() log.info(f"Connected!") @@ -112,7 +114,7 @@ class RoyalnetLink: @requires_identification async def request(self, message, destination): package = Package(message, destination, self.nid) - request = PendingRequest() + request = PendingRequest(loop=self._loop) self._pending_requests[package.source_conv_id] = request await self.send(package) log.debug(f"Sent request: {message} -> {destination}") @@ -141,6 +143,7 @@ class RoyalnetLink: log.debug(f"Received request {package.source_conv_id}: {package}") try: response = await self.request_handler(package.data) + assert isinstance(response, Message) except Exception as exc: response = RequestError(exc=exc) return diff --git a/royalnet/network/royalnetserver.py b/royalnet/network/royalnetserver.py index 3eb889dc..d63f92e0 100644 --- a/royalnet/network/royalnetserver.py +++ b/royalnet/network/royalnetserver.py @@ -9,7 +9,7 @@ import logging as _logging from .messages import InvalidPackageEM, InvalidSecretEM, IdentifySuccessfulMessage from .packages import Package -loop = asyncio.get_event_loop() +default_loop = asyncio.get_event_loop() log = _logging.getLogger(__name__) @@ -29,11 +29,12 @@ class ConnectedClient: class RoyalnetServer: - def __init__(self, address: str, port: int, required_secret: str): + def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = default_loop): self.address: str = address self.port: int = port self.required_secret: str = required_secret self.identified_clients: typing.List[ConnectedClient] = [] + self._loop: asyncio.AbstractEventLoop = loop def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]: assert not (nid and link_type) @@ -52,15 +53,15 @@ class RoyalnetServer: identify_msg = await websocket.recv() log.debug(f"{websocket.remote_address} identified itself with: {identify_msg}.") if not isinstance(identify_msg, str): - websocket.send(InvalidPackageEM("Invalid identification message (not a str)")) + await websocket.send(InvalidPackageEM("Invalid identification message (not a str)")) return - identification = re.match(r"Identify ([A-Za-z0-9\-]+):([a-z]+):([A-Za-z0-9\-]+)", identify_msg) + identification = re.match(r"Identify ([^:\s]+):([^:\s]+):([^:\s]+)", identify_msg) if identification is None: - websocket.send(InvalidPackageEM("Invalid identification message (regex failed)")) + await websocket.send(InvalidPackageEM("Invalid identification message (regex failed)")) return secret = identification.group(3) if secret != self.required_secret: - websocket.send(InvalidSecretEM("Invalid secret")) + await websocket.send(InvalidSecretEM("Invalid secret")) return # Identification successful connected_client.nid = identification.group(1) @@ -81,7 +82,7 @@ class RoyalnetServer: pass # Otherwise, route the package to its destination # noinspection PyAsyncCall - loop.create_task(self.route_package(package)) + self._loop.create_task(self.route_package(package)) def find_destination(self, package: Package) -> typing.List[ConnectedClient]: """Find a list of destinations for the sent packages""" @@ -112,7 +113,7 @@ class RoyalnetServer: async def serve(self): await websockets.serve(self.listener, host=self.address, port=self.port) - async def run(self): - log.debug(f"Running main server loop for __master__ on ws://{self.address}:{self.port}") + async def start(self): + log.debug(f"Starting main server loop for __master__ on ws://{self.address}:{self.port}") # noinspection PyAsyncCall - loop.create_task(self.serve()) + self._loop.create_task(self.serve())