diff --git a/unimore_bda_6/__main__.py b/unimore_bda_6/__main__.py index a8be723..e4fd18c 100644 --- a/unimore_bda_6/__main__.py +++ b/unimore_bda_6/__main__.py @@ -2,7 +2,7 @@ import logging from .config import config, DATA_SET_SIZE from .database import mongo_reviews_collection_from_config, get_reviews_dataset_polar, get_reviews_dataset_uniform -from .analysis.vanilla import VanillaReviewSA, VanillaUniformReviewSA +from .analysis.vanilla import VanillaReviewSA, polar_categorizer, stars_categorizer from .analysis.potts import PottsReviewSA from .log import install_log_handler @@ -16,7 +16,7 @@ def main(): reviews_uniform_training = get_reviews_dataset_uniform(collection=reviews, amount=DATA_SET_SIZE.__wrapped__) reviews_uniform_evaluation = get_reviews_dataset_uniform(collection=reviews, amount=DATA_SET_SIZE.__wrapped__) - vanilla_polar = VanillaReviewSA() + vanilla_polar = VanillaReviewSA(categorizer=polar_categorizer) vanilla_polar.train(reviews_polar_training) log.info("Vanilla polar evaluation results: %s", vanilla_polar.evaluate(reviews_polar_evaluation)) @@ -24,7 +24,7 @@ def main(): potts_polar.train(reviews_polar_training) log.info("Potts polar evaluation results: %s", potts_polar.evaluate(reviews_polar_evaluation)) - vanilla_uniform = VanillaUniformReviewSA() + vanilla_uniform = VanillaReviewSA(categorizer=stars_categorizer) vanilla_uniform.train(reviews_uniform_training) log.info("Vanilla uniform evaluation results: %s", vanilla_polar.evaluate(reviews_polar_evaluation)) diff --git a/unimore_bda_6/analysis/potts.py b/unimore_bda_6/analysis/potts.py index 2438838..611e206 100644 --- a/unimore_bda_6/analysis/potts.py +++ b/unimore_bda_6/analysis/potts.py @@ -1,5 +1,5 @@ from ..vendor.potts import Tokenizer -from .vanilla import VanillaSA, VanillaReviewSA, VanillaUniformReviewSA +from .vanilla import VanillaSA, VanillaReviewSA class PottsSA(VanillaSA): @@ -24,12 +24,6 @@ class PottsReviewSA(VanillaReviewSA, PottsSA): """ -class PottsUniformReviewSA(VanillaUniformReviewSA, PottsSA): - """ - A `PottsSA` with 5 buckets instead of 2. - """ - - __all__ = ( "PottsSA", "PottsReviewSA", diff --git a/unimore_bda_6/analysis/vanilla.py b/unimore_bda_6/analysis/vanilla.py index 9660215..2dde62a 100644 --- a/unimore_bda_6/analysis/vanilla.py +++ b/unimore_bda_6/analysis/vanilla.py @@ -76,25 +76,15 @@ class VanillaReviewSA(VanillaSA): A `VanillaSA` to be used with `Review`s. """ - @staticmethod - def _rating_to_label(rating: float) -> str: - """ - Return the label corresponding to the given rating. - - Possible categories are: - * negative (0.0 <= rating < 3.0) - * positive (3.0 < rating <= 5.0) - """ - if rating < 3.0: - return "negative" - else: - return "positive" + def __init__(self, categorizer: t.Callable[[Review], str]) -> None: + super().__init__() + self.categorizer: t.Callable[[Review], str] = categorizer def _review_to_data_set(self, review: Review) -> tuple[list[str], str]: """ Convert a review to a NLTK-compatible dataset. """ - return self._tokenize_text(text=review["reviewText"]), self._rating_to_label(rating=review["overall"]) + return self._tokenize_text(text=review["reviewText"]), self.categorizer(rating=review["overall"]) def train(self, reviews: t.Iterable[Review]) -> None: data_set = list(map(self._review_to_data_set, reviews)) @@ -108,27 +98,56 @@ class VanillaReviewSA(VanillaSA): return self._use_with_tokens(self._tokenize_text(text)) -class VanillaUniformReviewSA(VanillaReviewSA): - @staticmethod - def _rating_to_label(rating: float) -> str: - match rating: - case 0.0: - return "abysmal" - case 1.0: - return "terrible" - case 2.0: - return "negative" - case 3.0: - return "mixed" - case 4.0: - return "positive" - case 5.0: - return "great" - case _: - return "unknown" +def polar_categorizer(rating: float) -> str: + """ + Return the polar label corresponding to the given rating. + + Possible categories are: + + * negative (1.0, 2.0) + * positive (3.0, 4.0, 5.0) + * unknown (everything else) + """ + match rating: + case 1.0 | 2.0: + return "negative" + case 3.0 | 4.0 | 5.0: + return "positive" + case _: + return "unknown" + + +def stars_categorizer(rating: float) -> str: + """ + Return the "stars" label corresponding to the given rating. + + Possible categories are: + + * terrible (1.0) + * negative (2.0) + * mixed (3.0) + * positive (4.0) + * great (5.0) + * unknown (everything else) + """ + match rating: + case 1.0: + return "terrible" + case 2.0: + return "negative" + case 3.0: + return "mixed" + case 4.0: + return "positive" + case 5.0: + return "great" + case _: + return "unknown" __all__ = ( "VanillaSA", "VanillaReviewSA", + "polar_categorizer", + "stars_categorizer", )