1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-11-22 16:04:18 +00:00
bda-6-steffo/unimore_bda_6/analysis/base.py

79 lines
2.1 KiB
Python
Raw Normal View History

2023-02-02 01:56:37 +00:00
import abc
2023-02-03 22:27:44 +00:00
import logging
2023-02-04 05:14:24 +00:00
import typing as t
import dataclasses
2023-02-02 16:24:11 +00:00
2023-02-04 05:14:24 +00:00
from ..database import Text, Category, Review, DatasetFunc
2023-02-02 16:24:11 +00:00
2023-02-03 22:27:44 +00:00
log = logging.getLogger(__name__)
2023-02-02 01:56:37 +00:00
2023-02-04 05:14:24 +00:00
@dataclasses.dataclass
class EvaluationResults:
correct: int
evaluated: int
def __repr__(self):
return f"<EvaluationResults: {self.correct}/{self.evaluated}, {self.correct / self.evaluated * 100:.2f}>"
def __str__(self):
return f"{self.correct} / {self.evaluated} - {self.correct / self.evaluated * 100:.2f} %"
2023-02-03 22:27:44 +00:00
class BaseSentimentAnalyzer(metaclass=abc.ABCMeta):
2023-02-02 01:56:37 +00:00
"""
Abstract base class for sentiment analyzers implemented in this project.
"""
@abc.abstractmethod
2023-02-04 05:14:24 +00:00
def train(self, dataset_func: DatasetFunc) -> None:
2023-02-02 01:56:37 +00:00
"""
2023-02-03 22:27:44 +00:00
Train the analyzer with the given training dataset.
2023-02-02 01:56:37 +00:00
"""
raise NotImplementedError()
2023-02-04 05:14:24 +00:00
def evaluate(self, dataset_func: DatasetFunc) -> EvaluationResults:
2023-02-02 01:56:37 +00:00
"""
2023-02-03 22:27:44 +00:00
Perform a model evaluation by calling repeatedly `.use` on every text of the test dataset and by comparing its resulting category with the expected category.
2023-02-02 01:56:37 +00:00
2023-02-03 22:27:44 +00:00
Returns a tuple with the number of correct results and the number of evaluated results.
"""
2023-02-04 05:14:24 +00:00
2023-02-03 22:27:44 +00:00
evaluated: int = 0
2023-02-04 05:14:24 +00:00
correct: int = 0
2023-02-02 01:56:37 +00:00
2023-02-04 05:14:24 +00:00
for review in dataset_func():
resulting_category = self.use(review.text)
2023-02-03 22:27:44 +00:00
evaluated += 1
2023-02-04 05:14:24 +00:00
correct += 1 if resulting_category == review.category else 0
2023-02-03 22:27:44 +00:00
if not evaluated % 100:
log.debug("%d evaluated, %d correct, %0.2d %% accuracy", evaluated, correct, correct / evaluated * 100)
2023-02-02 01:56:37 +00:00
2023-02-04 05:14:24 +00:00
return EvaluationResults(correct=correct, evaluated=evaluated)
2023-02-03 16:50:40 +00:00
2023-02-03 22:27:44 +00:00
@abc.abstractmethod
def use(self, text: Text) -> Category:
"""
Run the model on the given input.
"""
raise NotImplementedError()
2023-02-02 01:56:37 +00:00
2023-02-04 00:36:42 +00:00
class AlreadyTrainedError(Exception):
"""
This model has already been trained and cannot be trained again.
"""
class NotTrainedError(Exception):
"""
This model has not been trained yet.
"""
2023-02-02 01:56:37 +00:00
__all__ = (
2023-02-03 22:27:44 +00:00
"BaseSentimentAnalyzer",
2023-02-04 00:36:42 +00:00
"AlreadyTrainedError",
"NotTrainedError",
2023-02-02 01:56:37 +00:00
)