diff --git a/poetry.lock b/poetry.lock index defd91fc..bf67f560 100644 --- a/poetry.lock +++ b/poetry.lock @@ -512,7 +512,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "c30fdf09f35a2430d64998b966b2648d1bed68a9a1e58e45339f5a1a53895263" +content-hash = "331ed6c6b9a070807ccc984fcda1e4d8c08da999f49a96195a7f3c25356fbd73" [metadata.files] aiohttp = [ diff --git a/pyproject.toml b/pyproject.toml index df7a474a..213c92e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ SQLAlchemy-Utils = "^0.37.0" arrow = "^1.0.3" colour = "^0.1.5" royalspells = "^3.2" +async-timeout = "^3.0.1" [tool.poetry.dev-dependencies] diff --git a/royalpack/commands/login.py b/royalpack/commands/login.py index f2aca537..6cabf837 100644 --- a/royalpack/commands/login.py +++ b/royalpack/commands/login.py @@ -1,20 +1,216 @@ +import royalnet.royaltyping as t import royalnet.engineer as engi import sqlalchemy.sql as ss import sqlalchemy.orm as so import royalpack.database as db import royalpack.config as cfg -import royalnet_telethon import royalnet_telethon.bullet.contents import aiohttp import asyncio import logging import arrow -import datetime +import async_timeout + log = logging.getLogger(__name__) # FIXME: Properly handle errors in this function! + +async def enforce_private_message(msg: engi.Message) -> engi.Channel: + """ + Get the private chat for an user and notify them of the switch. + + :param msg: The :class:`~.engi.Message` to reply to. + :return: The private :class:`~.engi.Channel`. + """ + + log.debug("Sliding into DMs...") + + sender: engi.User = await msg.sender + current: engi.Channel = await msg.channel + private: engi.Channel = await sender.slide() + if hash(current) != hash(private): + await msg.reply(text="πŸ‘€ Ti sto inviando un messaggio in chat privata contenente le istruzioni per il login!") + return private + + +async def device_code_request( + http_session: aiohttp.ClientSession, + client_id: str, + device_url: str, + scopes: list[str], +) -> t.JSON: + """ + Request a OAuth2 device code (which can be exchanged for an access token once the user has given us the + authorization to do so). + + :param http_session: The :class:`aiohttp.ClientSession` to use. + :param client_id: The OAuth2 Client ID. + :param device_url: The URL where device codes can be obtained. + :param scopes: A :class:`list` of scopes to require from the user. + + :return: The JSON response received from the Identity Provider. + """ + + log.debug("Requesting device code...") + + async with http_session.post(device_url, data={ + "client_id": client_id, + "scope": " ".join(scopes), + }) as request: + return await request.json() + + +async def prompt_login(channel: engi.Channel, verification_url: str, user_code: str) -> None: + """ + Ask the user to login. + + :param channel: The :class:`~.engi.Channel` to send the message in. + :param verification_url: The URL where the user can approve / reject the token. + :param user_code: Human-friendly view of the device code. + """ + + log.debug("Asking user to login...") + + await channel.send_message( + text=f"🌍 Effettua il RYGlogin al seguente URL, poi premi Confirm:\n" + f"{verification_url}\n" + f"\n" + f"(Codice: {user_code})" + ) + + +async def device_code_exchange( + http_session: aiohttp.ClientSession, + client_id: str, + token_url: str, + device_code: str, + sleep_time: float, +): + """ + Check if the user has authorized the device code, and try to exchange it for an access token. + + :return: The JSON response received from the Identity Provider. + """ + + log.debug("Starting validation process...") + + while True: + log.debug(f"Sleeping for {sleep_time}s...") + await asyncio.sleep(sleep_time) + + async with http_session.post(token_url, data={ + "client_id": client_id, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + "device_code": device_code, + }) as request: + response = await request.json() + if "error" in response: + log.debug(f"Response returned error {response['error']!r}, retrying...") + continue + elif "access_token" in response: + log.debug(f"Obtained access token!") + return response + else: + log.error(f"Didn't get an access token, but didn't get an error either?!") + continue + + +async def get_user_info( + http_session: aiohttp.ClientSession, + userinfo_url: str, + token_type: str, + access_token: str, +): + """ + Get the userinfo of an user. + + :param http_session: The :class:`aiohttp.ClientSession` to use. + :param userinfo_url: The URL where the user info is obtained. + :param token_type: The type of the token returned by the Identity Provider (usually ``"Bearer"``) + :param access_token: The access token to use. + :return: + """ + + log.debug("Getting user info...") + + async with http_session.post(userinfo_url, headers={ + "Authorization": f"{token_type} {access_token}" + }) as request: + return await request.json() + + +async def notify_expiration(channel: engi.Channel) -> None: + """ + Notify the user of the device code expiration. + + :param channel: The :class:`~.engi.Channel` to send the message in. + """ + + log.debug("Notifying the user of the expiration...") + + await channel.send_message( + text=f"πŸ•’ Il codice dispositivo Γ¨ scaduto e il login Γ¨ stato annullato. " + f"Fai il login piΓΉ in fretta la prossima volta! :)", + ) + + +async def register_user_generic( + session: so.Session, + user_info: dict[str, t.Any], +) -> db.User: + """ + Sync the user info with the data inside the database. + + :param session: The :class:`~.so.Session` to use. + :param user_info: The user_info obtained by the Identity Provider. + :return: The created/updated :class:`.db.User`. + """ + + log.debug("Syncing generic user...") + + user = db.User( + sub=user_info['sub'], + last_update=arrow.now(), + name=user_info['name'], + nickname=user_info['nickname'], + avatar=user_info['picture'], + email=user_info['email'], + ) + session.merge(user) + return user + + +async def register_user_telethon( + session: so.Session, + user_info: dict[str, t.Any], + telethon_user, +) -> db.TelegramAccount: + """ + Sync an user's Telegram account via a Telethon message. + + :param session: The :class:`~.so.Session` to use. + :param user_info: The user_info obtained by the Identity Provider. + :param telethon_user: The telethon user to base the user data on. + :return: The created/updated :class:`~.db.TelegramAccount` + """ + + log.debug("Syncing telethon user...") + + tg = db.TelegramAccount( + user_fk=user_info["sub"], + id=telethon_user.id, + first_name=telethon_user.first_name, + last_name=telethon_user.last_name, + username=telethon_user.username, + avatar_url=None, # TODO: avatars + ) + session.merge(tg) + return tg + + + @engi.use_database(db.lazy_session_class) @engi.TeleportingConversation async def login(*, _msg: engi.Message, _session: so.Session, _imp, **__): @@ -24,119 +220,56 @@ async def login(*, _msg: engi.Message, _session: so.Session, _imp, **__): log.debug("Evaluating config...") config = cfg.lazy_config.evaluate() - log.debug("Sliding into DMs...") - sender: engi.User = await _msg.sender - current: engi.Channel = await _msg.channel - private: engi.Channel = await sender.slide() - if hash(current) != hash(private): - await _msg.reply(text="πŸ‘€ Ti ho inviato un messaggio in chat privata contenente le istruzioni per il login!") + private = await enforce_private_message(msg=_msg) async with aiohttp.ClientSession() as http_session: - log.debug("Generating device code...") - async with http_session.post(config["auth.url.device"], data={ - "client_id": config["auth.client.id"], - "scope": "profile email openid", - "prompt": "consent", - }) as request: - response = await request.json() - start = arrow.now() + dc = await device_code_request( + http_session=http_session, + client_id=config["auth.client.id"], + device_url=config["auth.url.device"], + scopes=["profile", "email", "openid"], + ) - log.debug("Asking user to login...") - await private.send_message( - text=f"🌍 Effettua il RYGlogin al seguente URL, poi premi Confirm:\n" - f"{response['verification_uri_complete']}\n" - f"\n" - f"(Codice: {response['user_code']})" + await prompt_login( + channel=private, + verification_url=dc['verification_uri_complete'], + user_code=dc['user_code'] + ) + + try: + async with async_timeout.timeout(dc["expires_in"]): + at = await device_code_exchange( + http_session=http_session, + client_id=config["auth.client.id"], + token_url=config["auth.url.token"], + device_code=dc["device_code"], + sleep_time=9 + ) + except asyncio.TimeoutError: + await notify_expiration( + channel=private ) - - expiration = start + datetime.timedelta(seconds=response["expires_in"]) - while arrow.now() < expiration: - log.debug("Sleeping for 10 seconds...") - await asyncio.sleep(10) - - async with http_session.post(config["auth.url.token"], data={ - "client_id": config["auth.client.id"], - "grant_type": "urn:ietf:params:oauth:grant-type:device_code", - "device_code": response["device_code"], - }) as request: - response = await request.json() - if "error" in response: - log.debug(f"Response returned error {response['error']!r}, retrying...") - continue - elif "access_token" in response: - log.debug(f"Obtained access token...") - break - else: - log.error(f"Didn't get an access token, but didn't get an error either?!") - continue - else: - log.debug("Login request expired.") - await private.send_message(text="πŸ•’ La tua richiesta di login Γ¨ scaduta. " - "Riinvia il comando per ricominciare!") return - async with http_session.post(config["auth.url.userinfo"], headers={ - "Authorization": f"{response['token_type']} {response['access_token']}" - }) as request: - response = await request.json() + ui = await get_user_info( + http_session=http_session, + userinfo_url=config["auth.url.userinfo"], + token_type=at["token_type"], + access_token=at["access_token"], + ) - log.debug("Checking if the user already exists...") - user: db.User = _session.execute( - ss.select(db.User).where(db.User.sub == response["sub"]) - ).scalar() - - log.debug("Creating user dict...") - user_dict = { - "sub": response['sub'], - "last_update": arrow.now(), - "name": response['name'], - "nickname": response['nickname'], - "avatar": response['picture'], - "email": response['email'], - } - - if user is None: - log.info(f"Creating new user: {response['sub']}") - user = db.User(**user_dict) - _session.add(user) - else: - log.debug(f"Updating existing user: {response['sub']}") - user.update(**user_dict) + user = await register_user_generic(session=_session, user_info=ui) if isinstance(_imp, royalnet_telethon.TelethonPDAImplementation): - log.debug("Found out I'm running on Telethon...") - - sender: royalnet_telethon.bullet.contents.TelegramUser - - log.debug("Checking if the TelegramAccount already exists...") - tg: db.TelegramAccount = _session.execute( - ss.select(db.TelegramAccount).where(db.TelegramAccount.id == sender._user.id) - ).scalar() - - log.debug("Creating tg_dict...") - tg_dict = { - "user_fk": response["sub"], - "id": sender._user.id, - "first_name": sender._user.first_name, - "last_name": sender._user.last_name, - "username": sender._user.username, - "avatar_url": None, # TODO: avatars - } - - if tg is None: - log.info(f"Creating new TelegramAccount: {sender._user.id}") - tg = db.TelegramAccount(**tg_dict) - _session.add(tg) - else: - log.debug(f"Updating existing TelegramAccount: {sender._user.id}") - tg.update(**tg_dict) + sender = await _msg.sender + tg = await register_user_telethon(session=_session, user_info=ui, telethon_user=sender._user) log.debug(f"Committing session...") _session.commit() log.debug(f"Done, notifying the user...") - await private.send_message(text=f"βœ… Login riuscito! Sei loggato come {response['name']}!") + await private.send_message(text=f"βœ… Login riuscito! Sei loggato come {user.name}!") __all__ = ("login",)