mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 19:44:20 +00:00
Fix a network bug
This commit is contained in:
parent
eb19932d7f
commit
8b239c564e
2 changed files with 22 additions and 18 deletions
|
@ -8,7 +8,7 @@ import logging as _logging
|
||||||
from .messages import Message, ServerErrorMessage, RequestError
|
from .messages import Message, ServerErrorMessage, RequestError
|
||||||
from .packages import Package
|
from .packages import Package
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
default_loop = asyncio.get_event_loop()
|
||||||
log = _logging.getLogger(__name__)
|
log = _logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,8 +27,8 @@ class NetworkError(Exception):
|
||||||
|
|
||||||
|
|
||||||
class PendingRequest:
|
class PendingRequest:
|
||||||
def __init__(self):
|
def __init__(self, *, loop=default_loop):
|
||||||
self.event: asyncio.Event = asyncio.Event()
|
self.event: asyncio.Event = asyncio.Event(loop=loop)
|
||||||
self.data: typing.Optional[Message] = None
|
self.data: typing.Optional[Message] = None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -58,7 +58,8 @@ def requires_identification(func):
|
||||||
|
|
||||||
|
|
||||||
class RoyalnetLink:
|
class RoyalnetLink:
|
||||||
def __init__(self, master_uri: str, secret: str, link_type: str, request_handler):
|
def __init__(self, master_uri: str, secret: str, link_type: str, request_handler, *,
|
||||||
|
loop: asyncio.AbstractEventLoop = default_loop):
|
||||||
assert ":" not in link_type
|
assert ":" not in link_type
|
||||||
self.master_uri: str = master_uri
|
self.master_uri: str = master_uri
|
||||||
self.link_type: str = link_type
|
self.link_type: str = link_type
|
||||||
|
@ -67,12 +68,13 @@ class RoyalnetLink:
|
||||||
self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None
|
self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None
|
||||||
self.request_handler = request_handler
|
self.request_handler = request_handler
|
||||||
self._pending_requests: typing.Dict[typing.Optional[Message]] = {}
|
self._pending_requests: typing.Dict[typing.Optional[Message]] = {}
|
||||||
self._connect_event: asyncio.Event = asyncio.Event()
|
self._loop: asyncio.AbstractEventLoop = loop
|
||||||
self.identify_event: asyncio.Event = asyncio.Event()
|
self._connect_event: asyncio.Event = asyncio.Event(loop=self._loop)
|
||||||
|
self.identify_event: asyncio.Event = asyncio.Event(loop=self._loop)
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
log.info(f"Connecting to {self.master_uri}...")
|
log.info(f"Connecting to {self.master_uri}...")
|
||||||
self.websocket = await websockets.connect(self.master_uri)
|
self.websocket = await websockets.connect(self.master_uri, loop=self._loop)
|
||||||
self._connect_event.set()
|
self._connect_event.set()
|
||||||
log.info(f"Connected!")
|
log.info(f"Connected!")
|
||||||
|
|
||||||
|
@ -112,7 +114,7 @@ class RoyalnetLink:
|
||||||
@requires_identification
|
@requires_identification
|
||||||
async def request(self, message, destination):
|
async def request(self, message, destination):
|
||||||
package = Package(message, destination, self.nid)
|
package = Package(message, destination, self.nid)
|
||||||
request = PendingRequest()
|
request = PendingRequest(loop=self._loop)
|
||||||
self._pending_requests[package.source_conv_id] = request
|
self._pending_requests[package.source_conv_id] = request
|
||||||
await self.send(package)
|
await self.send(package)
|
||||||
log.debug(f"Sent request: {message} -> {destination}")
|
log.debug(f"Sent request: {message} -> {destination}")
|
||||||
|
@ -141,6 +143,7 @@ class RoyalnetLink:
|
||||||
log.debug(f"Received request {package.source_conv_id}: {package}")
|
log.debug(f"Received request {package.source_conv_id}: {package}")
|
||||||
try:
|
try:
|
||||||
response = await self.request_handler(package.data)
|
response = await self.request_handler(package.data)
|
||||||
|
assert isinstance(response, Message)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
response = RequestError(exc=exc)
|
response = RequestError(exc=exc)
|
||||||
return
|
return
|
||||||
|
|
|
@ -9,7 +9,7 @@ import logging as _logging
|
||||||
from .messages import InvalidPackageEM, InvalidSecretEM, IdentifySuccessfulMessage
|
from .messages import InvalidPackageEM, InvalidSecretEM, IdentifySuccessfulMessage
|
||||||
from .packages import Package
|
from .packages import Package
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
default_loop = asyncio.get_event_loop()
|
||||||
log = _logging.getLogger(__name__)
|
log = _logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,11 +29,12 @@ class ConnectedClient:
|
||||||
|
|
||||||
|
|
||||||
class RoyalnetServer:
|
class RoyalnetServer:
|
||||||
def __init__(self, address: str, port: int, required_secret: str):
|
def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = default_loop):
|
||||||
self.address: str = address
|
self.address: str = address
|
||||||
self.port: int = port
|
self.port: int = port
|
||||||
self.required_secret: str = required_secret
|
self.required_secret: str = required_secret
|
||||||
self.identified_clients: typing.List[ConnectedClient] = []
|
self.identified_clients: typing.List[ConnectedClient] = []
|
||||||
|
self._loop: asyncio.AbstractEventLoop = loop
|
||||||
|
|
||||||
def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]:
|
def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]:
|
||||||
assert not (nid and link_type)
|
assert not (nid and link_type)
|
||||||
|
@ -52,15 +53,15 @@ class RoyalnetServer:
|
||||||
identify_msg = await websocket.recv()
|
identify_msg = await websocket.recv()
|
||||||
log.debug(f"{websocket.remote_address} identified itself with: {identify_msg}.")
|
log.debug(f"{websocket.remote_address} identified itself with: {identify_msg}.")
|
||||||
if not isinstance(identify_msg, str):
|
if not isinstance(identify_msg, str):
|
||||||
websocket.send(InvalidPackageEM("Invalid identification message (not a str)"))
|
await websocket.send(InvalidPackageEM("Invalid identification message (not a str)"))
|
||||||
return
|
return
|
||||||
identification = re.match(r"Identify ([A-Za-z0-9\-]+):([a-z]+):([A-Za-z0-9\-]+)", identify_msg)
|
identification = re.match(r"Identify ([^:\s]+):([^:\s]+):([^:\s]+)", identify_msg)
|
||||||
if identification is None:
|
if identification is None:
|
||||||
websocket.send(InvalidPackageEM("Invalid identification message (regex failed)"))
|
await websocket.send(InvalidPackageEM("Invalid identification message (regex failed)"))
|
||||||
return
|
return
|
||||||
secret = identification.group(3)
|
secret = identification.group(3)
|
||||||
if secret != self.required_secret:
|
if secret != self.required_secret:
|
||||||
websocket.send(InvalidSecretEM("Invalid secret"))
|
await websocket.send(InvalidSecretEM("Invalid secret"))
|
||||||
return
|
return
|
||||||
# Identification successful
|
# Identification successful
|
||||||
connected_client.nid = identification.group(1)
|
connected_client.nid = identification.group(1)
|
||||||
|
@ -81,7 +82,7 @@ class RoyalnetServer:
|
||||||
pass
|
pass
|
||||||
# Otherwise, route the package to its destination
|
# Otherwise, route the package to its destination
|
||||||
# noinspection PyAsyncCall
|
# noinspection PyAsyncCall
|
||||||
loop.create_task(self.route_package(package))
|
self._loop.create_task(self.route_package(package))
|
||||||
|
|
||||||
def find_destination(self, package: Package) -> typing.List[ConnectedClient]:
|
def find_destination(self, package: Package) -> typing.List[ConnectedClient]:
|
||||||
"""Find a list of destinations for the sent packages"""
|
"""Find a list of destinations for the sent packages"""
|
||||||
|
@ -112,7 +113,7 @@ class RoyalnetServer:
|
||||||
async def serve(self):
|
async def serve(self):
|
||||||
await websockets.serve(self.listener, host=self.address, port=self.port)
|
await websockets.serve(self.listener, host=self.address, port=self.port)
|
||||||
|
|
||||||
async def run(self):
|
async def start(self):
|
||||||
log.debug(f"Running main server loop for __master__ on ws://{self.address}:{self.port}")
|
log.debug(f"Starting main server loop for __master__ on ws://{self.address}:{self.port}")
|
||||||
# noinspection PyAsyncCall
|
# noinspection PyAsyncCall
|
||||||
loop.create_task(self.serve())
|
self._loop.create_task(self.serve())
|
||||||
|
|
Loading…
Reference in a new issue