1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-11-23 00:14:19 +00:00
bda-6-steffo/unimore_bda_6/analysis/vanilla.py

100 lines
3.6 KiB
Python
Raw Normal View History

2023-02-01 16:46:25 +00:00
import nltk
import nltk.classify
import nltk.sentiment
import nltk.sentiment.util
import logging
2023-02-02 01:56:37 +00:00
import typing as t
2023-02-01 16:46:25 +00:00
2023-02-02 16:24:11 +00:00
from .base import Input, Category, BaseSA, AlreadyTrainedError, NotTrainedError
TokenBag = list[str]
IntermediateValue = t.TypeVar("IntermediateValue")
2023-02-01 16:46:25 +00:00
log = logging.getLogger(__name__)
2023-02-02 16:24:11 +00:00
class VanillaSA(BaseSA):
2023-02-02 01:56:37 +00:00
"""
A sentiment analyzer resembling the one implemented in structure the one implemented in the classroom, using the basic sentiment analyzer of NLTK.
"""
2023-02-02 16:24:11 +00:00
def __init__(self, *, extractor: t.Callable[[Input], tuple[str, Category]], tokenizer: t.Callable[[str], TokenBag], categorizer: t.Callable[[Input], Category]) -> None:
2023-02-02 01:56:37 +00:00
super().__init__()
self.model: nltk.sentiment.SentimentAnalyzer = nltk.sentiment.SentimentAnalyzer()
2023-02-02 16:24:11 +00:00
self.trained: bool = False
2023-02-02 01:56:37 +00:00
2023-02-02 16:24:11 +00:00
self.extractor: t.Callable[[Input], tuple[str, IntermediateValue]] = extractor
self.tokenizer: t.Callable[[str], TokenBag] = tokenizer
self.categorizer: t.Callable[[IntermediateValue], Category] = categorizer
2023-02-02 01:56:37 +00:00
2023-02-02 16:24:11 +00:00
def __add_feature_unigrams(self, training_set: list[tuple[TokenBag, Category]]) -> None:
2023-02-02 01:56:37 +00:00
"""
Add the `nltk.sentiment.util.extract_unigram_feats` feature to the model.
"""
all_words = self.model.all_words(training_set, labeled=True)
unigrams = self.model.unigram_word_feats(words=all_words, min_freq=4)
self.model.add_feat_extractor(nltk.sentiment.util.extract_unigram_feats, unigrams=unigrams)
2023-02-02 16:24:11 +00:00
def _add_features(self, training_set: list[tuple[TokenBag, Category]]):
2023-02-02 01:56:37 +00:00
"""
2023-02-02 16:24:11 +00:00
Add new features to the sentiment analyzer.
2023-02-02 01:56:37 +00:00
"""
2023-02-02 16:24:11 +00:00
self.__add_feature_unigrams(training_set)
2023-02-02 01:56:37 +00:00
2023-02-02 16:24:11 +00:00
def _train_from_dataset(self, dataset: list[tuple[TokenBag, Category]]) -> None:
2023-02-02 01:56:37 +00:00
"""
2023-02-02 16:24:11 +00:00
Train the model with the given training set.
2023-02-02 01:56:37 +00:00
"""
if self.trained:
raise AlreadyTrainedError()
2023-02-02 16:24:11 +00:00
self.__add_feature_unigrams(dataset)
training_set_with_features = self.model.apply_features(dataset, labeled=True)
2023-02-02 01:56:37 +00:00
self.model.train(trainer=nltk.classify.NaiveBayesClassifier.train, training_set=training_set_with_features)
self.trained = True
2023-02-02 16:24:11 +00:00
def _evaluate_from_dataset(self, dataset: list[tuple[TokenBag, Category]]) -> dict:
"""
Perform a model evaluation with the given test set.
"""
2023-02-02 01:56:37 +00:00
if not self.trained:
raise NotTrainedError()
2023-02-02 16:24:11 +00:00
test_set_with_features = self.model.apply_features(dataset, labeled=True)
return self.model.evaluate(test_set_with_features)
2023-02-02 01:56:37 +00:00
2023-02-02 16:24:11 +00:00
def _use_from_tokenbag(self, tokens: TokenBag) -> Category:
2023-02-02 01:56:37 +00:00
"""
2023-02-02 16:24:11 +00:00
Categorize the given token bag.
2023-02-02 01:56:37 +00:00
"""
2023-02-02 16:24:11 +00:00
if not self.trained:
raise NotTrainedError()
2023-02-02 01:56:37 +00:00
2023-02-02 16:24:11 +00:00
return self.model.classify(instance=tokens)
2023-02-01 16:46:25 +00:00
2023-02-02 16:24:11 +00:00
def _extract_data(self, inp: Input) -> tuple[TokenBag, Category]:
text, value = self.extractor(inp)
return self.tokenizer(text), self.categorizer(value)
2023-02-01 16:46:25 +00:00
2023-02-02 16:24:11 +00:00
def _extract_dataset(self, inp: list[Input]) -> list[tuple[TokenBag, Category]]:
return list(map(self._extract_data, inp))
2023-02-02 15:03:07 +00:00
2023-02-02 16:24:11 +00:00
def train(self, training_set: list[Input]) -> None:
dataset = self._extract_dataset(training_set)
self._train_from_dataset(dataset)
2023-02-02 15:03:07 +00:00
2023-02-02 16:24:11 +00:00
def evaluate(self, test_set: list[tuple[Input, Category]]) -> None:
dataset = self._extract_dataset(test_set)
return self._evaluate_from_dataset(dataset)
2023-02-02 15:03:07 +00:00
2023-02-02 16:24:11 +00:00
def use(self, text: Input) -> Category:
tokens = self.tokenizer(text)
return self._use_from_tokenbag(tokens)
2023-02-01 16:46:25 +00:00
__all__ = (
2023-02-02 01:56:37 +00:00
"VanillaSA",
2023-02-01 16:46:25 +00:00
)