diff --git a/requirements.txt b/requirements.txt index 506efd03..b4556fcd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ python-telegram-bot>=11.1.0 +websockets>=7.0 diff --git a/royalnet/network/__init__.py b/royalnet/network/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/royalnet/network/messages.py b/royalnet/network/messages.py new file mode 100644 index 00000000..8d1b5b48 --- /dev/null +++ b/royalnet/network/messages.py @@ -0,0 +1,16 @@ +class Message: + pass + + +class IdentifySuccessfulMessage(Message): + pass + + +class ErrorMessage(Message): + def __init__(self, reason): + super().__init__() + self.reason = reason + + +class InvalidSecretErrorMessage(ErrorMessage): + pass diff --git a/royalnet/network/packages.py b/royalnet/network/packages.py new file mode 100644 index 00000000..ff9d28b8 --- /dev/null +++ b/royalnet/network/packages.py @@ -0,0 +1,24 @@ +import pickle +import uuid + + +class Package: + def __init__(self, data, destination: str, *, conversation_id: str = None): + self.data = data + self.destination: str = destination + self.conversation_id = conversation_id or str(uuid.uuid4()) + + def pickle(self): + return pickle.dumps(self) + + +class TwoWayPackage(Package): + def __init__(self, data, destination: str, source: str, *, conversation_id: str = None): + super().__init__(data, destination, conversation_id=conversation_id) + self.source = source + + def reply(self, data) -> Package: + return Package(data, self.source, conversation_id=self.conversation_id) + + def two_way_reply(self, data) -> "TwoWayPackage": + return TwoWayPackage(data, self.source, self.destination, conversation_id=self.conversation_id) diff --git a/royalnet/network/royalnetlink.py b/royalnet/network/royalnetlink.py new file mode 100644 index 00000000..4d62e19c --- /dev/null +++ b/royalnet/network/royalnetlink.py @@ -0,0 +1,120 @@ +import asyncio +from asyncio import Event +import websockets +import uuid +import functools +import typing +import pickle +from .messages import Message, IdentifyMessage, ErrorMessage +from .packages import Package, TwoWayPackage +loop = asyncio.get_event_loop() + + +class NotConnectedError(Exception): + pass + + +class NotIdentifiedError(Exception): + pass + + +class NetworkError(Exception): + def __init__(self, error_msg: ErrorMessage, *args): + super().__init__(*args) + self.error_msg = error_msg + + +class PendingRequest: + def __init__(self): + self.event = Event() + self.data = None + + def set(self, data): + self.data = data + self.event.set() + + +class RoyalnetLink: + def __init__(self, master_uri: str, request_handler): + self.master_uri: str = master_uri + self.nid: str = str(uuid.uuid4()) + self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None + self.identified: bool = False + self.request_handler = request_handler + self._pending_requests: typing.Dict[typing.Optional[Message]] = {} + + async def connect(self): + self.websocket = await websockets.connect(self.master_uri) + + def requires_connection(self, func): + @functools.wraps(func) + def new_func(*args, **kwargs): + if self.websocket is None: + raise NotConnectedError("Tried to call a method which @requires_connection while not connected") + return func(*args, **kwargs) + return new_func + + @requires_connection + async def receive(self) -> Package: + try: + raw_pickle = await self.websocket.recv() + except websockets.ConnectionClosed: + self.websocket = None + self.identified = False + # What to do now? Let's just reraise. + raise + package: typing.Union[Package, TwoWayPackage] = pickle.loads(raw_pickle) + assert package.destination == self.nid + return package + + @requires_connection + async def identify(self, secret) -> None: + await self.websocket.send(f"Identify: {self.nid}:{secret}") + response_package = await self.receive() + response = response_package.data + if isinstance(response, ErrorMessage): + raise NetworkError(response, "Server returned error while identifying self") + self.identified = True + + def requires_identification(self, func): + @functools.wraps(func) + def new_func(*args, **kwargs): + if not self.identified: + raise NotIdentifiedError("Tried to call a method which @requires_identification while not identified") + return func(*args, **kwargs) + return new_func + + @requires_identification + async def send(self, package: Package): + raw_pickle: bytes = pickle.dumps(package) + await self.websocket.send(raw_pickle) + + @requires_identification + async def request(self, message, destination): + package = TwoWayPackage(message, destination, self.nid) + request = PendingRequest() + self._pending_requests[package.conversation_id] = request + await self.send(package) + await request.event.wait() + result = request.data + if isinstance(result, ErrorMessage): + raise NetworkError(result, "Server returned error while requesting something") + return result + + async def run_link(self): + while True: + if self.websocket is None: + await self.connect() + if not self.identified: + await self.identify() + package: Package = self.receive() + # Package is a response + if package.conversation_id in self._pending_requests: + request = self._pending_requests[package.conversation_id] + request.set(package.data) + continue + # Package is a request + assert isinstance(package, TwoWayPackage) + response = await self.request_handler(package.data) + response_package: Package = package.reply(response) + await self.send(response_package) diff --git a/royalnet/utils/networkdict.py b/royalnet/utils/networkdict.py deleted file mode 100644 index 2711e1b2..00000000 --- a/royalnet/utils/networkdict.py +++ /dev/null @@ -1,40 +0,0 @@ -import uuid -import typing -from asyncio import Event - - -class RoyalnetData: - """A class to hold data to be sent to the Royalnet.""" - def __init__(self, data): - self.uuid = str(uuid.uuid4()) - self.request = data - self.event = Event() - self.response = None - - def send(self): - """TODO EVERYTHING""" - - - -class RoyalnetWait: - """A class that represents a data request sent to the Royalnet.""" - def __init__(self): - self.event = Event() - self.data = None - - def receive(self, data): - self.data = data - self.event.set() - - async def get(self): - await self.event.wait() - return self.data - - -class RoyalnetDict: - """A dictionary used to asyncrounosly hold data received from the Royalnet.""" - - def __init__(self): - self.dict: typing.Dict[str, RoyalnetRequest] = {} - - async def request(self, data: RoyalnetWait):