mirror of
https://github.com/Steffo99/unimore-bda-6.git
synced 2024-11-24 08:44:19 +00:00
Use composition instead of inheritance
This commit is contained in:
parent
3ae43b2714
commit
4c3f892038
3 changed files with 55 additions and 42 deletions
|
@ -2,7 +2,7 @@ import logging
|
|||
|
||||
from .config import config, DATA_SET_SIZE
|
||||
from .database import mongo_reviews_collection_from_config, get_reviews_dataset_polar, get_reviews_dataset_uniform
|
||||
from .analysis.vanilla import VanillaReviewSA, VanillaUniformReviewSA
|
||||
from .analysis.vanilla import VanillaReviewSA, polar_categorizer, stars_categorizer
|
||||
from .analysis.potts import PottsReviewSA
|
||||
from .log import install_log_handler
|
||||
|
||||
|
@ -16,7 +16,7 @@ def main():
|
|||
reviews_uniform_training = get_reviews_dataset_uniform(collection=reviews, amount=DATA_SET_SIZE.__wrapped__)
|
||||
reviews_uniform_evaluation = get_reviews_dataset_uniform(collection=reviews, amount=DATA_SET_SIZE.__wrapped__)
|
||||
|
||||
vanilla_polar = VanillaReviewSA()
|
||||
vanilla_polar = VanillaReviewSA(categorizer=polar_categorizer)
|
||||
vanilla_polar.train(reviews_polar_training)
|
||||
log.info("Vanilla polar evaluation results: %s", vanilla_polar.evaluate(reviews_polar_evaluation))
|
||||
|
||||
|
@ -24,7 +24,7 @@ def main():
|
|||
potts_polar.train(reviews_polar_training)
|
||||
log.info("Potts polar evaluation results: %s", potts_polar.evaluate(reviews_polar_evaluation))
|
||||
|
||||
vanilla_uniform = VanillaUniformReviewSA()
|
||||
vanilla_uniform = VanillaReviewSA(categorizer=stars_categorizer)
|
||||
vanilla_uniform.train(reviews_uniform_training)
|
||||
log.info("Vanilla uniform evaluation results: %s", vanilla_polar.evaluate(reviews_polar_evaluation))
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from ..vendor.potts import Tokenizer
|
||||
from .vanilla import VanillaSA, VanillaReviewSA, VanillaUniformReviewSA
|
||||
from .vanilla import VanillaSA, VanillaReviewSA
|
||||
|
||||
|
||||
class PottsSA(VanillaSA):
|
||||
|
@ -24,12 +24,6 @@ class PottsReviewSA(VanillaReviewSA, PottsSA):
|
|||
"""
|
||||
|
||||
|
||||
class PottsUniformReviewSA(VanillaUniformReviewSA, PottsSA):
|
||||
"""
|
||||
A `PottsSA` with 5 buckets instead of 2.
|
||||
"""
|
||||
|
||||
|
||||
__all__ = (
|
||||
"PottsSA",
|
||||
"PottsReviewSA",
|
||||
|
|
|
@ -76,25 +76,15 @@ class VanillaReviewSA(VanillaSA):
|
|||
A `VanillaSA` to be used with `Review`s.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _rating_to_label(rating: float) -> str:
|
||||
"""
|
||||
Return the label corresponding to the given rating.
|
||||
|
||||
Possible categories are:
|
||||
* negative (0.0 <= rating < 3.0)
|
||||
* positive (3.0 < rating <= 5.0)
|
||||
"""
|
||||
if rating < 3.0:
|
||||
return "negative"
|
||||
else:
|
||||
return "positive"
|
||||
def __init__(self, categorizer: t.Callable[[Review], str]) -> None:
|
||||
super().__init__()
|
||||
self.categorizer: t.Callable[[Review], str] = categorizer
|
||||
|
||||
def _review_to_data_set(self, review: Review) -> tuple[list[str], str]:
|
||||
"""
|
||||
Convert a review to a NLTK-compatible dataset.
|
||||
"""
|
||||
return self._tokenize_text(text=review["reviewText"]), self._rating_to_label(rating=review["overall"])
|
||||
return self._tokenize_text(text=review["reviewText"]), self.categorizer(rating=review["overall"])
|
||||
|
||||
def train(self, reviews: t.Iterable[Review]) -> None:
|
||||
data_set = list(map(self._review_to_data_set, reviews))
|
||||
|
@ -108,12 +98,39 @@ class VanillaReviewSA(VanillaSA):
|
|||
return self._use_with_tokens(self._tokenize_text(text))
|
||||
|
||||
|
||||
class VanillaUniformReviewSA(VanillaReviewSA):
|
||||
@staticmethod
|
||||
def _rating_to_label(rating: float) -> str:
|
||||
def polar_categorizer(rating: float) -> str:
|
||||
"""
|
||||
Return the polar label corresponding to the given rating.
|
||||
|
||||
Possible categories are:
|
||||
|
||||
* negative (1.0, 2.0)
|
||||
* positive (3.0, 4.0, 5.0)
|
||||
* unknown (everything else)
|
||||
"""
|
||||
match rating:
|
||||
case 1.0 | 2.0:
|
||||
return "negative"
|
||||
case 3.0 | 4.0 | 5.0:
|
||||
return "positive"
|
||||
case _:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def stars_categorizer(rating: float) -> str:
|
||||
"""
|
||||
Return the "stars" label corresponding to the given rating.
|
||||
|
||||
Possible categories are:
|
||||
|
||||
* terrible (1.0)
|
||||
* negative (2.0)
|
||||
* mixed (3.0)
|
||||
* positive (4.0)
|
||||
* great (5.0)
|
||||
* unknown (everything else)
|
||||
"""
|
||||
match rating:
|
||||
case 0.0:
|
||||
return "abysmal"
|
||||
case 1.0:
|
||||
return "terrible"
|
||||
case 2.0:
|
||||
|
@ -131,4 +148,6 @@ class VanillaUniformReviewSA(VanillaReviewSA):
|
|||
__all__ = (
|
||||
"VanillaSA",
|
||||
"VanillaReviewSA",
|
||||
"polar_categorizer",
|
||||
"stars_categorizer",
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue