mirror of
https://github.com/RYGhub/royalnet.git
synced 2024-11-23 19:44:20 +00:00
More progress
This commit is contained in:
parent
3bc8046bc4
commit
11bbb77afe
6 changed files with 106 additions and 46 deletions
|
@ -14,13 +14,14 @@ class TelegramBot:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
master_server_uri: str,
|
master_server_uri: str,
|
||||||
|
master_server_secret: str,
|
||||||
commands: typing.List[typing.Type[Command]],
|
commands: typing.List[typing.Type[Command]],
|
||||||
missing_command: Command = NullCommand):
|
missing_command: Command = NullCommand):
|
||||||
self.bot: telegram.Bot = telegram.Bot(api_key)
|
self.bot: telegram.Bot = telegram.Bot(api_key)
|
||||||
self.should_run: bool = False
|
self.should_run: bool = False
|
||||||
self.offset: int = -100
|
self.offset: int = -100
|
||||||
self.missing_command: typing.Callable = missing_command
|
self.missing_command = missing_command
|
||||||
self.network: RoyalnetLink = RoyalnetLink(master_server_uri, "telegram", null)
|
self.network: RoyalnetLink = RoyalnetLink(master_server_uri, master_server_secret, "telegram", null)
|
||||||
# Generate commands
|
# Generate commands
|
||||||
self.commands = {}
|
self.commands = {}
|
||||||
for command in commands:
|
for command in commands:
|
||||||
|
|
|
@ -1,6 +1,16 @@
|
||||||
from .messages import Message, ErrorMessage, InvalidSecretErrorMessage
|
from .messages import Message, ErrorMessage, InvalidSecretEM, InvalidDestinationEM, InvalidPackageEM
|
||||||
|
from .packages import Package
|
||||||
from .royalnetlink import RoyalnetLink, NetworkError, NotConnectedError, NotIdentifiedError
|
from .royalnetlink import RoyalnetLink, NetworkError, NotConnectedError, NotIdentifiedError
|
||||||
from .packages import Package, TwoWayPackage
|
from .royalnetserver import RoyalnetServer
|
||||||
|
|
||||||
__all__ = ["Message", "ErrorMessage", "InvalidSecretErrorMessage", "RoyalnetLink", "NetworkError", "NotConnectedError",
|
__all__ = ["Message",
|
||||||
"NotIdentifiedError", "Package", "TwoWayPackage"]
|
"ErrorMessage",
|
||||||
|
"InvalidSecretEM",
|
||||||
|
"InvalidDestinationEM",
|
||||||
|
"InvalidPackageEM",
|
||||||
|
"RoyalnetLink",
|
||||||
|
"NetworkError",
|
||||||
|
"NotConnectedError",
|
||||||
|
"NotIdentifiedError",
|
||||||
|
"Package",
|
||||||
|
"RoyalnetServer"]
|
||||||
|
|
|
@ -12,9 +12,15 @@ class ErrorMessage(Message):
|
||||||
self.reason = reason
|
self.reason = reason
|
||||||
|
|
||||||
|
|
||||||
class BadMessage(ErrorMessage):
|
class InvalidSecretEM(ErrorMessage):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvalidSecretErrorMessage(BadMessage):
|
class InvalidPackageEM(ErrorMessage):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidDestinationEM(InvalidPackageEM):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,22 +3,15 @@ import uuid
|
||||||
|
|
||||||
|
|
||||||
class Package:
|
class Package:
|
||||||
def __init__(self, data, destination: str, *, conversation_id: str = None):
|
def __init__(self, data, destination: str, source: str, *, conversation_id: str = None):
|
||||||
self.data = data
|
self.data = data
|
||||||
self.destination: str = destination
|
self.destination: str = destination
|
||||||
|
self.source, = source
|
||||||
self.conversation_id = conversation_id or str(uuid.uuid4())
|
self.conversation_id = conversation_id or str(uuid.uuid4())
|
||||||
|
|
||||||
|
def reply(self, data) -> "Package":
|
||||||
|
return Package(data, self.source, self.destination, conversation_id=self.conversation_id)
|
||||||
|
|
||||||
def pickle(self):
|
def pickle(self):
|
||||||
return pickle.dumps(self)
|
return pickle.dumps(self)
|
||||||
|
|
||||||
|
|
||||||
class TwoWayPackage(Package):
|
|
||||||
def __init__(self, data, destination: str, source: str, *, conversation_id: str = None):
|
|
||||||
super().__init__(data, destination, conversation_id=conversation_id)
|
|
||||||
self.source = source
|
|
||||||
|
|
||||||
def reply(self, data) -> Package:
|
|
||||||
return Package(data, self.source, conversation_id=self.conversation_id)
|
|
||||||
|
|
||||||
def two_way_reply(self, data) -> "TwoWayPackage":
|
|
||||||
return TwoWayPackage(data, self.source, self.destination, conversation_id=self.conversation_id)
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import functools
|
||||||
import typing
|
import typing
|
||||||
import pickle
|
import pickle
|
||||||
from .messages import Message, ErrorMessage
|
from .messages import Message, ErrorMessage
|
||||||
from .packages import Package, TwoWayPackage
|
from .packages import Package
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,11 +35,12 @@ class PendingRequest:
|
||||||
|
|
||||||
|
|
||||||
class RoyalnetLink:
|
class RoyalnetLink:
|
||||||
def __init__(self, master_uri: str, link_type: str, request_handler):
|
def __init__(self, master_uri: str, secret: str, link_type: str, request_handler):
|
||||||
assert ":" not in link_type
|
assert ":" not in link_type
|
||||||
self.master_uri: str = master_uri
|
self.master_uri: str = master_uri
|
||||||
self.link_type: str = link_type
|
self.link_type: str = link_type
|
||||||
self.nid: str = str(uuid.uuid4())
|
self.nid: str = str(uuid.uuid4())
|
||||||
|
self.secret: str = secret
|
||||||
self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None
|
self.websocket: typing.Optional[websockets.WebSocketClientProtocol] = None
|
||||||
self.identified: bool = False
|
self.identified: bool = False
|
||||||
self.request_handler = request_handler
|
self.request_handler = request_handler
|
||||||
|
@ -48,12 +49,12 @@ class RoyalnetLink:
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
self.websocket = await websockets.connect(self.master_uri)
|
self.websocket = await websockets.connect(self.master_uri)
|
||||||
|
|
||||||
def requires_connection(self, func):
|
def requires_connection(func):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def new_func(*args, **kwargs):
|
def new_func(self, *args, **kwargs):
|
||||||
if self.websocket is None:
|
if self.websocket is None:
|
||||||
raise NotConnectedError("Tried to call a method which @requires_connection while not connected")
|
raise NotConnectedError("Tried to call a method which @requires_connection while not connected")
|
||||||
return func(*args, **kwargs)
|
return func(self, *args, **kwargs)
|
||||||
return new_func
|
return new_func
|
||||||
|
|
||||||
@requires_connection
|
@requires_connection
|
||||||
|
@ -65,7 +66,7 @@ class RoyalnetLink:
|
||||||
self.identified = False
|
self.identified = False
|
||||||
# What to do now? Let's just reraise.
|
# What to do now? Let's just reraise.
|
||||||
raise
|
raise
|
||||||
package: typing.Union[Package, TwoWayPackage] = pickle.loads(raw_pickle)
|
package: typing.Union[Package, Package] = pickle.loads(raw_pickle)
|
||||||
assert package.destination == self.nid
|
assert package.destination == self.nid
|
||||||
return package
|
return package
|
||||||
|
|
||||||
|
@ -78,12 +79,12 @@ class RoyalnetLink:
|
||||||
raise NetworkError(response, "Server returned error while identifying self")
|
raise NetworkError(response, "Server returned error while identifying self")
|
||||||
self.identified = True
|
self.identified = True
|
||||||
|
|
||||||
def requires_identification(self, func):
|
def requires_identification(func):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def new_func(*args, **kwargs):
|
def new_func(self, *args, **kwargs):
|
||||||
if not self.identified:
|
if not self.identified:
|
||||||
raise NotIdentifiedError("Tried to call a method which @requires_identification while not identified")
|
raise NotIdentifiedError("Tried to call a method which @requires_identification while not identified")
|
||||||
return func(*args, **kwargs)
|
return func(self, *args, **kwargs)
|
||||||
return new_func
|
return new_func
|
||||||
|
|
||||||
@requires_identification
|
@requires_identification
|
||||||
|
@ -93,7 +94,7 @@ class RoyalnetLink:
|
||||||
|
|
||||||
@requires_identification
|
@requires_identification
|
||||||
async def request(self, message, destination):
|
async def request(self, message, destination):
|
||||||
package = TwoWayPackage(message, destination, self.nid)
|
package = Package(message, destination, self.nid)
|
||||||
request = PendingRequest()
|
request = PendingRequest()
|
||||||
self._pending_requests[package.conversation_id] = request
|
self._pending_requests[package.conversation_id] = request
|
||||||
await self.send(package)
|
await self.send(package)
|
||||||
|
@ -103,7 +104,7 @@ class RoyalnetLink:
|
||||||
raise NetworkError(result, "Server returned error while requesting something")
|
raise NetworkError(result, "Server returned error while requesting something")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def run_link(self):
|
async def run(self):
|
||||||
while True:
|
while True:
|
||||||
if self.websocket is None:
|
if self.websocket is None:
|
||||||
await self.connect()
|
await self.connect()
|
||||||
|
@ -116,7 +117,8 @@ class RoyalnetLink:
|
||||||
request.set(package.data)
|
request.set(package.data)
|
||||||
continue
|
continue
|
||||||
# Package is a request
|
# Package is a request
|
||||||
assert isinstance(package, TwoWayPackage)
|
assert isinstance(package, Package)
|
||||||
response = await self.request_handler(package.data)
|
response = await self.request_handler(package.data)
|
||||||
|
if response is not None:
|
||||||
response_package: Package = package.reply(response)
|
response_package: Package = package.reply(response)
|
||||||
await self.send(response_package)
|
await self.send(response_package)
|
||||||
|
|
|
@ -2,8 +2,11 @@ import typing
|
||||||
import websockets
|
import websockets
|
||||||
import re
|
import re
|
||||||
import datetime
|
import datetime
|
||||||
from .messages import Message, ErrorMessage, BadMessage, InvalidSecretErrorMessage, IdentifySuccessfulMessage
|
import pickle
|
||||||
from .packages import Package, TwoWayPackage
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from .messages import Message, ErrorMessage, InvalidPackageEM, InvalidSecretEM, IdentifySuccessfulMessage
|
||||||
|
from .packages import Package
|
||||||
|
|
||||||
|
|
||||||
class ConnectedClient:
|
class ConnectedClient:
|
||||||
|
@ -17,38 +20,83 @@ class ConnectedClient:
|
||||||
def is_identified(self) -> bool:
|
def is_identified(self) -> bool:
|
||||||
return bool(self.nid)
|
return bool(self.nid)
|
||||||
|
|
||||||
|
async def send(self, package: Package):
|
||||||
|
self.socket.send(package.pickle())
|
||||||
|
|
||||||
|
|
||||||
class RoyalnetServer:
|
class RoyalnetServer:
|
||||||
def __init__(self, required_secret: str):
|
def __init__(self, address: str, port: int, required_secret: str):
|
||||||
|
self.address: str = address
|
||||||
|
self.port: int = port
|
||||||
self.required_secret: str = required_secret
|
self.required_secret: str = required_secret
|
||||||
self.connected_clients: typing.List[ConnectedClient] = {}
|
self.identified_clients: typing.List[ConnectedClient] = {}
|
||||||
self.server: websockets.server.WebSocketServer = websockets.server
|
|
||||||
|
|
||||||
def find_client_by_nid(self, nid: str):
|
def find_client(self, *, nid: str=None, link_type: str=None) -> typing.List[ConnectedClient]:
|
||||||
return [client for client in self.connected_clients if client.nid == nid][0]
|
assert not (nid and link_type)
|
||||||
|
if nid:
|
||||||
|
matching = [client for client in self.identified_clients if client.nid == nid]
|
||||||
|
assert len(matching) <= 1
|
||||||
|
return matching
|
||||||
|
if link_type:
|
||||||
|
matching = [client for client in self.identified_clients if client.link_type == link_type]
|
||||||
|
return matching or []
|
||||||
|
|
||||||
async def listener(self, websocket: websockets.server.WebSocketServerProtocol, request_uri: str):
|
async def listener(self, websocket: websockets.server.WebSocketServerProtocol, request_uri: str):
|
||||||
connected_client = ConnectedClient(websocket)
|
connected_client = ConnectedClient(websocket)
|
||||||
# Wait for identification
|
# Wait for identification
|
||||||
identify_msg = websocket.recv()
|
identify_msg = websocket.recv()
|
||||||
if not isinstance(identify_msg, str):
|
if not isinstance(identify_msg, str):
|
||||||
websocket.send(BadMessage("Invalid identification message (not a str)"))
|
websocket.send(InvalidPackageEM("Invalid identification message (not a str)"))
|
||||||
return
|
return
|
||||||
identification = re.match(r"Identify ([A-Za-z0-9\-]+):([a-z]+):([A-Za-z0-9\-])", identify_msg)
|
identification = re.match(r"Identify ([A-Za-z0-9\-]+):([a-z]+):([A-Za-z0-9\-])", identify_msg)
|
||||||
if identification is None:
|
if identification is None:
|
||||||
websocket.send(BadMessage("Invalid identification message (regex failed)"))
|
websocket.send(InvalidPackageEM("Invalid identification message (regex failed)"))
|
||||||
return
|
return
|
||||||
secret = identification.group(3)
|
secret = identification.group(3)
|
||||||
if secret != self.required_secret:
|
if secret != self.required_secret:
|
||||||
websocket.send(InvalidSecretErrorMessage("Invalid secret"))
|
websocket.send(InvalidSecretEM("Invalid secret"))
|
||||||
return
|
return
|
||||||
# Identification successful
|
# Identification successful
|
||||||
connected_client.nid = identification.group(1)
|
connected_client.nid = identification.group(1)
|
||||||
connected_client.link_type = identification.group(2)
|
connected_client.link_type = identification.group(2)
|
||||||
self.connected_clients.append(connected_client)
|
self.identified_clients.append(connected_client)
|
||||||
|
await connected_client.send(Package(IdentifySuccessfulMessage(), connected_client.nid, "__master__"))
|
||||||
# Main loop
|
# Main loop
|
||||||
while True:
|
while True:
|
||||||
|
# Receive packages
|
||||||
|
raw_pickle = await websocket.recv()
|
||||||
|
package: Package = pickle.loads(raw_pickle)
|
||||||
|
# Check if the package destination is the server itself.
|
||||||
|
if package.destination == "__master__":
|
||||||
|
# TODO: do stuff
|
||||||
pass
|
pass
|
||||||
|
# Otherwise, route the package to its destination
|
||||||
|
asyncio.create_task(self.route_package(package))
|
||||||
|
|
||||||
|
def find_destination(self, package: Package) -> typing.List[ConnectedClient]:
|
||||||
|
"""Find a list of destinations for the sent packages"""
|
||||||
|
# Parse destination
|
||||||
|
# Is it nothing?
|
||||||
|
if package.destination == "NULL":
|
||||||
|
return []
|
||||||
|
# Is it the wildcard?
|
||||||
|
if package.destination == "*":
|
||||||
|
return self.identified_clients
|
||||||
|
# Is it a valid nid?
|
||||||
|
try:
|
||||||
|
destination = str(uuid.UUID(package.destination))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
return self.find_client(nid=destination)
|
||||||
|
# Is it a link_type?
|
||||||
|
return self.find_client(link_type=package.destination)
|
||||||
|
|
||||||
|
async def route_package(self, package: Package) -> None:
|
||||||
|
"""Executed every time a package is received and must be routed somewhere."""
|
||||||
|
destinations = self.find_destination(package)
|
||||||
|
for destination in destinations:
|
||||||
|
await destination.send(package)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
websockets.serve(self.listener, host=self.address, port=self.port)
|
||||||
|
|
Loading…
Reference in a new issue