mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 11:34:18 +00:00
Basically complete network branch
This commit is contained in:
parent
11bbb77afe
commit
879c0ce953
4 changed files with 77 additions and 42 deletions
|
@ -1,5 +1,6 @@
|
|||
class Message:
|
||||
pass
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}>"
|
||||
|
||||
|
||||
class IdentifySuccessfulMessage(Message):
|
||||
|
@ -22,5 +23,3 @@ class InvalidPackageEM(ErrorMessage):
|
|||
|
||||
class InvalidDestinationEM(InvalidPackageEM):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -6,12 +6,14 @@ class Package:
|
|||
def __init__(self, data, destination: str, source: str, *, conversation_id: str = None):
|
||||
self.data = data
|
||||
self.destination: str = destination
|
||||
self.source, = source
|
||||
self.source = source
|
||||
self.conversation_id = conversation_id or str(uuid.uuid4())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Package to {self.destination}: {self.data.__class__.__name__}>"
|
||||
|
||||
def reply(self, data) -> "Package":
|
||||
return Package(data, self.source, self.destination, conversation_id=self.conversation_id)
|
||||
|
||||
def pickle(self):
|
||||
return pickle.dumps(self)
|
||||
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import asyncio
|
||||
from asyncio import Event
|
||||
import websockets
|
||||
import uuid
|
||||
import functools
|
||||
import typing
|
||||
import pickle
|
||||
import logging
|
||||
from .messages import Message, ErrorMessage
|
||||
from .packages import Package
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotConnectedError(Exception):
|
||||
|
@ -21,19 +23,40 @@ class NotIdentifiedError(Exception):
|
|||
class NetworkError(Exception):
|
||||
def __init__(self, error_msg: ErrorMessage, *args):
|
||||
super().__init__(*args)
|
||||
self.error_msg = error_msg
|
||||
self.error_msg: ErrorMessage = error_msg
|
||||
|
||||
|
||||
class PendingRequest:
|
||||
def __init__(self):
|
||||
self.event = Event()
|
||||
self.data = None
|
||||
self.event: asyncio.Event = asyncio.Event()
|
||||
self.data: Message = None
|
||||
|
||||
def __repr__(self):
|
||||
if self.event.is_set():
|
||||
return f"<PendingRequest: {self.data.__class__.__name__}>"
|
||||
return f"<PendingRequest>"
|
||||
|
||||
def set(self, data):
|
||||
self.data = data
|
||||
self.event.set()
|
||||
|
||||
|
||||
def requires_connection(func):
|
||||
@functools.wraps(func)
|
||||
async def new_func(self, *args, **kwargs):
|
||||
await self._connect_event.wait()
|
||||
return await func(self, *args, **kwargs)
|
||||
return new_func
|
||||
|
||||
|
||||
def requires_identification(func):
|
||||
@functools.wraps(func)
|
||||
async def new_func(self, *args, **kwargs):
|
||||
await self._identify_event.wait()
|
||||
return await func(self, *args, **kwargs)
|
||||
return new_func
|
||||
|
||||
|
||||
class RoyalnetLink:
|
||||
def __init__(self, master_uri: str, secret: str, link_type: str, request_handler):
|
||||
assert ":" not in link_type
|
||||
|
@ -42,20 +65,16 @@ class RoyalnetLink:
|
|||
self.nid: str = str(uuid.uuid4())
|
||||
self.secret: str = secret
|
||||
self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None
|
||||
self.identified: bool = False
|
||||
self.request_handler = request_handler
|
||||
self._pending_requests: typing.Dict[typing.Optional[Message]] = {}
|
||||
self._connect_event: asyncio.Event = asyncio.Event()
|
||||
self._identify_event: asyncio.Event = asyncio.Event()
|
||||
|
||||
async def connect(self):
|
||||
log.info(f"Connecting to {self.master_uri}...")
|
||||
self.websocket = await websockets.connect(self.master_uri)
|
||||
|
||||
def requires_connection(func):
|
||||
@functools.wraps(func)
|
||||
def new_func(self, *args, **kwargs):
|
||||
if self.websocket is None:
|
||||
raise NotConnectedError("Tried to call a method which @requires_connection while not connected")
|
||||
return func(self, *args, **kwargs)
|
||||
return new_func
|
||||
self._connect_event.set()
|
||||
log.info(f"Connected!")
|
||||
|
||||
@requires_connection
|
||||
async def receive(self) -> Package:
|
||||
|
@ -63,34 +82,32 @@ class RoyalnetLink:
|
|||
raw_pickle = await self.websocket.recv()
|
||||
except websockets.ConnectionClosed:
|
||||
self.websocket = None
|
||||
self.identified = False
|
||||
self._connect_event.clear()
|
||||
self._identify_event.clear()
|
||||
log.info(f"Connection to {self.master_uri} was closed.")
|
||||
# What to do now? Let's just reraise.
|
||||
raise
|
||||
package: typing.Union[Package, Package] = pickle.loads(raw_pickle)
|
||||
assert package.destination == self.nid
|
||||
log.debug(f"Received package: {package}")
|
||||
return package
|
||||
|
||||
@requires_connection
|
||||
async def identify(self, secret) -> None:
|
||||
await self.websocket.send(f"Identify {self.nid}:{self.link_type}:{secret}")
|
||||
async def identify(self) -> None:
|
||||
log.info(f"Identifying to {self.master_uri}...")
|
||||
await self.websocket.send(f"Identify {self.nid}:{self.link_type}:{self.secret}")
|
||||
response_package = await self.receive()
|
||||
response = response_package.data
|
||||
if isinstance(response, ErrorMessage):
|
||||
raise NetworkError(response, "Server returned error while identifying self")
|
||||
self.identified = True
|
||||
|
||||
def requires_identification(func):
|
||||
@functools.wraps(func)
|
||||
def new_func(self, *args, **kwargs):
|
||||
if not self.identified:
|
||||
raise NotIdentifiedError("Tried to call a method which @requires_identification while not identified")
|
||||
return func(self, *args, **kwargs)
|
||||
return new_func
|
||||
self._identify_event.set()
|
||||
log.info(f"Identified successfully!")
|
||||
|
||||
@requires_identification
|
||||
async def send(self, package: Package):
|
||||
raw_pickle: bytes = pickle.dumps(package)
|
||||
await self.websocket.send(raw_pickle)
|
||||
log.debug(f"Sent package: {package}")
|
||||
|
||||
@requires_identification
|
||||
async def request(self, message, destination):
|
||||
|
@ -98,19 +115,22 @@ class RoyalnetLink:
|
|||
request = PendingRequest()
|
||||
self._pending_requests[package.conversation_id] = request
|
||||
await self.send(package)
|
||||
log.debug(f"Sent request: {message} -> {destination}")
|
||||
await request.event.wait()
|
||||
result = request.data
|
||||
result: Message = request.data
|
||||
log.debug(f"Received response: {request} -> {result}")
|
||||
if isinstance(result, ErrorMessage):
|
||||
raise NetworkError(result, "Server returned error while requesting something")
|
||||
return result
|
||||
|
||||
async def run(self):
|
||||
log.debug(f"Running main client loop for {self.nid}.")
|
||||
while True:
|
||||
if self.websocket is None:
|
||||
await self.connect()
|
||||
if not self.identified:
|
||||
if not self._identify_event.is_set():
|
||||
await self.identify()
|
||||
package: Package = self.receive()
|
||||
package: Package = await self.receive()
|
||||
# Package is a response
|
||||
if package.conversation_id in self._pending_requests:
|
||||
request = self._pending_requests[package.conversation_id]
|
||||
|
@ -118,7 +138,9 @@ class RoyalnetLink:
|
|||
continue
|
||||
# Package is a request
|
||||
assert isinstance(package, Package)
|
||||
log.debug(f"Received request: {package.source} -> {package.data}")
|
||||
response = await self.request_handler(package.data)
|
||||
if response is not None:
|
||||
response_package: Package = package.reply(response)
|
||||
await self.send(response_package)
|
||||
log.debug(f"Replied to request: {response_package.data} -> {response_package.destination}")
|
||||
|
|
|
@ -3,11 +3,15 @@ import websockets
|
|||
import re
|
||||
import datetime
|
||||
import pickle
|
||||
import asyncio
|
||||
import uuid
|
||||
import asyncio
|
||||
import logging
|
||||
from .messages import Message, ErrorMessage, InvalidPackageEM, InvalidSecretEM, IdentifySuccessfulMessage
|
||||
from .packages import Package
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectedClient:
|
||||
def __init__(self, socket: websockets.WebSocketServerProtocol):
|
||||
|
@ -21,7 +25,7 @@ class ConnectedClient:
|
|||
return bool(self.nid)
|
||||
|
||||
async def send(self, package: Package):
|
||||
self.socket.send(package.pickle())
|
||||
await self.socket.send(package.pickle())
|
||||
|
||||
|
||||
class RoyalnetServer:
|
||||
|
@ -29,9 +33,9 @@ class RoyalnetServer:
|
|||
self.address: str = address
|
||||
self.port: int = port
|
||||
self.required_secret: str = required_secret
|
||||
self.identified_clients: typing.List[ConnectedClient] = {}
|
||||
self.identified_clients: typing.List[ConnectedClient] = []
|
||||
|
||||
def find_client(self, *, nid: str=None, link_type: str=None) -> typing.List[ConnectedClient]:
|
||||
def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]:
|
||||
assert not (nid and link_type)
|
||||
if nid:
|
||||
matching = [client for client in self.identified_clients if client.nid == nid]
|
||||
|
@ -42,13 +46,15 @@ class RoyalnetServer:
|
|||
return matching or []
|
||||
|
||||
async def listener(self, websocket: websockets.server.WebSocketServerProtocol, request_uri: str):
|
||||
log.info(f"{websocket.remote_address} connected to the server.")
|
||||
connected_client = ConnectedClient(websocket)
|
||||
# Wait for identification
|
||||
identify_msg = websocket.recv()
|
||||
identify_msg = await websocket.recv()
|
||||
log.debug(f"{websocket.remote_address} identified itself with: {identify_msg}.")
|
||||
if not isinstance(identify_msg, str):
|
||||
websocket.send(InvalidPackageEM("Invalid identification message (not a str)"))
|
||||
return
|
||||
identification = re.match(r"Identify ([A-Za-z0-9\-]+):([a-z]+):([A-Za-z0-9\-])", identify_msg)
|
||||
identification = re.match(r"Identify ([A-Za-z0-9\-]+):([a-z]+):([A-Za-z0-9\-]+)", identify_msg)
|
||||
if identification is None:
|
||||
websocket.send(InvalidPackageEM("Invalid identification message (regex failed)"))
|
||||
return
|
||||
|
@ -60,18 +66,21 @@ class RoyalnetServer:
|
|||
connected_client.nid = identification.group(1)
|
||||
connected_client.link_type = identification.group(2)
|
||||
self.identified_clients.append(connected_client)
|
||||
log.debug(f"{websocket.remote_address} identified successfully as {connected_client.nid} ({connected_client.link_type}).")
|
||||
await connected_client.send(Package(IdentifySuccessfulMessage(), connected_client.nid, "__master__"))
|
||||
log.debug(f"{connected_client.nid}'s identification confirmed.")
|
||||
# Main loop
|
||||
while True:
|
||||
# Receive packages
|
||||
raw_pickle = await websocket.recv()
|
||||
package: Package = pickle.loads(raw_pickle)
|
||||
log.debug(f"Received package: {package}")
|
||||
# Check if the package destination is the server itself.
|
||||
if package.destination == "__master__":
|
||||
# TODO: do stuff
|
||||
pass
|
||||
# Otherwise, route the package to its destination
|
||||
asyncio.create_task(self.route_package(package))
|
||||
loop.create_task(self.route_package(package))
|
||||
|
||||
def find_destination(self, package: Package) -> typing.List[ConnectedClient]:
|
||||
"""Find a list of destinations for the sent packages"""
|
||||
|
@ -95,8 +104,11 @@ class RoyalnetServer:
|
|||
async def route_package(self, package: Package) -> None:
|
||||
"""Executed every time a package is received and must be routed somewhere."""
|
||||
destinations = self.find_destination(package)
|
||||
log.debug(f"Routing package: {package} -> {destinations}")
|
||||
for destination in destinations:
|
||||
await destination.send(package)
|
||||
specific_package = Package(package.data, destination.nid, package.source, conversation_id=package.conversation_id)
|
||||
await destination.send(specific_package)
|
||||
|
||||
async def run(self):
|
||||
websockets.serve(self.listener, host=self.address, port=self.port)
|
||||
log.debug(f"Running main server loop for __master__ on ws://{self.address}:{self.port}")
|
||||
await websockets.serve(self.listener, host=self.address, port=self.port)
|
||||
|
|
Loading…
Reference in a new issue