1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-11-25 17:24:20 +00:00
bda-6-steffo/unimore_bda_6/__main__.py

57 lines
2.2 KiB
Python
Raw Normal View History

import logging
2023-02-04 00:36:42 +00:00
import tensorflow
from .config import config, DATA_SET_SIZE
2023-02-03 22:27:44 +00:00
from .database import mongo_reviews_collection_from_config, polar_dataset, varied_dataset
from .analysis.nltk_sentiment import NLTKSentimentAnalyzer
2023-02-04 00:36:42 +00:00
from .analysis.tf_text import TensorflowSentimentAnalyzer
from .tokenizer import NLTKWordTokenizer, PottsTokenizer, PottsTokenizerWithNegation, LowercaseTokenizer
2023-02-01 16:46:25 +00:00
from .log import install_log_handler
2023-02-01 03:20:09 +00:00
log = logging.getLogger(__name__)
2023-02-01 01:33:42 +00:00
def main():
2023-02-04 00:36:42 +00:00
if len(tensorflow.config.list_physical_devices(device_type="GPU")) == 0:
log.warning("Tensorflow reports no GPU acceleration available.")
else:
log.debug("Tensorflow successfully found GPU acceleration!")
2023-02-03 22:27:44 +00:00
for dataset_func in [polar_dataset, varied_dataset]:
2023-02-04 00:36:42 +00:00
for SentimentAnalyzer in [
# NLTKSentimentAnalyzer,
TensorflowSentimentAnalyzer,
]:
for Tokenizer in [
# NLTKWordTokenizer,
# PottsTokenizer,
# PottsTokenizerWithNegation,
LowercaseTokenizer,
]:
2023-02-03 22:27:44 +00:00
tokenizer = Tokenizer()
model = SentimentAnalyzer(tokenizer=tokenizer)
with mongo_reviews_collection_from_config() as reviews:
reviews_training = dataset_func(collection=reviews, amount=DATA_SET_SIZE.__wrapped__)
reviews_evaluation = dataset_func(collection=reviews, amount=DATA_SET_SIZE.__wrapped__)
log.info("Training model %s", model)
model.train(reviews_training)
log.info("Evaluating model %s", model)
correct, evaluated = model.evaluate(reviews_evaluation)
log.info("%d evaluated, %d correct, %0.2d %% accuracy", evaluated, correct, correct / evaluated * 100)
# try:
# print("Manual testing for %s" % model)
# print("Input an empty string to continue to the next model.")
# while inp := input():
# print(model.use(inp))
# except KeyboardInterrupt:
# pass
2023-02-01 01:33:42 +00:00
if __name__ == "__main__":
2023-02-01 16:46:25 +00:00
install_log_handler()
config.proxies.resolve()
2023-02-01 01:33:42 +00:00
main()