1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-10-16 22:37:31 +00:00
bda-6-steffo/unimore_bda_6/__main__.py

85 lines
3.6 KiB
Python
Raw Normal View History

import logging
2023-02-04 00:36:42 +00:00
import tensorflow
from .config import config, DATA_SET_SIZE
2023-02-08 09:54:14 +00:00
from .database import mongo_client_from_config, reviews_collection, sample_reviews_polar, sample_reviews_varied, store_cache, load_cache, delete_cache
2023-02-03 22:27:44 +00:00
from .analysis.nltk_sentiment import NLTKSentimentAnalyzer
2023-02-04 00:36:42 +00:00
from .analysis.tf_text import TensorflowSentimentAnalyzer
2023-02-08 09:54:14 +00:00
from .analysis.base import TrainingFailedError
from .tokenizer import LowercaseTokenizer
2023-02-01 16:46:25 +00:00
from .log import install_log_handler
2023-02-01 03:20:09 +00:00
log = logging.getLogger(__name__)
2023-02-01 01:33:42 +00:00
def main():
2023-02-04 00:36:42 +00:00
if len(tensorflow.config.list_physical_devices(device_type="GPU")) == 0:
log.warning("Tensorflow reports no GPU acceleration available.")
else:
log.debug("Tensorflow successfully found GPU acceleration!")
2023-02-08 09:54:14 +00:00
try:
delete_cache("./data/training")
delete_cache("./data/evaluation")
except FileNotFoundError:
pass
2023-02-04 05:14:24 +00:00
for dataset_func in [sample_reviews_polar, sample_reviews_varied]:
for SentimentAnalyzer in [TensorflowSentimentAnalyzer, NLTKSentimentAnalyzer]:
for Tokenizer in [
# NLTKWordTokenizer,
# PottsTokenizer,
# PottsTokenizerWithNegation,
LowercaseTokenizer,
]:
2023-02-08 09:54:14 +00:00
while True:
try:
tokenizer = Tokenizer()
model = SentimentAnalyzer(tokenizer=tokenizer)
2023-02-08 09:54:14 +00:00
with mongo_client_from_config() as db:
log.debug("Finding the reviews MongoDB collection...")
collection = reviews_collection(db)
2023-02-08 09:54:14 +00:00
try:
training_cache = load_cache("./data/training")
evaluation_cache = load_cache("./data/evaluation")
except FileNotFoundError:
log.debug("Gathering datasets...")
reviews_training = dataset_func(collection=collection, amount=DATA_SET_SIZE.__wrapped__)
reviews_evaluation = dataset_func(collection=collection, amount=DATA_SET_SIZE.__wrapped__)
2023-02-08 09:54:14 +00:00
log.debug("Caching datasets...")
store_cache(reviews_training, "./data/training")
store_cache(reviews_evaluation, "./data/evaluation")
del reviews_training
del reviews_evaluation
2023-02-08 09:54:14 +00:00
training_cache = load_cache("./data/training")
evaluation_cache = load_cache("./data/evaluation")
log.debug("Caches stored and loaded successfully!")
else:
log.debug("Caches loaded successfully!")
2023-02-08 09:54:14 +00:00
log.info("Training model: %s", model)
model.train(training_cache)
log.info("Evaluating model: %s", model)
evaluation_results = model.evaluate(evaluation_cache)
log.info("%s", evaluation_results)
2023-02-08 09:54:14 +00:00
except TrainingFailedError:
log.error("Training failed, restarting with a different dataset.")
continue
else:
log.info("Training")
break
finally:
delete_cache("./data/training")
delete_cache("./data/evaluation")
2023-02-01 01:33:42 +00:00
if __name__ == "__main__":
2023-02-01 16:46:25 +00:00
install_log_handler()
config.proxies.resolve()
2023-02-01 01:33:42 +00:00
main()