1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-11-24 08:44:19 +00:00

Use composition instead of inheritance

This commit is contained in:
Steffo 2023-02-02 16:03:07 +01:00
parent 3ae43b2714
commit 4c3f892038
Signed by: steffo
GPG key ID: 2A24051445686895
3 changed files with 55 additions and 42 deletions

View file

@ -2,7 +2,7 @@ import logging
from .config import config, DATA_SET_SIZE
from .database import mongo_reviews_collection_from_config, get_reviews_dataset_polar, get_reviews_dataset_uniform
from .analysis.vanilla import VanillaReviewSA, VanillaUniformReviewSA
from .analysis.vanilla import VanillaReviewSA, polar_categorizer, stars_categorizer
from .analysis.potts import PottsReviewSA
from .log import install_log_handler
@ -16,7 +16,7 @@ def main():
reviews_uniform_training = get_reviews_dataset_uniform(collection=reviews, amount=DATA_SET_SIZE.__wrapped__)
reviews_uniform_evaluation = get_reviews_dataset_uniform(collection=reviews, amount=DATA_SET_SIZE.__wrapped__)
vanilla_polar = VanillaReviewSA()
vanilla_polar = VanillaReviewSA(categorizer=polar_categorizer)
vanilla_polar.train(reviews_polar_training)
log.info("Vanilla polar evaluation results: %s", vanilla_polar.evaluate(reviews_polar_evaluation))
@ -24,7 +24,7 @@ def main():
potts_polar.train(reviews_polar_training)
log.info("Potts polar evaluation results: %s", potts_polar.evaluate(reviews_polar_evaluation))
vanilla_uniform = VanillaUniformReviewSA()
vanilla_uniform = VanillaReviewSA(categorizer=stars_categorizer)
vanilla_uniform.train(reviews_uniform_training)
log.info("Vanilla uniform evaluation results: %s", vanilla_polar.evaluate(reviews_polar_evaluation))

View file

@ -1,5 +1,5 @@
from ..vendor.potts import Tokenizer
from .vanilla import VanillaSA, VanillaReviewSA, VanillaUniformReviewSA
from .vanilla import VanillaSA, VanillaReviewSA
class PottsSA(VanillaSA):
@ -24,12 +24,6 @@ class PottsReviewSA(VanillaReviewSA, PottsSA):
"""
class PottsUniformReviewSA(VanillaUniformReviewSA, PottsSA):
"""
A `PottsSA` with 5 buckets instead of 2.
"""
__all__ = (
"PottsSA",
"PottsReviewSA",

View file

@ -76,25 +76,15 @@ class VanillaReviewSA(VanillaSA):
A `VanillaSA` to be used with `Review`s.
"""
@staticmethod
def _rating_to_label(rating: float) -> str:
"""
Return the label corresponding to the given rating.
Possible categories are:
* negative (0.0 <= rating < 3.0)
* positive (3.0 < rating <= 5.0)
"""
if rating < 3.0:
return "negative"
else:
return "positive"
def __init__(self, categorizer: t.Callable[[Review], str]) -> None:
super().__init__()
self.categorizer: t.Callable[[Review], str] = categorizer
def _review_to_data_set(self, review: Review) -> tuple[list[str], str]:
"""
Convert a review to a NLTK-compatible dataset.
"""
return self._tokenize_text(text=review["reviewText"]), self._rating_to_label(rating=review["overall"])
return self._tokenize_text(text=review["reviewText"]), self.categorizer(rating=review["overall"])
def train(self, reviews: t.Iterable[Review]) -> None:
data_set = list(map(self._review_to_data_set, reviews))
@ -108,27 +98,56 @@ class VanillaReviewSA(VanillaSA):
return self._use_with_tokens(self._tokenize_text(text))
class VanillaUniformReviewSA(VanillaReviewSA):
@staticmethod
def _rating_to_label(rating: float) -> str:
match rating:
case 0.0:
return "abysmal"
case 1.0:
return "terrible"
case 2.0:
return "negative"
case 3.0:
return "mixed"
case 4.0:
return "positive"
case 5.0:
return "great"
case _:
return "unknown"
def polar_categorizer(rating: float) -> str:
"""
Return the polar label corresponding to the given rating.
Possible categories are:
* negative (1.0, 2.0)
* positive (3.0, 4.0, 5.0)
* unknown (everything else)
"""
match rating:
case 1.0 | 2.0:
return "negative"
case 3.0 | 4.0 | 5.0:
return "positive"
case _:
return "unknown"
def stars_categorizer(rating: float) -> str:
"""
Return the "stars" label corresponding to the given rating.
Possible categories are:
* terrible (1.0)
* negative (2.0)
* mixed (3.0)
* positive (4.0)
* great (5.0)
* unknown (everything else)
"""
match rating:
case 1.0:
return "terrible"
case 2.0:
return "negative"
case 3.0:
return "mixed"
case 4.0:
return "positive"
case 5.0:
return "great"
case _:
return "unknown"
__all__ = (
"VanillaSA",
"VanillaReviewSA",
"polar_categorizer",
"stars_categorizer",
)