1
Fork 0
mirror of https://github.com/RYGhub/royalnet.git synced 2024-12-17 23:24:20 +00:00

Fix a network bug

This commit is contained in:
Steffo 2019-04-20 02:18:49 +02:00
parent eb19932d7f
commit 8b239c564e
2 changed files with 22 additions and 18 deletions

View file

@ -8,7 +8,7 @@ import logging as _logging
from .messages import Message, ServerErrorMessage, RequestError from .messages import Message, ServerErrorMessage, RequestError
from .packages import Package from .packages import Package
loop = asyncio.get_event_loop() default_loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__) log = _logging.getLogger(__name__)
@ -27,8 +27,8 @@ class NetworkError(Exception):
class PendingRequest: class PendingRequest:
def __init__(self): def __init__(self, *, loop=default_loop):
self.event: asyncio.Event = asyncio.Event() self.event: asyncio.Event = asyncio.Event(loop=loop)
self.data: typing.Optional[Message] = None self.data: typing.Optional[Message] = None
def __repr__(self): def __repr__(self):
@ -58,7 +58,8 @@ def requires_identification(func):
class RoyalnetLink: class RoyalnetLink:
def __init__(self, master_uri: str, secret: str, link_type: str, request_handler): def __init__(self, master_uri: str, secret: str, link_type: str, request_handler, *,
loop: asyncio.AbstractEventLoop = default_loop):
assert ":" not in link_type assert ":" not in link_type
self.master_uri: str = master_uri self.master_uri: str = master_uri
self.link_type: str = link_type self.link_type: str = link_type
@ -67,12 +68,13 @@ class RoyalnetLink:
self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None
self.request_handler = request_handler self.request_handler = request_handler
self._pending_requests: typing.Dict[typing.Optional[Message]] = {} self._pending_requests: typing.Dict[typing.Optional[Message]] = {}
self._connect_event: asyncio.Event = asyncio.Event() self._loop: asyncio.AbstractEventLoop = loop
self.identify_event: asyncio.Event = asyncio.Event() self._connect_event: asyncio.Event = asyncio.Event(loop=self._loop)
self.identify_event: asyncio.Event = asyncio.Event(loop=self._loop)
async def connect(self): async def connect(self):
log.info(f"Connecting to {self.master_uri}...") log.info(f"Connecting to {self.master_uri}...")
self.websocket = await websockets.connect(self.master_uri) self.websocket = await websockets.connect(self.master_uri, loop=self._loop)
self._connect_event.set() self._connect_event.set()
log.info(f"Connected!") log.info(f"Connected!")
@ -112,7 +114,7 @@ class RoyalnetLink:
@requires_identification @requires_identification
async def request(self, message, destination): async def request(self, message, destination):
package = Package(message, destination, self.nid) package = Package(message, destination, self.nid)
request = PendingRequest() request = PendingRequest(loop=self._loop)
self._pending_requests[package.source_conv_id] = request self._pending_requests[package.source_conv_id] = request
await self.send(package) await self.send(package)
log.debug(f"Sent request: {message} -> {destination}") log.debug(f"Sent request: {message} -> {destination}")
@ -141,6 +143,7 @@ class RoyalnetLink:
log.debug(f"Received request {package.source_conv_id}: {package}") log.debug(f"Received request {package.source_conv_id}: {package}")
try: try:
response = await self.request_handler(package.data) response = await self.request_handler(package.data)
assert isinstance(response, Message)
except Exception as exc: except Exception as exc:
response = RequestError(exc=exc) response = RequestError(exc=exc)
return return

View file

@ -9,7 +9,7 @@ import logging as _logging
from .messages import InvalidPackageEM, InvalidSecretEM, IdentifySuccessfulMessage from .messages import InvalidPackageEM, InvalidSecretEM, IdentifySuccessfulMessage
from .packages import Package from .packages import Package
loop = asyncio.get_event_loop() default_loop = asyncio.get_event_loop()
log = _logging.getLogger(__name__) log = _logging.getLogger(__name__)
@ -29,11 +29,12 @@ class ConnectedClient:
class RoyalnetServer: class RoyalnetServer:
def __init__(self, address: str, port: int, required_secret: str): def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = default_loop):
self.address: str = address self.address: str = address
self.port: int = port self.port: int = port
self.required_secret: str = required_secret self.required_secret: str = required_secret
self.identified_clients: typing.List[ConnectedClient] = [] self.identified_clients: typing.List[ConnectedClient] = []
self._loop: asyncio.AbstractEventLoop = loop
def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]: def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]:
assert not (nid and link_type) assert not (nid and link_type)
@ -52,15 +53,15 @@ class RoyalnetServer:
identify_msg = await websocket.recv() identify_msg = await websocket.recv()
log.debug(f"{websocket.remote_address} identified itself with: {identify_msg}.") log.debug(f"{websocket.remote_address} identified itself with: {identify_msg}.")
if not isinstance(identify_msg, str): if not isinstance(identify_msg, str):
websocket.send(InvalidPackageEM("Invalid identification message (not a str)")) await websocket.send(InvalidPackageEM("Invalid identification message (not a str)"))
return return
identification = re.match(r"Identify ([A-Za-z0-9\-]+):([a-z]+):([A-Za-z0-9\-]+)", identify_msg) identification = re.match(r"Identify ([^:\s]+):([^:\s]+):([^:\s]+)", identify_msg)
if identification is None: if identification is None:
websocket.send(InvalidPackageEM("Invalid identification message (regex failed)")) await websocket.send(InvalidPackageEM("Invalid identification message (regex failed)"))
return return
secret = identification.group(3) secret = identification.group(3)
if secret != self.required_secret: if secret != self.required_secret:
websocket.send(InvalidSecretEM("Invalid secret")) await websocket.send(InvalidSecretEM("Invalid secret"))
return return
# Identification successful # Identification successful
connected_client.nid = identification.group(1) connected_client.nid = identification.group(1)
@ -81,7 +82,7 @@ class RoyalnetServer:
pass pass
# Otherwise, route the package to its destination # Otherwise, route the package to its destination
# noinspection PyAsyncCall # noinspection PyAsyncCall
loop.create_task(self.route_package(package)) self._loop.create_task(self.route_package(package))
def find_destination(self, package: Package) -> typing.List[ConnectedClient]: def find_destination(self, package: Package) -> typing.List[ConnectedClient]:
"""Find a list of destinations for the sent packages""" """Find a list of destinations for the sent packages"""
@ -112,7 +113,7 @@ class RoyalnetServer:
async def serve(self): async def serve(self):
await websockets.serve(self.listener, host=self.address, port=self.port) await websockets.serve(self.listener, host=self.address, port=self.port)
async def run(self): async def start(self):
log.debug(f"Running main server loop for __master__ on ws://{self.address}:{self.port}") log.debug(f"Starting main server loop for __master__ on ws://{self.address}:{self.port}")
# noinspection PyAsyncCall # noinspection PyAsyncCall
loop.create_task(self.serve()) self._loop.create_task(self.serve())