diff --git a/royalnet/bots/telegram.py b/royalnet/bots/telegram.py index 9d191876..866d9eb1 100644 --- a/royalnet/bots/telegram.py +++ b/royalnet/bots/telegram.py @@ -14,13 +14,14 @@ class TelegramBot: def __init__(self, api_key: str, master_server_uri: str, + master_server_secret: str, commands: typing.List[typing.Type[Command]], missing_command: Command = NullCommand): self.bot: telegram.Bot = telegram.Bot(api_key) self.should_run: bool = False self.offset: int = -100 - self.missing_command: typing.Callable = missing_command - self.network: RoyalnetLink = RoyalnetLink(master_server_uri, "telegram", null) + self.missing_command = missing_command + self.network: RoyalnetLink = RoyalnetLink(master_server_uri, master_server_secret, "telegram", null) # Generate commands self.commands = {} for command in commands: diff --git a/royalnet/network/__init__.py b/royalnet/network/__init__.py index 8393cb8f..9015ede6 100644 --- a/royalnet/network/__init__.py +++ b/royalnet/network/__init__.py @@ -1,6 +1,16 @@ -from .messages import Message, ErrorMessage, InvalidSecretErrorMessage +from .messages import Message, ErrorMessage, InvalidSecretEM, InvalidDestinationEM, InvalidPackageEM +from .packages import Package from .royalnetlink import RoyalnetLink, NetworkError, NotConnectedError, NotIdentifiedError -from .packages import Package, TwoWayPackage +from .royalnetserver import RoyalnetServer -__all__ = ["Message", "ErrorMessage", "InvalidSecretErrorMessage", "RoyalnetLink", "NetworkError", "NotConnectedError", - "NotIdentifiedError", "Package", "TwoWayPackage"] +__all__ = ["Message", + "ErrorMessage", + "InvalidSecretEM", + "InvalidDestinationEM", + "InvalidPackageEM", + "RoyalnetLink", + "NetworkError", + "NotConnectedError", + "NotIdentifiedError", + "Package", + "RoyalnetServer"] diff --git a/royalnet/network/messages.py b/royalnet/network/messages.py index 59824eb2..6b13cfa2 100644 --- a/royalnet/network/messages.py +++ b/royalnet/network/messages.py @@ -12,9 +12,15 @@ class ErrorMessage(Message): self.reason = reason -class BadMessage(ErrorMessage): +class InvalidSecretEM(ErrorMessage): pass -class InvalidSecretErrorMessage(BadMessage): +class InvalidPackageEM(ErrorMessage): pass + + +class InvalidDestinationEM(InvalidPackageEM): + pass + + diff --git a/royalnet/network/packages.py b/royalnet/network/packages.py index ff9d28b8..9e2239c1 100644 --- a/royalnet/network/packages.py +++ b/royalnet/network/packages.py @@ -3,22 +3,15 @@ import uuid class Package: - def __init__(self, data, destination: str, *, conversation_id: str = None): + def __init__(self, data, destination: str, source: str, *, conversation_id: str = None): self.data = data self.destination: str = destination + self.source, = source self.conversation_id = conversation_id or str(uuid.uuid4()) + def reply(self, data) -> "Package": + return Package(data, self.source, self.destination, conversation_id=self.conversation_id) + def pickle(self): return pickle.dumps(self) - -class TwoWayPackage(Package): - def __init__(self, data, destination: str, source: str, *, conversation_id: str = None): - super().__init__(data, destination, conversation_id=conversation_id) - self.source = source - - def reply(self, data) -> Package: - return Package(data, self.source, conversation_id=self.conversation_id) - - def two_way_reply(self, data) -> "TwoWayPackage": - return TwoWayPackage(data, self.source, self.destination, conversation_id=self.conversation_id) diff --git a/royalnet/network/royalnetlink.py b/royalnet/network/royalnetlink.py index d07baf4e..c40419b6 100644 --- a/royalnet/network/royalnetlink.py +++ b/royalnet/network/royalnetlink.py @@ -6,7 +6,7 @@ import functools import typing import pickle from .messages import Message, ErrorMessage -from .packages import Package, TwoWayPackage +from .packages import Package loop = asyncio.get_event_loop() @@ -35,11 +35,12 @@ class PendingRequest: class RoyalnetLink: - def __init__(self, master_uri: str, link_type: str, request_handler): + def __init__(self, master_uri: str, secret: str, link_type: str, request_handler): assert ":" not in link_type self.master_uri: str = master_uri self.link_type: str = link_type self.nid: str = str(uuid.uuid4()) + self.secret: str = secret self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None self.identified: bool = False self.request_handler = request_handler @@ -48,12 +49,12 @@ class RoyalnetLink: async def connect(self): self.websocket = await websockets.connect(self.master_uri) - def requires_connection(self, func): + def requires_connection(func): @functools.wraps(func) - def new_func(*args, **kwargs): + def new_func(self, *args, **kwargs): if self.websocket is None: raise NotConnectedError("Tried to call a method which @requires_connection while not connected") - return func(*args, **kwargs) + return func(self, *args, **kwargs) return new_func @requires_connection @@ -65,7 +66,7 @@ class RoyalnetLink: self.identified = False # What to do now? Let's just reraise. raise - package: typing.Union[Package, TwoWayPackage] = pickle.loads(raw_pickle) + package: typing.Union[Package, Package] = pickle.loads(raw_pickle) assert package.destination == self.nid return package @@ -78,12 +79,12 @@ class RoyalnetLink: raise NetworkError(response, "Server returned error while identifying self") self.identified = True - def requires_identification(self, func): + def requires_identification(func): @functools.wraps(func) - def new_func(*args, **kwargs): + def new_func(self, *args, **kwargs): if not self.identified: raise NotIdentifiedError("Tried to call a method which @requires_identification while not identified") - return func(*args, **kwargs) + return func(self, *args, **kwargs) return new_func @requires_identification @@ -93,7 +94,7 @@ class RoyalnetLink: @requires_identification async def request(self, message, destination): - package = TwoWayPackage(message, destination, self.nid) + package = Package(message, destination, self.nid) request = PendingRequest() self._pending_requests[package.conversation_id] = request await self.send(package) @@ -103,7 +104,7 @@ class RoyalnetLink: raise NetworkError(result, "Server returned error while requesting something") return result - async def run_link(self): + async def run(self): while True: if self.websocket is None: await self.connect() @@ -116,7 +117,8 @@ class RoyalnetLink: request.set(package.data) continue # Package is a request - assert isinstance(package, TwoWayPackage) + assert isinstance(package, Package) response = await self.request_handler(package.data) - response_package: Package = package.reply(response) - await self.send(response_package) + if response is not None: + response_package: Package = package.reply(response) + await self.send(response_package) diff --git a/royalnet/network/royalnetserver.py b/royalnet/network/royalnetserver.py index ec9b4cb3..0452d5ea 100644 --- a/royalnet/network/royalnetserver.py +++ b/royalnet/network/royalnetserver.py @@ -2,8 +2,11 @@ import typing import websockets import re import datetime -from .messages import Message, ErrorMessage, BadMessage, InvalidSecretErrorMessage, IdentifySuccessfulMessage -from .packages import Package, TwoWayPackage +import pickle +import asyncio +import uuid +from .messages import Message, ErrorMessage, InvalidPackageEM, InvalidSecretEM, IdentifySuccessfulMessage +from .packages import Package class ConnectedClient: @@ -17,38 +20,83 @@ class ConnectedClient: def is_identified(self) -> bool: return bool(self.nid) + async def send(self, package: Package): + self.socket.send(package.pickle()) + class RoyalnetServer: - def __init__(self, required_secret: str): + def __init__(self, address: str, port: int, required_secret: str): + self.address: str = address + self.port: int = port self.required_secret: str = required_secret - self.connected_clients: typing.List[ConnectedClient] = {} - self.server: websockets.server.WebSocketServer = websockets.server + self.identified_clients: typing.List[ConnectedClient] = {} - def find_client_by_nid(self, nid: str): - return [client for client in self.connected_clients if client.nid == nid][0] + def find_client(self, *, nid: str=None, link_type: str=None) -> typing.List[ConnectedClient]: + assert not (nid and link_type) + if nid: + matching = [client for client in self.identified_clients if client.nid == nid] + assert len(matching) <= 1 + return matching + if link_type: + matching = [client for client in self.identified_clients if client.link_type == link_type] + return matching or [] async def listener(self, websocket: websockets.server.WebSocketServerProtocol, request_uri: str): connected_client = ConnectedClient(websocket) # Wait for identification identify_msg = websocket.recv() if not isinstance(identify_msg, str): - websocket.send(BadMessage("Invalid identification message (not a str)")) + websocket.send(InvalidPackageEM("Invalid identification message (not a str)")) return identification = re.match(r"Identify ([A-Za-z0-9\-]+):([a-z]+):([A-Za-z0-9\-])", identify_msg) if identification is None: - websocket.send(BadMessage("Invalid identification message (regex failed)")) + websocket.send(InvalidPackageEM("Invalid identification message (regex failed)")) return secret = identification.group(3) if secret != self.required_secret: - websocket.send(InvalidSecretErrorMessage("Invalid secret")) + websocket.send(InvalidSecretEM("Invalid secret")) return # Identification successful connected_client.nid = identification.group(1) connected_client.link_type = identification.group(2) - self.connected_clients.append(connected_client) + self.identified_clients.append(connected_client) + await connected_client.send(Package(IdentifySuccessfulMessage(), connected_client.nid, "__master__")) # Main loop while True: + # Receive packages + raw_pickle = await websocket.recv() + package: Package = pickle.loads(raw_pickle) + # Check if the package destination is the server itself. + if package.destination == "__master__": + # TODO: do stuff + pass + # Otherwise, route the package to its destination + asyncio.create_task(self.route_package(package)) + + def find_destination(self, package: Package) -> typing.List[ConnectedClient]: + """Find a list of destinations for the sent packages""" + # Parse destination + # Is it nothing? + if package.destination == "NULL": + return [] + # Is it the wildcard? + if package.destination == "*": + return self.identified_clients + # Is it a valid nid? + try: + destination = str(uuid.UUID(package.destination)) + except ValueError: pass + else: + return self.find_client(nid=destination) + # Is it a link_type? + return self.find_client(link_type=package.destination) + async def route_package(self, package: Package) -> None: + """Executed every time a package is received and must be routed somewhere.""" + destinations = self.find_destination(package) + for destination in destinations: + await destination.send(package) - + async def run(self): + websockets.serve(self.listener, host=self.address, port=self.port)