1
Fork 0
mirror of https://github.com/RYGhub/royalnet.git synced 2024-11-27 13:34:28 +00:00
royalnet/royalherald/server.py

146 lines
6.1 KiB
Python

import typing
import websockets
import re
import datetime
import uuid
import asyncio
import logging as _logging
from .package import Package
log = _logging.getLogger(__name__)
class ConnectedClient:
"""The :py:class:`Server`-side representation of a connected :py:class:`Link`."""
def __init__(self, socket: websockets.WebSocketServerProtocol):
self.socket: websockets.WebSocketServerProtocol = socket
self.nid: typing.Optional[str] = None
self.link_type: typing.Optional[str] = None
self.connection_datetime: datetime.datetime = datetime.datetime.now()
@property
def is_identified(self) -> bool:
"""Has the client sent a valid identification package?"""
return bool(self.nid)
async def send_service(self, msg_type: str, message: str):
await self.send(Package({"type": msg_type, "service": message},
source="<server>",
destination=self.nid))
async def send(self, package: Package):
"""Send a :py:class:`Package` to the :py:class:`Link`."""
await self.socket.send(package.to_json_bytes())
class Server:
def __init__(self, address: str, port: int, required_secret: str, *, loop: asyncio.AbstractEventLoop = None):
self.address: str = address
self.port: int = port
self.required_secret: str = required_secret
self.identified_clients: typing.List[ConnectedClient] = []
self.loop = loop
def find_client(self, *, nid: str = None, link_type: str = None) -> typing.List[ConnectedClient]:
assert not (nid and link_type)
if nid:
matching = [client for client in self.identified_clients if client.nid == nid]
assert len(matching) <= 1
return matching
if link_type:
matching = [client for client in self.identified_clients if client.link_type == link_type]
return matching or []
async def listener(self, websocket: websockets.server.WebSocketServerProtocol, path):
log.info(f"{websocket.remote_address} connected to the server.")
connected_client = ConnectedClient(websocket)
# Wait for identification
identify_msg = await websocket.recv()
log.debug(f"{websocket.remote_address} identified itself with: {identify_msg}.")
if not isinstance(identify_msg, str):
await connected_client.send_service("error", "Invalid identification message (not a str)")
return
identification = re.match(r"Identify ([^:\s]+):([^:\s]+):([^:\s]+)", identify_msg)
if identification is None:
await connected_client.send_service("error", "Invalid identification message (regex failed)")
return
secret = identification.group(3)
if secret != self.required_secret:
await connected_client.send_service("error", "Invalid secret")
return
# Identification successful
connected_client.nid = identification.group(1)
connected_client.link_type = identification.group(2)
self.identified_clients.append(connected_client)
log.debug(f"{websocket.remote_address} identified successfully as {connected_client.nid}"
f" ({connected_client.link_type}).")
await connected_client.send_service("success", "Identification successful!")
log.debug(f"{connected_client.nid}'s identification confirmed.")
# Main loop
while True:
# Receive packages
raw_bytes = await websocket.recv()
package: Package = Package.from_json_bytes(raw_bytes)
log.debug(f"Received package: {package}")
# Check if the package destination is the server itself.
if package.destination == "<server>":
# TODO: do stuff
pass
# Otherwise, route the package to its destination
# noinspection PyAsyncCall
self.loop.create_task(self.route_package(package))
def find_destination(self, package: Package) -> typing.List[ConnectedClient]:
"""Find a list of destinations for the package.
Parameters:
package: The package to find the destination of.
Returns:
A :py:class:`list` of :py:class:`ConnectedClient` to send the package to."""
# Parse destination
# Is it nothing?
if package.destination == "<none>":
return []
# Is it all possible destinations?
if package.destination == "<all>":
return self.identified_clients
# Is it a valid nid?
try:
destination = str(uuid.UUID(package.destination))
except ValueError:
pass
else:
return self.find_client(nid=destination)
# Is it a link_type?
return self.find_client(link_type=package.destination)
async def route_package(self, package: Package) -> None:
"""Executed every time a package is received and must be routed somewhere."""
destinations = self.find_destination(package)
log.debug(f"Routing package: {package} -> {destinations}")
for destination in destinations:
# This may have some consequences
specific_package = Package(package.data,
source=package.source,
destination=destination.nid,
source_conv_id=package.source_conv_id,
destination_conv_id=package.destination_conv_id)
await destination.send(specific_package)
def serve(self):
log.debug(f"Serving on ws://{self.address}:{self.port}")
self.loop.run_until_complete(self.run())
self.loop.run_forever()
async def run(self):
await websockets.serve(self.listener,
host=self.address,
port=self.port,
loop=self.loop)
def run_blocking(self):
if self.loop is None:
self.loop = asyncio.get_event_loop()
self.serve()