1
Fork 0
mirror of https://github.com/RYGhub/royalnet.git synced 2024-11-27 13:34:28 +00:00

More progress

This commit is contained in:
Steffo 2019-03-18 15:32:48 +01:00
parent 3bc8046bc4
commit 11bbb77afe
6 changed files with 106 additions and 46 deletions

View file

@ -14,13 +14,14 @@ class TelegramBot:
def __init__(self, def __init__(self,
api_key: str, api_key: str,
master_server_uri: str, master_server_uri: str,
master_server_secret: str,
commands: typing.List[typing.Type[Command]], commands: typing.List[typing.Type[Command]],
missing_command: Command = NullCommand): missing_command: Command = NullCommand):
self.bot: telegram.Bot = telegram.Bot(api_key) self.bot: telegram.Bot = telegram.Bot(api_key)
self.should_run: bool = False self.should_run: bool = False
self.offset: int = -100 self.offset: int = -100
self.missing_command: typing.Callable = missing_command self.missing_command = missing_command
self.network: RoyalnetLink = RoyalnetLink(master_server_uri, "telegram", null) self.network: RoyalnetLink = RoyalnetLink(master_server_uri, master_server_secret, "telegram", null)
# Generate commands # Generate commands
self.commands = {} self.commands = {}
for command in commands: for command in commands:

View file

@ -1,6 +1,16 @@
from .messages import Message, ErrorMessage, InvalidSecretErrorMessage from .messages import Message, ErrorMessage, InvalidSecretEM, InvalidDestinationEM, InvalidPackageEM
from .packages import Package
from .royalnetlink import RoyalnetLink, NetworkError, NotConnectedError, NotIdentifiedError from .royalnetlink import RoyalnetLink, NetworkError, NotConnectedError, NotIdentifiedError
from .packages import Package, TwoWayPackage from .royalnetserver import RoyalnetServer
__all__ = ["Message", "ErrorMessage", "InvalidSecretErrorMessage", "RoyalnetLink", "NetworkError", "NotConnectedError", __all__ = ["Message",
"NotIdentifiedError", "Package", "TwoWayPackage"] "ErrorMessage",
"InvalidSecretEM",
"InvalidDestinationEM",
"InvalidPackageEM",
"RoyalnetLink",
"NetworkError",
"NotConnectedError",
"NotIdentifiedError",
"Package",
"RoyalnetServer"]

View file

@ -12,9 +12,15 @@ class ErrorMessage(Message):
self.reason = reason self.reason = reason
class BadMessage(ErrorMessage): class InvalidSecretEM(ErrorMessage):
pass pass
class InvalidSecretErrorMessage(BadMessage): class InvalidPackageEM(ErrorMessage):
pass pass
class InvalidDestinationEM(InvalidPackageEM):
pass

View file

@ -3,22 +3,15 @@ import uuid
class Package: class Package:
def __init__(self, data, destination: str, *, conversation_id: str = None): def __init__(self, data, destination: str, source: str, *, conversation_id: str = None):
self.data = data self.data = data
self.destination: str = destination self.destination: str = destination
self.source, = source
self.conversation_id = conversation_id or str(uuid.uuid4()) self.conversation_id = conversation_id or str(uuid.uuid4())
def reply(self, data) -> "Package":
return Package(data, self.source, self.destination, conversation_id=self.conversation_id)
def pickle(self): def pickle(self):
return pickle.dumps(self) return pickle.dumps(self)
class TwoWayPackage(Package):
def __init__(self, data, destination: str, source: str, *, conversation_id: str = None):
super().__init__(data, destination, conversation_id=conversation_id)
self.source = source
def reply(self, data) -> Package:
return Package(data, self.source, conversation_id=self.conversation_id)
def two_way_reply(self, data) -> "TwoWayPackage":
return TwoWayPackage(data, self.source, self.destination, conversation_id=self.conversation_id)

View file

@ -6,7 +6,7 @@ import functools
import typing import typing
import pickle import pickle
from .messages import Message, ErrorMessage from .messages import Message, ErrorMessage
from .packages import Package, TwoWayPackage from .packages import Package
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -35,11 +35,12 @@ class PendingRequest:
class RoyalnetLink: class RoyalnetLink:
def __init__(self, master_uri: str, link_type: str, request_handler): def __init__(self, master_uri: str, secret: str, link_type: str, request_handler):
assert ":" not in link_type assert ":" not in link_type
self.master_uri: str = master_uri self.master_uri: str = master_uri
self.link_type: str = link_type self.link_type: str = link_type
self.nid: str = str(uuid.uuid4()) self.nid: str = str(uuid.uuid4())
self.secret: str = secret
self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None
self.identified: bool = False self.identified: bool = False
self.request_handler = request_handler self.request_handler = request_handler
@ -48,12 +49,12 @@ class RoyalnetLink:
async def connect(self): async def connect(self):
self.websocket = await websockets.connect(self.master_uri) self.websocket = await websockets.connect(self.master_uri)
def requires_connection(self, func): def requires_connection(func):
@functools.wraps(func) @functools.wraps(func)
def new_func(*args, **kwargs): def new_func(self, *args, **kwargs):
if self.websocket is None: if self.websocket is None:
raise NotConnectedError("Tried to call a method which @requires_connection while not connected") raise NotConnectedError("Tried to call a method which @requires_connection while not connected")
return func(*args, **kwargs) return func(self, *args, **kwargs)
return new_func return new_func
@requires_connection @requires_connection
@ -65,7 +66,7 @@ class RoyalnetLink:
self.identified = False self.identified = False
# What to do now? Let's just reraise. # What to do now? Let's just reraise.
raise raise
package: typing.Union[Package, TwoWayPackage] = pickle.loads(raw_pickle) package: typing.Union[Package, Package] = pickle.loads(raw_pickle)
assert package.destination == self.nid assert package.destination == self.nid
return package return package
@ -78,12 +79,12 @@ class RoyalnetLink:
raise NetworkError(response, "Server returned error while identifying self") raise NetworkError(response, "Server returned error while identifying self")
self.identified = True self.identified = True
def requires_identification(self, func): def requires_identification(func):
@functools.wraps(func) @functools.wraps(func)
def new_func(*args, **kwargs): def new_func(self, *args, **kwargs):
if not self.identified: if not self.identified:
raise NotIdentifiedError("Tried to call a method which @requires_identification while not identified") raise NotIdentifiedError("Tried to call a method which @requires_identification while not identified")
return func(*args, **kwargs) return func(self, *args, **kwargs)
return new_func return new_func
@requires_identification @requires_identification
@ -93,7 +94,7 @@ class RoyalnetLink:
@requires_identification @requires_identification
async def request(self, message, destination): async def request(self, message, destination):
package = TwoWayPackage(message, destination, self.nid) package = Package(message, destination, self.nid)
request = PendingRequest() request = PendingRequest()
self._pending_requests[package.conversation_id] = request self._pending_requests[package.conversation_id] = request
await self.send(package) await self.send(package)
@ -103,7 +104,7 @@ class RoyalnetLink:
raise NetworkError(result, "Server returned error while requesting something") raise NetworkError(result, "Server returned error while requesting something")
return result return result
async def run_link(self): async def run(self):
while True: while True:
if self.websocket is None: if self.websocket is None:
await self.connect() await self.connect()
@ -116,7 +117,8 @@ class RoyalnetLink:
request.set(package.data) request.set(package.data)
continue continue
# Package is a request # Package is a request
assert isinstance(package, TwoWayPackage) assert isinstance(package, Package)
response = await self.request_handler(package.data) response = await self.request_handler(package.data)
response_package: Package = package.reply(response) if response is not None:
await self.send(response_package) response_package: Package = package.reply(response)
await self.send(response_package)

View file

@ -2,8 +2,11 @@ import typing
import websockets import websockets
import re import re
import datetime import datetime
from .messages import Message, ErrorMessage, BadMessage, InvalidSecretErrorMessage, IdentifySuccessfulMessage import pickle
from .packages import Package, TwoWayPackage import asyncio
import uuid
from .messages import Message, ErrorMessage, InvalidPackageEM, InvalidSecretEM, IdentifySuccessfulMessage
from .packages import Package
class ConnectedClient: class ConnectedClient:
@ -17,38 +20,83 @@ class ConnectedClient:
def is_identified(self) -> bool: def is_identified(self) -> bool:
return bool(self.nid) return bool(self.nid)
async def send(self, package: Package):
self.socket.send(package.pickle())
class RoyalnetServer: class RoyalnetServer:
def __init__(self, required_secret: str): def __init__(self, address: str, port: int, required_secret: str):
self.address: str = address
self.port: int = port
self.required_secret: str = required_secret self.required_secret: str = required_secret
self.connected_clients: typing.List[ConnectedClient] = {} self.identified_clients: typing.List[ConnectedClient] = {}
self.server: websockets.server.WebSocketServer = websockets.server
def find_client_by_nid(self, nid: str): def find_client(self, *, nid: str=None, link_type: str=None) -> typing.List[ConnectedClient]:
return [client for client in self.connected_clients if client.nid == nid][0] assert not (nid and link_type)
if nid:
matching = [client for client in self.identified_clients if client.nid == nid]
assert len(matching) <= 1
return matching
if link_type:
matching = [client for client in self.identified_clients if client.link_type == link_type]
return matching or []
async def listener(self, websocket: websockets.server.WebSocketServerProtocol, request_uri: str): async def listener(self, websocket: websockets.server.WebSocketServerProtocol, request_uri: str):
connected_client = ConnectedClient(websocket) connected_client = ConnectedClient(websocket)
# Wait for identification # Wait for identification
identify_msg = websocket.recv() identify_msg = websocket.recv()
if not isinstance(identify_msg, str): if not isinstance(identify_msg, str):
websocket.send(BadMessage("Invalid identification message (not a str)")) websocket.send(InvalidPackageEM("Invalid identification message (not a str)"))
return return
identification = re.match(r"Identify ([A-Za-z0-9\-]+):([a-z]+):([A-Za-z0-9\-])", identify_msg) identification = re.match(r"Identify ([A-Za-z0-9\-]+):([a-z]+):([A-Za-z0-9\-])", identify_msg)
if identification is None: if identification is None:
websocket.send(BadMessage("Invalid identification message (regex failed)")) websocket.send(InvalidPackageEM("Invalid identification message (regex failed)"))
return return
secret = identification.group(3) secret = identification.group(3)
if secret != self.required_secret: if secret != self.required_secret:
websocket.send(InvalidSecretErrorMessage("Invalid secret")) websocket.send(InvalidSecretEM("Invalid secret"))
return return
# Identification successful # Identification successful
connected_client.nid = identification.group(1) connected_client.nid = identification.group(1)
connected_client.link_type = identification.group(2) connected_client.link_type = identification.group(2)
self.connected_clients.append(connected_client) self.identified_clients.append(connected_client)
await connected_client.send(Package(IdentifySuccessfulMessage(), connected_client.nid, "__master__"))
# Main loop # Main loop
while True: while True:
# Receive packages
raw_pickle = await websocket.recv()
package: Package = pickle.loads(raw_pickle)
# Check if the package destination is the server itself.
if package.destination == "__master__":
# TODO: do stuff
pass
# Otherwise, route the package to its destination
asyncio.create_task(self.route_package(package))
def find_destination(self, package: Package) -> typing.List[ConnectedClient]:
"""Find a list of destinations for the sent packages"""
# Parse destination
# Is it nothing?
if package.destination == "NULL":
return []
# Is it the wildcard?
if package.destination == "*":
return self.identified_clients
# Is it a valid nid?
try:
destination = str(uuid.UUID(package.destination))
except ValueError:
pass pass
else:
return self.find_client(nid=destination)
# Is it a link_type?
return self.find_client(link_type=package.destination)
async def route_package(self, package: Package) -> None:
"""Executed every time a package is received and must be routed somewhere."""
destinations = self.find_destination(package)
for destination in destinations:
await destination.send(package)
async def run(self):
websockets.serve(self.listener, host=self.address, port=self.port)