diff --git a/.gitignore b/.gitignore index 981870b..e83cfc0 100644 --- a/.gitignore +++ b/.gitignore @@ -102,5 +102,6 @@ ENV/ .idea/ config/config.ini +config/config.toml *.sqlite *.sqlite-journal \ No newline at end of file diff --git a/core.py b/core.py index cb9d90e..808d6a4 100644 --- a/core.py +++ b/core.py @@ -7,6 +7,10 @@ import threading import localization import logging import duckbot +import sqlalchemy +import sqlalchemy.orm +import sqlalchemy.ext.declarative as sed +import database try: import coloredlogs @@ -33,8 +37,8 @@ def main(): if not os.path.isfile("config/config.toml"): log.debug("config/config.toml does not exist.") - with open("config/template_config.ini", encoding="utf8") as template_cfg_file, \ - open("config/config.ini", "w", encoding="utf8") as user_cfg_file: + with open("config/template_config.toml", encoding="utf8") as template_cfg_file, \ + open("config/config.toml", "w", encoding="utf8") as user_cfg_file: # Copy the template file to the config file user_cfg_file.write(template_cfg_file.read()) @@ -43,8 +47,8 @@ def main(): exit(1) # Compare the template config with the user-made one - with open("config/template_config.ini", encoding="utf8") as template_cfg_file, \ - open("config/config.ini", encoding="utf8") as user_cfg_file: + with open("config/template_config.toml", encoding="utf8") as template_cfg_file, \ + open("config/config.toml", encoding="utf8") as user_cfg_file: template_cfg = nuconfig.NuConfig(template_cfg_file) user_cfg = nuconfig.NuConfig(user_cfg_file) if not template_cfg.cmplog(user_cfg): @@ -65,6 +69,16 @@ def main(): # Ignore most python-telegram-bot logs, as they are useless most of the time logging.getLogger("telegram").setLevel("ERROR") + # Create the database engine + log.debug("Creating the sqlalchemy engine...") + engine = sqlalchemy.create_engine(user_cfg["Database"]["engine"]) + log.debug("Preparing the tables through deferred reflection...") + sed.DeferredReflection.prepare(engine) + log.debug("Binding metadata to the engine...") + database.TableDeclarativeBase.metadata.bind = engine + log.debug("Creating all missing tables...") + database.TableDeclarativeBase.metadata.create_all() + # Create a bot instance bot = duckbot.factory(user_cfg)() @@ -122,7 +136,8 @@ def main(): new_worker = worker.Worker(bot=bot, chat=update.message.chat, telegram_user=update.message.from_user, - cfg=user_cfg) + cfg=user_cfg, + engine=engine) # Start the worker log.debug(f"Starting {new_worker.name}") new_worker.start() diff --git a/database.py b/database.py index 7b6d521..1066cad 100644 --- a/database.py +++ b/database.py @@ -1,28 +1,24 @@ import typing from sqlalchemy import create_engine, Column, ForeignKey, UniqueConstraint from sqlalchemy import Integer, BigInteger, String, Text, LargeBinary, DateTime, Boolean -from sqlalchemy.orm import sessionmaker, relationship, backref -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship, backref +from sqlalchemy.ext.declarative import declarative_base, DeferredReflection import telegram import requests import utils -import localization import logging +if typing.TYPE_CHECKING: + import worker + + log = logging.getLogger(__name__) -# Create a (lazy) database engine -engine = create_engine(configloader.config["Database"]["engine"]) - # Create a base class to define all the database subclasses -TableDeclarativeBase = declarative_base(bind=engine) - -# Create a Session class able to initialize database sessions -Session = sessionmaker() - +TableDeclarativeBase = declarative_base() # Define all the database tables using the sqlalchemy declarative base -class User(TableDeclarativeBase): +class User(DeferredReflection, TableDeclarativeBase): """A Telegram user who used the bot at least once.""" # Telegram data @@ -38,16 +34,18 @@ class User(TableDeclarativeBase): # Extra table parameters __tablename__ = "users" - def __init__(self, telegram_user: telegram.User, **kwargs): + def __init__(self, w: "worker.Worker", **kwargs): # Initialize the super super().__init__(**kwargs) # Get the data from telegram - self.user_id = telegram_user.id - self.first_name = telegram_user.first_name - self.last_name = telegram_user.last_name - self.username = telegram_user.username - self.language = telegram_user.language_code if telegram_user.language_code else configloader.config["Language"][ - "default_language"] + self.user_id = w.telegram_user.id + self.first_name = w.telegram_user.first_name + self.last_name = w.telegram_user.last_name + self.username = w.telegram_user.username + if w.telegram_user.language_code: + self.language = w.telegram_user.language_code + else: + self.language = w.cfg["Language"]["default_language"] # The starting wallet value is 0 self.credit = 0 @@ -87,7 +85,7 @@ class User(TableDeclarativeBase): return f"" -class Product(TableDeclarativeBase): +class Product(DeferredReflection, TableDeclarativeBase): """A purchasable product.""" # Product id @@ -108,37 +106,37 @@ class Product(TableDeclarativeBase): # No __init__ is needed, the default one is sufficient - def text(self, *, loc: localization.Localization, style: str = "full", cart_qty: int = None): + def text(self, w: "worker.Worker", *, style: str = "full", cart_qty: int = None): """Return the product details formatted with Telegram HTML. The image is omitted.""" if style == "short": - return f"{cart_qty}x {utils.telegram_html_escape(self.name)} - {str(utils.Price(self.price, loc) * cart_qty)}" + return f"{cart_qty}x {utils.telegram_html_escape(self.name)} - {str(w.Price(self.price) * cart_qty)}" elif style == "full": if cart_qty is not None: - cart = loc.get("in_cart_format_string", quantity=cart_qty) + cart = w.loc.get("in_cart_format_string", quantity=cart_qty) else: cart = '' - return loc.get("product_format_string", name=utils.telegram_html_escape(self.name), - description=utils.telegram_html_escape(self.description), - price=str(utils.Price(self.price, loc)), - cart=cart) + return w.loc.get("product_format_string", name=utils.telegram_html_escape(self.name), + description=utils.telegram_html_escape(self.description), + price=str(w.Price(self.price)), + cart=cart) else: raise ValueError("style is not an accepted value") def __repr__(self): return f"" - def send_as_message(self, loc: localization.Localization, chat_id: int) -> dict: + def send_as_message(self, w: "worker.Worker", chat_id: int) -> dict: """Send a message containing the product data.""" if self.image is None: - r = requests.get(f"https://api.telegram.org/bot{configloader.config['Telegram']['token']}/sendMessage", + r = requests.get(f"https://api.telegram.org/bot{w.cfg['Telegram']['token']}/sendMessage", params={"chat_id": chat_id, - "text": self.text(loc=loc), + "text": self.text(w), "parse_mode": "HTML"}) else: - r = requests.post(f"https://api.telegram.org/bot{configloader.config['Telegram']['token']}/sendPhoto", + r = requests.post(f"https://api.telegram.org/bot{w.cfg['Telegram']['token']}/sendPhoto", files={"photo": self.image}, params={"chat_id": chat_id, - "caption": self.text(loc=loc), + "caption": self.text(w), "parse_mode": "HTML"}) return r.json() @@ -151,7 +149,7 @@ class Product(TableDeclarativeBase): self.image = r.content -class Transaction(TableDeclarativeBase): +class Transaction(DeferredReflection, TableDeclarativeBase): """A greed wallet transaction. Wallet credit ISN'T calculated from these, but they can be used to recalculate it.""" # TODO: split this into multiple tables @@ -187,10 +185,10 @@ class Transaction(TableDeclarativeBase): __tablename__ = "transactions" __table_args__ = (UniqueConstraint("provider", "provider_charge_id"),) - def text(self, *, loc: localization.Localization): - string = f"T{self.transaction_id} | {str(self.user)} | {utils.Price(self.value, loc)}" + def text(self, w: "worker.Worker"): + string = f"T{self.transaction_id} | {str(self.user)} | {w.Price(self.value)}" if self.refunded: - string += f" | {loc.get('emoji_refunded')}" + string += f" | {w.loc['emoji_refunded']}" if self.provider: string += f" | {self.provider}" if self.notes: @@ -201,7 +199,7 @@ class Transaction(TableDeclarativeBase): return f"" -class Admin(TableDeclarativeBase): +class Admin(DeferredReflection, TableDeclarativeBase): """A greed administrator with his permissions.""" # The telegram id @@ -223,7 +221,7 @@ class Admin(TableDeclarativeBase): return f"" -class Order(TableDeclarativeBase): +class Order(DeferredReflection, TableDeclarativeBase): """An order which has been placed by an user. It may include multiple products, available in the OrderItem table.""" @@ -253,41 +251,41 @@ class Order(TableDeclarativeBase): def __repr__(self): return f"" - def text(self, *, loc: localization.Localization, session, user=False): + def text(self, w: "worker.Worker", session, user=False): joined_self = session.query(Order).filter_by(order_id=self.order_id).join(Transaction).one() items = "" for item in self.items: - items += item.text(loc=loc) + "\n" + items += item.text(w) + "\n" if self.delivery_date is not None: - status_emoji = loc.get("emoji_completed") - status_text = loc.get("text_completed") + status_emoji = w.loc.get("emoji_completed") + status_text = w.loc.get("text_completed") elif self.refund_date is not None: - status_emoji = loc.get("emoji_refunded") - status_text = loc.get("text_refunded") + status_emoji = w.loc.get("emoji_refunded") + status_text = w.loc.get("text_refunded") else: - status_emoji = loc.get("emoji_not_processed") - status_text = loc.get("text_not_processed") - if user and configloader.config["Appearance"]["full_order_info"] == "no": - return loc.get("user_order_format_string", - status_emoji=status_emoji, - status_text=status_text, - items=items, - notes=self.notes, - value=str(utils.Price(-joined_self.transaction.value, loc))) + \ - (loc.get("refund_reason", reason=self.refund_reason) if self.refund_date is not None else "") + status_emoji = w.loc.get("emoji_not_processed") + status_text = w.loc.get("text_not_processed") + if user and w.cfg["Appearance"]["full_order_info"] == "no": + return w.loc.get("user_order_format_string", + status_emoji=status_emoji, + status_text=status_text, + items=items, + notes=self.notes, + value=str(w.Price(-joined_self.transaction.value))) + \ + (w.loc.get("refund_reason", reason=self.refund_reason) if self.refund_date is not None else "") else: return status_emoji + " " + \ - loc.get("order_number", id=self.order_id) + "\n" + \ - loc.get("order_format_string", - user=self.user.mention(), - date=self.creation_date.isoformat(), - items=items, - notes=self.notes if self.notes is not None else "", - value=str(utils.Price(-joined_self.transaction.value, loc))) + \ - (loc.get("refund_reason", reason=self.refund_reason) if self.refund_date is not None else "") + w.loc.get("order_number", id=self.order_id) + "\n" + \ + w.loc.get("order_format_string", + user=self.user.mention(), + date=self.creation_date.isoformat(), + items=items, + notes=self.notes if self.notes is not None else "", + value=str(w.Price(-joined_self.transaction.value))) + \ + (w.loc.get("refund_reason", reason=self.refund_reason) if self.refund_date is not None else "") -class OrderItem(TableDeclarativeBase): +class OrderItem(DeferredReflection, TableDeclarativeBase): """A product that has been purchased as part of an order.""" # The unique item id @@ -301,11 +299,8 @@ class OrderItem(TableDeclarativeBase): # Extra table parameters __tablename__ = "orderitems" - def text(self, *, loc: localization.Localization): - return f"{self.product.name} - {str(utils.Price(self.product.price, loc))}" + def text(self, w: "worker.Worker"): + return f"{self.product.name} - {str(w.Price(self.product.price))}" def __repr__(self): return f"" - - -TableDeclarativeBase.metadata.create_all() diff --git a/worker.py b/worker.py index 9f17862..5aadf32 100644 --- a/worker.py +++ b/worker.py @@ -14,6 +14,7 @@ from html import escape import requests import logging import localization +import sqlalchemy.orm log = logging.getLogger(__name__) @@ -37,6 +38,7 @@ class Worker(threading.Thread): chat: telegram.Chat, telegram_user: telegram.User, cfg: nuconfig.NuConfig, + engine, *args, **kwargs): # Initialize the thread @@ -48,7 +50,7 @@ class Worker(threading.Thread): self.cfg = cfg # Open a new database session log.debug(f"Opening new database session for {self.name}") - self.session = db.Session() + self.session = sqlalchemy.orm.sessionmaker(bind=engine)() # Get the user db data from the users and admin tables self.user: Optional[db.User] = None self.admin: Optional[db.Admin] = None @@ -132,7 +134,6 @@ class Worker(threading.Thread): return Price(Price(other).value - self.value) def __rmul__(self, other): - return self.__mul__(other) def __iadd__(self, other): @@ -165,7 +166,7 @@ class Worker(threading.Thread): # Check if there are other registered users: if there aren't any, the first user will be owner of the bot will_be_owner = (self.session.query(db.Admin).first() is None) # Create the new record - self.user = db.User(self.telegram_user) + self.user = db.User(w=self) # Add the new record to the db self.session.add(self.user) # Flush the session to get an userid @@ -502,7 +503,7 @@ class Worker(threading.Thread): if product.price is None: continue # Send the message without the keyboard to get the message id - message = product.send_as_message(loc=self.loc, chat_id=self.chat.id) + message = product.send_as_message(w=self, chat_id=self.chat.id) # Add the product to the cart cart[message['result']['message_id']] = [product, 0] # Create the inline keyboard to add the product to the cart @@ -513,12 +514,12 @@ class Worker(threading.Thread): if product.image is None: self.bot.edit_message_text(chat_id=self.chat.id, message_id=message['result']['message_id'], - text=product.text(loc=self.loc), + text=product.text(w=self), reply_markup=inline_keyboard) else: self.bot.edit_message_caption(chat_id=self.chat.id, message_id=message['result']['message_id'], - caption=product.text(loc=self.loc), + caption=product.text(w=self), reply_markup=inline_keyboard) # Create the keyboard with the cancel button inline_keyboard = telegram.InlineKeyboardMarkup([[telegram.InlineKeyboardButton(self.loc.get("menu_cancel"), @@ -562,13 +563,13 @@ class Worker(threading.Thread): if product.image is None: self.bot.edit_message_text(chat_id=self.chat.id, message_id=callback.message.message_id, - text=product.text(loc=self.loc, + text=product.text(w=self, cart_qty=cart[callback.message.message_id][1]), reply_markup=product_inline_keyboard) else: self.bot.edit_message_caption(chat_id=self.chat.id, message_id=callback.message.message_id, - caption=product.text(loc=self.loc, + caption=product.text(w=self, cart_qty=cart[callback.message.message_id][1]), reply_markup=product_inline_keyboard) @@ -610,13 +611,13 @@ class Worker(threading.Thread): # Edit the product message if product.image is None: self.bot.edit_message_text(chat_id=self.chat.id, message_id=callback.message.message_id, - text=product.text(loc=self.loc, + text=product.text(w=self, cart_qty=cart[callback.message.message_id][1]), reply_markup=product_inline_keyboard) else: self.bot.edit_message_caption(chat_id=self.chat.id, message_id=callback.message.message_id, - caption=product.text(loc=self.loc, + caption=product.text(w=self, cart_qty=cart[callback.message.message_id][1]), reply_markup=product_inline_keyboard) @@ -684,7 +685,7 @@ class Worker(threading.Thread): product_list = "" for product_id in cart: if cart[product_id][1] > 0: - product_list += cart[product_id][0].text(loc=self.loc, + product_list += cart[product_id][0].text(w=self, style="short", cart_qty=cart[product_id][1]) + "\n" return product_list @@ -706,7 +707,7 @@ class Worker(threading.Thread): def __order_notify_admins(self, order): # Notify the user of the order result - self.bot.send_message(self.chat.id, self.loc.get("success_order_created", order=order.text(loc=self.loc, + self.bot.send_message(self.chat.id, self.loc.get("success_order_created", order=order.text(w=self, session=self.session, user=True))) # Notify the admins (in Live Orders mode) of the new order @@ -721,7 +722,7 @@ class Worker(threading.Thread): for admin in admins: self.bot.send_message(admin.user_id, self.loc.get('notification_order_placed', - order=order.text(loc=self.loc, session=self.session)), + order=order.text(w=self, session=self.session)), reply_markup=order_keyboard) def __order_status(self): @@ -738,7 +739,7 @@ class Worker(threading.Thread): self.bot.send_message(self.chat.id, self.loc.get("error_no_orders")) # Display the order status to the user for order in orders: - self.bot.send_message(self.chat.id, order.text(loc=self.loc, session=self.session, user=True)) + self.bot.send_message(self.chat.id, order.text(w=self, session=self.session, user=True)) # TODO: maybe add a page displayer instead of showing the latest 5 orders def __add_credit_menu(self): @@ -1128,7 +1129,7 @@ class Worker(threading.Thread): # Create a message for every one of them for order in orders: # Send the created message - self.bot.send_message(self.chat.id, order.text(loc=self.loc, session=self.session), + self.bot.send_message(self.chat.id, order.text(w=self, session=self.session), reply_markup=order_keyboard) # Set the Live mode flag to True self.admin.live_mode = True @@ -1157,11 +1158,11 @@ class Worker(threading.Thread): # Commit the transaction self.session.commit() # Update order message - self.bot.edit_message_text(order.text(loc=self.loc, session=self.session), chat_id=self.chat.id, + self.bot.edit_message_text(order.text(w=self, session=self.session), chat_id=self.chat.id, message_id=update.message.message_id) # Notify the user of the completition self.bot.send_message(order.user_id, - self.loc.get("notification_order_completed", order=order.text(loc=self.loc, session=self.session, user=True))) + self.loc.get("notification_order_completed", order=order.text(w=self, session=self.session, user=True))) # If the user pressed the refund order button, refund the order... elif update.data == "order_refund": # Ask for a refund reason @@ -1185,12 +1186,12 @@ class Worker(threading.Thread): # Commit the changes self.session.commit() # Update the order message - self.bot.edit_message_text(order.text(loc=self.loc, session=self.session), + self.bot.edit_message_text(order.text(w=self, session=self.session), chat_id=self.chat.id, message_id=update.message.message_id) # Notify the user of the refund self.bot.send_message(order.user_id, - self.loc.get("notification_order_refunded", order=order.text(loc=self.loc, + self.loc.get("notification_order_refunded", order=order.text(w=self, session=self.session, user=True))) # Notify the admin of the refund @@ -1236,10 +1237,10 @@ class Worker(threading.Thread): # Notify the user of the credit/debit self.bot.send_message(user.user_id, self.loc.get("notification_transaction_created", - transaction=transaction.text(loc=self.loc))) + transaction=transaction.text(w=self))) # Notify the admin of the success self.bot.send_message(self.chat.id, self.loc.get("success_transaction_created", - transaction=transaction.text(loc=self.loc))) + transaction=transaction.text(w=self))) def __help_menu(self): """Help menu. Allows the user to ask for assistance, get a guide or see some info about the bot.""" @@ -1305,7 +1306,7 @@ class Worker(threading.Thread): # Create the inline keyboard markup inline_keyboard = telegram.InlineKeyboardMarkup(inline_keyboard_list) # Create the message text - transactions_string = "\n".join([transaction.text(loc=self.loc) for transaction in transactions]) + transactions_string = "\n".join([transaction.text(w=self) for transaction in transactions]) text = self.loc.get("transactions_page", page=page + 1, transactions=transactions_string) # Update the previously sent message self.bot.edit_message_text(chat_id=self.chat.id, message_id=message.message_id, text=text,