1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-11-25 17:24:20 +00:00
bda-6-steffo/unimore_bda_6/analysis/base.py

210 lines
6.4 KiB
Python
Raw Normal View History

2023-02-08 18:46:05 +00:00
from __future__ import annotations
2023-02-02 01:56:37 +00:00
import abc
2023-02-03 22:27:44 +00:00
import logging
2023-02-14 01:25:38 +00:00
import collections
2023-02-02 16:24:11 +00:00
2023-02-13 14:42:45 +00:00
from ..database import CachedDatasetFunc
2023-02-08 18:46:05 +00:00
from ..tokenizer import BaseTokenizer
2023-02-02 16:24:11 +00:00
2023-02-03 22:27:44 +00:00
log = logging.getLogger(__name__)
2023-02-02 01:56:37 +00:00
2023-02-03 22:27:44 +00:00
class BaseSentimentAnalyzer(metaclass=abc.ABCMeta):
2023-02-02 01:56:37 +00:00
"""
Abstract base class for sentiment analyzers implemented in this project.
"""
2023-02-08 18:46:05 +00:00
def __init__(self, *, tokenizer: BaseTokenizer):
2023-02-12 04:11:58 +00:00
self.tokenizer: BaseTokenizer = tokenizer
2023-02-08 18:46:05 +00:00
def __repr__(self):
2023-02-12 04:11:58 +00:00
return f"<{self.__class__.__qualname__} with {self.tokenizer} tokenizer>"
2023-02-08 18:46:05 +00:00
2023-02-02 01:56:37 +00:00
@abc.abstractmethod
2023-02-08 18:46:05 +00:00
def train(self, training_dataset_func: CachedDatasetFunc, validation_dataset_func: CachedDatasetFunc) -> None:
2023-02-02 01:56:37 +00:00
"""
2023-02-08 18:46:05 +00:00
Train the analyzer with the given training and validation datasets.
2023-02-02 01:56:37 +00:00
"""
raise NotImplementedError()
2023-02-08 18:46:05 +00:00
@abc.abstractmethod
2023-02-12 04:11:58 +00:00
def use(self, text: str) -> float:
2023-02-08 18:46:05 +00:00
"""
2023-02-12 04:11:58 +00:00
Run the model on the given input, and return the predicted rating.
2023-02-08 18:46:05 +00:00
"""
raise NotImplementedError()
def evaluate(self, evaluation_dataset_func: CachedDatasetFunc) -> EvaluationResults:
2023-02-02 01:56:37 +00:00
"""
2023-02-03 22:27:44 +00:00
Perform a model evaluation by calling repeatedly `.use` on every text of the test dataset and by comparing its resulting category with the expected category.
"""
2023-02-14 01:25:38 +00:00
er = EvaluationResults()
for review in evaluation_dataset_func():
er.add(expected=review.rating, predicted=self.use(review.text))
return er
2023-02-04 05:14:24 +00:00
2023-02-13 17:47:29 +00:00
2023-02-14 01:25:38 +00:00
class EvaluationResults:
"""
Container for the results of a dataset evaluation.
"""
2023-02-12 04:11:58 +00:00
2023-02-14 01:25:38 +00:00
def __init__(self):
self.confusion_matrix: dict[float, dict[float, int]] = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
"""
Confusion matrix of the evaluation.
2023-02-12 04:11:58 +00:00
2023-02-14 01:25:38 +00:00
First key is the expected rating, second key is the output label.
"""
2023-02-02 01:56:37 +00:00
2023-02-14 01:25:38 +00:00
self.absolute_error_total: float = 0.0
"""
Sum of the absolute errors committed in the evaluation.
"""
2023-02-02 01:56:37 +00:00
2023-02-14 01:25:38 +00:00
self.squared_error_total: float = 0.0
"""
Sum of the squared errors committed in the evaluation.
"""
2023-02-03 16:50:40 +00:00
2023-02-14 01:25:38 +00:00
def __repr__(self) -> str:
return f"<EvaluationResults with {self.evaluated_count()} evaluated and {len(self.keys())} categories>"
def __str__(self) -> str:
text = [f"Evaluation results: {self.evaluated_count()} evaluated, {self.mean_absolute_error()} mean absolute error, {self.mean_squared_error()} mean squared error, "]
for key in self.keys():
text.append(f"{self.recall(key)} recall of {key}, ")
text.append(f"{self.precision(key)} precision of {key}, ")
text.append(f"{self.perfect_count()} perfect matches.")
return "".join(text)
def __add__(self, other: EvaluationResults) -> EvaluationResults:
new = self.__class__()
for expected, value in self.confusion_matrix.items():
for predicted, amount in value.items():
new.confusion_matrix[expected][predicted] += amount
for expected, value in other.confusion_matrix.items():
for predicted, amount in value.items():
new.confusion_matrix[expected][predicted] += amount
return new
def keys(self) -> set[float]:
"""
Return all processed categories.
"""
keys: set[float] = set()
2023-02-08 18:46:05 +00:00
2023-02-14 01:25:38 +00:00
for expected, value in self.confusion_matrix.items():
keys.add(expected)
for predicted, _ in value.items():
keys.add(predicted)
2023-02-08 18:46:05 +00:00
2023-02-14 01:25:38 +00:00
return keys
2023-02-12 04:11:58 +00:00
2023-02-14 01:25:38 +00:00
def evaluated_count(self) -> int:
"""
Return the total number of evaluated reviews.
"""
total: int = 0
for row in self.confusion_matrix.values():
for el in row.values():
total += el
return total
2023-02-12 04:11:58 +00:00
2023-02-14 01:25:38 +00:00
def perfect_count(self) -> int:
"""
Return the total number of perfect reviews.
"""
total: int = 0
for key in self.keys():
total += self.confusion_matrix[key][key]
return total
2023-02-08 18:46:05 +00:00
2023-02-14 01:25:38 +00:00
def recall_count(self, rating: float) -> int:
"""
Return the number of reviews processed with the given rating.
"""
total: int = 0
for el in self.confusion_matrix[rating].values():
total += el
return total
def precision_count(self, rating: float) -> int:
"""
Return the number of reviews for which the model returned the given rating.
"""
total: int = 0
for col in self.confusion_matrix.values():
total += col[rating]
return total
def recall(self, rating: float) -> float:
"""
Return the recall for a given rating.
"""
try:
return self.confusion_matrix[rating][rating] / self.recall_count(rating)
except ZeroDivisionError:
return float("inf")
def precision(self, rating: float) -> float:
"""
Return the precision for a given rating.
"""
try:
return self.confusion_matrix[rating][rating] / self.precision_count(rating)
except ZeroDivisionError:
return float("inf")
def mean_absolute_error(self) -> float:
"""
Return the mean absolute error.
"""
return self.absolute_error_total / self.evaluated_count()
def mean_squared_error(self) -> float:
"""
Return the mean squared error.
"""
return self.squared_error_total / self.evaluated_count()
2023-02-08 18:46:05 +00:00
2023-02-14 01:25:38 +00:00
def add(self, expected: float, predicted: float) -> None:
"""
Count a new prediction.
"""
if expected == predicted:
log.log(11, "Expected %.1d*, predicted %.1d*", expected, predicted) # Success
else:
log.log(12, "Expected %.1d*, predicted %.1d*", expected, predicted) # Failure
self.confusion_matrix[expected][predicted] += 1
self.absolute_error_total += abs(expected - predicted)
self.squared_error_total += (expected - predicted) ** 2
2023-02-02 01:56:37 +00:00
2023-02-04 00:36:42 +00:00
class AlreadyTrainedError(Exception):
"""
This model has already been trained and cannot be trained again.
"""
class NotTrainedError(Exception):
"""
This model has not been trained yet.
"""
2023-02-08 09:54:14 +00:00
class TrainingFailedError(Exception):
"""
The model wasn't able to complete the training and should not be used anymore.
"""
2023-02-02 01:56:37 +00:00
__all__ = (
2023-02-03 22:27:44 +00:00
"BaseSentimentAnalyzer",
2023-02-04 00:36:42 +00:00
"AlreadyTrainedError",
"NotTrainedError",
2023-02-08 09:54:14 +00:00
"TrainingFailedError",
2023-02-02 01:56:37 +00:00
)