1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-11-23 00:14:19 +00:00
bda-6-steffo/unimore_bda_6/analysis/base.py

114 lines
3.1 KiB
Python
Raw Normal View History

2023-02-08 18:46:05 +00:00
from __future__ import annotations
2023-02-02 01:56:37 +00:00
import abc
2023-02-03 22:27:44 +00:00
import logging
2023-02-04 05:14:24 +00:00
import dataclasses
2023-02-02 16:24:11 +00:00
2023-02-12 04:11:58 +00:00
from ..database import CachedDatasetFunc, TextReview, TokenizedReview
2023-02-08 18:46:05 +00:00
from ..tokenizer import BaseTokenizer
2023-02-02 16:24:11 +00:00
2023-02-03 22:27:44 +00:00
log = logging.getLogger(__name__)
2023-02-02 01:56:37 +00:00
2023-02-03 22:27:44 +00:00
class BaseSentimentAnalyzer(metaclass=abc.ABCMeta):
2023-02-02 01:56:37 +00:00
"""
Abstract base class for sentiment analyzers implemented in this project.
"""
2023-02-08 18:46:05 +00:00
def __init__(self, *, tokenizer: BaseTokenizer):
2023-02-12 04:11:58 +00:00
self.tokenizer: BaseTokenizer = tokenizer
2023-02-08 18:46:05 +00:00
def __repr__(self):
2023-02-12 04:11:58 +00:00
return f"<{self.__class__.__qualname__} with {self.tokenizer} tokenizer>"
2023-02-08 18:46:05 +00:00
2023-02-02 01:56:37 +00:00
@abc.abstractmethod
2023-02-08 18:46:05 +00:00
def train(self, training_dataset_func: CachedDatasetFunc, validation_dataset_func: CachedDatasetFunc) -> None:
2023-02-02 01:56:37 +00:00
"""
2023-02-08 18:46:05 +00:00
Train the analyzer with the given training and validation datasets.
2023-02-02 01:56:37 +00:00
"""
raise NotImplementedError()
2023-02-08 18:46:05 +00:00
@abc.abstractmethod
2023-02-12 04:11:58 +00:00
def use(self, text: str) -> float:
2023-02-08 18:46:05 +00:00
"""
2023-02-12 04:11:58 +00:00
Run the model on the given input, and return the predicted rating.
2023-02-08 18:46:05 +00:00
"""
raise NotImplementedError()
def evaluate(self, evaluation_dataset_func: CachedDatasetFunc) -> EvaluationResults:
2023-02-02 01:56:37 +00:00
"""
2023-02-03 22:27:44 +00:00
Perform a model evaluation by calling repeatedly `.use` on every text of the test dataset and by comparing its resulting category with the expected category.
"""
2023-02-04 05:14:24 +00:00
2023-02-03 22:27:44 +00:00
evaluated: int = 0
2023-02-12 04:11:58 +00:00
perfect: int = 0
squared_error: float = 0.0
2023-02-02 01:56:37 +00:00
2023-02-08 18:46:05 +00:00
for review in evaluation_dataset_func():
2023-02-04 05:14:24 +00:00
resulting_category = self.use(review.text)
2023-02-12 04:11:58 +00:00
log.debug("Evaluation step: %d for %s", resulting_category, review)
2023-02-03 22:27:44 +00:00
evaluated += 1
2023-02-10 05:21:50 +00:00
try:
2023-02-12 04:11:58 +00:00
perfect += 1 if resulting_category == review.rating else 0
squared_error += (resulting_category - review.rating) ** 2
2023-02-10 05:21:50 +00:00
except ValueError:
log.warning("Model execution on %s resulted in a NaN value: %s", review, resulting_category)
2023-02-02 01:56:37 +00:00
2023-02-12 04:11:58 +00:00
return EvaluationResults(perfect=perfect, evaluated=evaluated, mse=squared_error / evaluated)
2023-02-03 16:50:40 +00:00
2023-02-08 18:46:05 +00:00
@dataclasses.dataclass
class EvaluationResults:
"""
Container for the results of a dataset evaluation.
"""
evaluated: int
2023-02-12 04:11:58 +00:00
"""
The number of reviews that were evaluated.
"""
perfect: int
"""
The number of reviews for which the model returned the correct rating.
"""
mse: float
"""
Mean squared error
"""
2023-02-08 18:46:05 +00:00
def __repr__(self):
return f"<EvaluationResults: {self!s}>"
def __str__(self):
2023-02-12 04:11:58 +00:00
return f"Evaluation results:\t{self.evaluated}\tevaluated\t{self.perfect}\tperfect\t{self.perfect / self.evaluated:.2%}\taccuracy\t{self.mse / self.evaluated:.2%}\tmean squared error"
2023-02-02 01:56:37 +00:00
2023-02-04 00:36:42 +00:00
class AlreadyTrainedError(Exception):
"""
This model has already been trained and cannot be trained again.
"""
class NotTrainedError(Exception):
"""
This model has not been trained yet.
"""
2023-02-08 09:54:14 +00:00
class TrainingFailedError(Exception):
"""
The model wasn't able to complete the training and should not be used anymore.
"""
2023-02-02 01:56:37 +00:00
__all__ = (
2023-02-03 22:27:44 +00:00
"BaseSentimentAnalyzer",
2023-02-04 00:36:42 +00:00
"AlreadyTrainedError",
"NotTrainedError",
2023-02-08 09:54:14 +00:00
"TrainingFailedError",
2023-02-02 01:56:37 +00:00
)