mirror of
https://github.com/Steffo99/unimore-bda-6.git
synced 2024-11-21 23:44:19 +00:00
enough
This commit is contained in:
parent
4d6c8f0fee
commit
e3005ab8b0
21 changed files with 309 additions and 149 deletions
|
@ -40,6 +40,7 @@
|
|||
<option value="E501" />
|
||||
<option value="E221" />
|
||||
<option value="E203" />
|
||||
<option value="E402" />
|
||||
</list>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
|
|
|
@ -10,4 +10,7 @@
|
|||
<component name="ProjectRootManager" version="2" languageLevel="JDK_19">
|
||||
<output url="file://$PROJECT_DIR$/out" />
|
||||
</component>
|
||||
<component name="PythonCompatibilityInspectionAdvertiser">
|
||||
<option name="version" value="3" />
|
||||
</component>
|
||||
</project>
|
|
@ -1,12 +1,12 @@
|
|||
<component name="ProjectRunConfigurationManager">
|
||||
<configuration default="false" name="unimore_bda_6" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true">
|
||||
<module name="unimore-bda-6" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="INTERPRETER_OPTIONS" value="-O" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
<env name="CONFIRM_OVERWRITE" value="False" />
|
||||
<env name="NLTK_DATA" value="./data/nltk" />
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
<env name="TF_CPP_MIN_LOG_LEVEL" value="2" />
|
||||
<env name="WORKING_SET_SIZE" value="1000000" />
|
||||
<env name="XLA_FLAGS" value="--xla_gpu_cuda_data_dir=/opt/cuda" />
|
||||
|
|
57
poetry.lock
generated
57
poetry.lock
generated
|
@ -1230,6 +1230,61 @@ files = [
|
|||
[package.extras]
|
||||
tests = ["pytest", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.13.2"
|
||||
description = "Fast and Customizable Tokenizers"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:a6f36b1b499233bb4443b5e57e20630c5e02fba61109632f5e00dab970440157"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:bc6983282ee74d638a4b6d149e5dadd8bc7ff1d0d6de663d69f099e0c6bddbeb"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16756e6ab264b162f99c0c0a8d3d521328f428b33374c5ee161c0ebec42bf3c0"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b10db6e4b036c78212c6763cb56411566edcf2668c910baa1939afd50095ce48"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:238e879d1a0f4fddc2ce5b2d00f219125df08f8532e5f1f2ba9ad42f02b7da59"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:9eee037bb5aa14daeb56b4c39956164b2bebbe6ab4ca7779d88aa16b79bd4e17"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d1b079c4c9332048fec4cb9c2055c2373c74fbb336716a5524c9a720206d787e"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:fa7ef7ee380b1f49211bbcfac8a006b1a3fa2fa4c7f4ee134ae384eb4ea5e453"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0901a5c6538d2d2dc752c6b4bde7dab170fddce559ec75662cfad03b3187c8f6"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ba9baa76b5a3eefa78b6cc351315a216232fd727ee5e3ce0f7c6885d9fb531b"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f44d59bafe3d61e8a56b9e0a963075187c0f0091023120b13fbe37a87936f171"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3bc9f7d7f4c1aa84bb6b8d642a60272c8a2c987669e9bb0ac26daf0c6a9fc8"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-win32.whl", hash = "sha256:efbf189fb9cf29bd29e98c0437bdb9809f9de686a1e6c10e0b954410e9ca2142"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-win_amd64.whl", hash = "sha256:0b4cb2c60c094f31ea652f6cf9f349aae815f9243b860610c29a69ed0d7a88f8"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:b47d6212e7dd05784d7330b3b1e5a170809fa30e2b333ca5c93fba1463dec2b7"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:80a57501b61ec4f94fb7ce109e2b4a1a090352618efde87253b4ede6d458b605"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61507a9953f6e7dc3c972cbc57ba94c80c8f7f686fbc0876afe70ea2b8cc8b04"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c09f4fa620e879debdd1ec299bb81e3c961fd8f64f0e460e64df0818d29d845c"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:66c892d85385b202893ac6bc47b13390909e205280e5df89a41086cfec76fedb"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3e306b0941ad35087ae7083919a5c410a6b672be0343609d79a1171a364ce79"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-win32.whl", hash = "sha256:79189e7f706c74dbc6b753450757af172240916d6a12ed4526af5cc6d3ceca26"},
|
||||
{file = "tokenizers-0.13.2-cp39-cp39-win_amd64.whl", hash = "sha256:486d637b413fddada845a10a45c74825d73d3725da42ccd8796ccd7a1c07a024"},
|
||||
{file = "tokenizers-0.13.2.tar.gz", hash = "sha256:f9525375582fd1912ac3caa2f727d36c86ff8c0c6de45ae1aaff90f87f33b907"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
|
||||
docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
|
||||
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.64.1"
|
||||
|
@ -1390,4 +1445,4 @@ files = [
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "~3.10"
|
||||
content-hash = "b17bed73011c627355660e5ba5ca176a920c1d87915a1ab875e5a5ddd28a9dca"
|
||||
content-hash = "d63867d77886e8ae2c09771a5ec1053dae99a5699f7e905a69bd298e1b986a80"
|
||||
|
|
|
@ -137,6 +137,7 @@ nltk = "^3.8.1"
|
|||
cfig = {extras = ["cli"], version = "^0.3.0"}
|
||||
coloredlogs = "^15.0.1"
|
||||
tensorflow = "^2.11.0"
|
||||
tokenizers = "^0.13.2"
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,84 +1,81 @@
|
|||
import logging
|
||||
import tensorflow
|
||||
import pymongo.errors
|
||||
from .log import install_log_handler
|
||||
|
||||
from .config import config, DATA_SET_SIZE
|
||||
from .database import mongo_client_from_config, reviews_collection, sample_reviews_polar, sample_reviews_varied, store_cache, load_cache, delete_cache
|
||||
install_log_handler()
|
||||
|
||||
from .config import config
|
||||
from .database import mongo_client_from_config, reviews_collection, sample_reviews_polar, sample_reviews_varied
|
||||
from .analysis.nltk_sentiment import NLTKSentimentAnalyzer
|
||||
from .analysis.tf_text import TensorflowSentimentAnalyzer
|
||||
from .analysis.base import TrainingFailedError
|
||||
from .tokenizer import LowercaseTokenizer
|
||||
from .log import install_log_handler
|
||||
from .tokenizer import PlainTokenizer, LowercaseTokenizer, NLTKWordTokenizer, PottsTokenizer, PottsTokenizerWithNegation
|
||||
from .gathering import Caches
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
if len(tensorflow.config.list_physical_devices(device_type="GPU")) == 0:
|
||||
log.warning("Tensorflow reports no GPU acceleration available.")
|
||||
else:
|
||||
log.debug("Tensorflow successfully found GPU acceleration!")
|
||||
log.info("Started unimore-bda-6 in %s mode!", "DEBUG" if __debug__ else "PRODUCTION")
|
||||
|
||||
try:
|
||||
delete_cache("./data/training")
|
||||
delete_cache("./data/evaluation")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
log.debug("Validating configuration...")
|
||||
config.proxies.resolve()
|
||||
|
||||
for dataset_func in [sample_reviews_polar, sample_reviews_varied]:
|
||||
for SentimentAnalyzer in [TensorflowSentimentAnalyzer, NLTKSentimentAnalyzer]:
|
||||
for Tokenizer in [
|
||||
# NLTKWordTokenizer,
|
||||
# PottsTokenizer,
|
||||
# PottsTokenizerWithNegation,
|
||||
LowercaseTokenizer,
|
||||
]:
|
||||
while True:
|
||||
try:
|
||||
tokenizer = Tokenizer()
|
||||
model = SentimentAnalyzer(tokenizer=tokenizer)
|
||||
log.debug("Ensuring there are no leftover caches...")
|
||||
Caches.ensure_clean()
|
||||
|
||||
with mongo_client_from_config() as db:
|
||||
log.debug("Finding the reviews MongoDB collection...")
|
||||
collection = reviews_collection(db)
|
||||
try:
|
||||
db.admin.command("ping")
|
||||
except pymongo.errors.ServerSelectionTimeoutError:
|
||||
log.fatal("MongoDB database is not available, exiting...")
|
||||
exit(1)
|
||||
|
||||
reviews = reviews_collection(db)
|
||||
|
||||
for sample_func in [sample_reviews_varied, sample_reviews_polar]:
|
||||
|
||||
for SentimentAnalyzer in [
|
||||
TensorflowSentimentAnalyzer,
|
||||
NLTKSentimentAnalyzer
|
||||
]:
|
||||
|
||||
for Tokenizer in [
|
||||
PlainTokenizer,
|
||||
LowercaseTokenizer,
|
||||
NLTKWordTokenizer,
|
||||
PottsTokenizer,
|
||||
PottsTokenizerWithNegation,
|
||||
]:
|
||||
|
||||
slog = logging.getLogger(f"{__name__}.{sample_func.__name__}.{SentimentAnalyzer.__name__}.{Tokenizer.__name__}")
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
training_cache = load_cache("./data/training")
|
||||
evaluation_cache = load_cache("./data/evaluation")
|
||||
except FileNotFoundError:
|
||||
log.debug("Gathering datasets...")
|
||||
reviews_training = dataset_func(collection=collection, amount=DATA_SET_SIZE.__wrapped__)
|
||||
reviews_evaluation = dataset_func(collection=collection, amount=DATA_SET_SIZE.__wrapped__)
|
||||
slog.debug("Creating sentiment analyzer...")
|
||||
sa = SentimentAnalyzer(tokenizer=Tokenizer())
|
||||
except TypeError:
|
||||
slog.warning("%s does not support %s, skipping...", Tokenizer.__name__, SentimentAnalyzer.__name__)
|
||||
break
|
||||
|
||||
log.debug("Caching datasets...")
|
||||
store_cache(reviews_training, "./data/training")
|
||||
store_cache(reviews_evaluation, "./data/evaluation")
|
||||
del reviews_training
|
||||
del reviews_evaluation
|
||||
|
||||
training_cache = load_cache("./data/training")
|
||||
evaluation_cache = load_cache("./data/evaluation")
|
||||
log.debug("Caches stored and loaded successfully!")
|
||||
else:
|
||||
log.debug("Caches loaded successfully!")
|
||||
|
||||
log.info("Training model: %s", model)
|
||||
model.train(training_cache)
|
||||
log.info("Evaluating model: %s", model)
|
||||
evaluation_results = model.evaluate(evaluation_cache)
|
||||
log.info("%s", evaluation_results)
|
||||
with Caches.from_database_samples(collection=reviews, sample_func=sample_func) as datasets:
|
||||
try:
|
||||
slog.info("Training sentiment analyzer: %s", sa)
|
||||
sa.train(training_dataset_func=datasets.training, validation_dataset_func=datasets.validation)
|
||||
|
||||
except TrainingFailedError:
|
||||
log.error("Training failed, restarting with a different dataset.")
|
||||
slog.error("Training failed, trying again with a different dataset...")
|
||||
continue
|
||||
|
||||
else:
|
||||
log.info("Training")
|
||||
slog.info("Training succeeded!")
|
||||
|
||||
slog.info("Evaluating sentiment analyzer: %s", sa)
|
||||
evaluation_results = sa.evaluate(evaluation_dataset_func=datasets.evaluation)
|
||||
slog.info("Evaluation results: %s", evaluation_results)
|
||||
break
|
||||
finally:
|
||||
delete_cache("./data/training")
|
||||
delete_cache("./data/evaluation")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
install_log_handler()
|
||||
config.proxies.resolve()
|
||||
main()
|
||||
|
|
|
@ -1,38 +1,42 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import dataclasses
|
||||
|
||||
from ..database import Text, Category, DatasetFunc
|
||||
from ..database import Text, Category, CachedDatasetFunc
|
||||
from ..tokenizer import BaseTokenizer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EvaluationResults:
|
||||
correct: int
|
||||
evaluated: int
|
||||
score: float
|
||||
|
||||
def __repr__(self):
|
||||
return f"<EvaluationResults: score of {self.score} out of {self.evaluated} evaluated tuples>"
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.evaluated} evaluated, {self.correct} correct, {self.correct / self.evaluated * 100:.2} % accuracy, {self.score:.2} score, {self.score / self.evaluated * 100:.2} scoreaccuracy"
|
||||
|
||||
|
||||
class BaseSentimentAnalyzer(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
Abstract base class for sentiment analyzers implemented in this project.
|
||||
"""
|
||||
|
||||
# noinspection PyUnusedLocal
|
||||
def __init__(self, *, tokenizer: BaseTokenizer):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__qualname__}>"
|
||||
|
||||
@abc.abstractmethod
|
||||
def train(self, dataset_func: DatasetFunc) -> None:
|
||||
def train(self, training_dataset_func: CachedDatasetFunc, validation_dataset_func: CachedDatasetFunc) -> None:
|
||||
"""
|
||||
Train the analyzer with the given training dataset.
|
||||
Train the analyzer with the given training and validation datasets.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def evaluate(self, dataset_func: DatasetFunc) -> EvaluationResults:
|
||||
@abc.abstractmethod
|
||||
def use(self, text: Text) -> Category:
|
||||
"""
|
||||
Run the model on the given input.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def evaluate(self, evaluation_dataset_func: CachedDatasetFunc) -> EvaluationResults:
|
||||
"""
|
||||
Perform a model evaluation by calling repeatedly `.use` on every text of the test dataset and by comparing its resulting category with the expected category.
|
||||
|
||||
|
@ -43,23 +47,30 @@ class BaseSentimentAnalyzer(metaclass=abc.ABCMeta):
|
|||
correct: int = 0
|
||||
score: float = 0.0
|
||||
|
||||
for review in dataset_func():
|
||||
for review in evaluation_dataset_func():
|
||||
resulting_category = self.use(review.text)
|
||||
evaluated += 1
|
||||
correct += 1 if resulting_category == review.category else 0
|
||||
score += 1 - (abs(resulting_category - review.category) / 4)
|
||||
if not evaluated % 100:
|
||||
temp_results = EvaluationResults(correct=correct, evaluated=evaluated, score=score)
|
||||
log.debug(f"{temp_results!s}")
|
||||
|
||||
return EvaluationResults(correct=correct, evaluated=evaluated, score=score)
|
||||
|
||||
@abc.abstractmethod
|
||||
def use(self, text: Text) -> Category:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EvaluationResults:
|
||||
"""
|
||||
Run the model on the given input.
|
||||
Container for the results of a dataset evaluation.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
correct: int
|
||||
evaluated: int
|
||||
score: float
|
||||
|
||||
def __repr__(self):
|
||||
return f"<EvaluationResults: {self!s}>"
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.evaluated} evaluated, {self.correct} correct, {self.correct / self.evaluated:.2%} accuracy, {self.score:.2f} score, {self.score / self.evaluated:.2%} scoreaccuracy"
|
||||
|
||||
|
||||
class AlreadyTrainedError(Exception):
|
||||
|
|
|
@ -6,7 +6,7 @@ import logging
|
|||
import typing as t
|
||||
import itertools
|
||||
|
||||
from ..database import Text, Category, Review, DatasetFunc
|
||||
from ..database import Text, Category, Review, CachedDatasetFunc
|
||||
from .base import BaseSentimentAnalyzer, AlreadyTrainedError, NotTrainedError
|
||||
from ..log import count_passage
|
||||
from ..tokenizer import BaseTokenizer
|
||||
|
@ -23,7 +23,11 @@ class NLTKSentimentAnalyzer(BaseSentimentAnalyzer):
|
|||
"""
|
||||
|
||||
def __init__(self, *, tokenizer: BaseTokenizer) -> None:
|
||||
super().__init__()
|
||||
if not tokenizer.supports_plain():
|
||||
raise TypeError("Tokenizer does not support NLTK")
|
||||
|
||||
super().__init__(tokenizer=tokenizer)
|
||||
|
||||
self.model: nltk.sentiment.SentimentAnalyzer = nltk.sentiment.SentimentAnalyzer()
|
||||
self.trained: bool = False
|
||||
self.tokenizer: BaseTokenizer = tokenizer
|
||||
|
@ -36,7 +40,7 @@ class NLTKSentimentAnalyzer(BaseSentimentAnalyzer):
|
|||
Convert the `Text` of a `DataTuple` to a `TokenBag`.
|
||||
"""
|
||||
count_passage(log, "tokenize_datatuple", 100)
|
||||
return self.tokenizer.tokenize_builtins(datatuple.text), datatuple.category
|
||||
return self.tokenizer.tokenize_plain(datatuple.text), datatuple.category
|
||||
|
||||
def _add_feature_unigrams(self, dataset: t.Iterator[tuple[TokenBag, Category]]) -> None:
|
||||
"""
|
||||
|
@ -67,13 +71,13 @@ class NLTKSentimentAnalyzer(BaseSentimentAnalyzer):
|
|||
count_passage(log, "extract_features", 100)
|
||||
return self.model.extract_features(data[0]), data[1]
|
||||
|
||||
def train(self, dataset_func: DatasetFunc) -> None:
|
||||
def train(self, training_dataset_func: CachedDatasetFunc, validation_dataset_func: CachedDatasetFunc) -> None:
|
||||
# Forbid retraining the model
|
||||
if self.trained:
|
||||
raise AlreadyTrainedError()
|
||||
|
||||
# Get a generator
|
||||
dataset: t.Generator[Review] = dataset_func()
|
||||
dataset: t.Generator[Review] = training_dataset_func()
|
||||
|
||||
# Tokenize the dataset
|
||||
dataset: t.Iterator[tuple[TokenBag, Category]] = map(self.__tokenize_review, dataset)
|
||||
|
@ -103,7 +107,7 @@ class NLTKSentimentAnalyzer(BaseSentimentAnalyzer):
|
|||
raise NotTrainedError()
|
||||
|
||||
# Tokenize the input
|
||||
tokens = self.tokenizer.tokenize_builtins(text)
|
||||
tokens = self.tokenizer.tokenize_plain(text)
|
||||
|
||||
# Run the classification method
|
||||
return self.model.classify(instance=tokens)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import tensorflow
|
||||
import logging
|
||||
|
||||
from ..database import Text, Category, DatasetFunc
|
||||
from ..database import Text, Category, CachedDatasetFunc
|
||||
from ..config import TENSORFLOW_EMBEDDING_SIZE, TENSORFLOW_MAX_FEATURES, TENSORFLOW_EPOCHS
|
||||
from ..tokenizer import BaseTokenizer
|
||||
from .base import BaseSentimentAnalyzer, AlreadyTrainedError, NotTrainedError, TrainingFailedError
|
||||
|
@ -9,9 +9,19 @@ from .base import BaseSentimentAnalyzer, AlreadyTrainedError, NotTrainedError, T
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if len(tensorflow.config.list_physical_devices(device_type="GPU")) == 0:
|
||||
log.warning("Tensorflow reports no GPU acceleration available.")
|
||||
else:
|
||||
log.debug("Tensorflow successfully found GPU acceleration!")
|
||||
|
||||
|
||||
class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer):
|
||||
def __init__(self, tokenizer: BaseTokenizer):
|
||||
super().__init__()
|
||||
def __init__(self, *, tokenizer: BaseTokenizer):
|
||||
if not tokenizer.supports_tensorflow():
|
||||
raise TypeError("Tokenizer does not support Tensorflow")
|
||||
|
||||
super().__init__(tokenizer=tokenizer)
|
||||
|
||||
self.trained: bool = False
|
||||
|
||||
self.text_vectorization_layer: tensorflow.keras.layers.TextVectorization = self._build_vectorizer(tokenizer)
|
||||
|
@ -19,7 +29,11 @@ class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer):
|
|||
self.history: tensorflow.keras.callbacks.History | None = None
|
||||
|
||||
@staticmethod
|
||||
def _build_dataset(dataset_func: DatasetFunc) -> tensorflow.data.Dataset:
|
||||
def _build_dataset(dataset_func: CachedDatasetFunc) -> tensorflow.data.Dataset:
|
||||
"""
|
||||
Convert a `CachedDatasetFunc` to a `tensorflow.data.Dataset`.
|
||||
"""
|
||||
|
||||
def dataset_func_with_tensor_tuple():
|
||||
for review in dataset_func():
|
||||
yield review.to_tensor_tuple()
|
||||
|
@ -43,15 +57,16 @@ class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer):
|
|||
|
||||
@staticmethod
|
||||
def _build_model() -> tensorflow.keras.Sequential:
|
||||
log.debug("Creating %s model...", tensorflow.keras.Sequential)
|
||||
log.debug("Creating model...")
|
||||
model = tensorflow.keras.Sequential([
|
||||
tensorflow.keras.layers.Embedding(
|
||||
input_dim=TENSORFLOW_MAX_FEATURES.__wrapped__ + 1,
|
||||
output_dim=TENSORFLOW_EMBEDDING_SIZE.__wrapped__,
|
||||
),
|
||||
tensorflow.keras.layers.Dropout(0.2),
|
||||
tensorflow.keras.layers.Dropout(0.25),
|
||||
tensorflow.keras.layers.GlobalAveragePooling1D(),
|
||||
tensorflow.keras.layers.Dropout(0.2),
|
||||
tensorflow.keras.layers.Dropout(0.25),
|
||||
tensorflow.keras.layers.Dense(25),
|
||||
tensorflow.keras.layers.Dense(5, activation="softmax"),
|
||||
])
|
||||
log.debug("Compiling model: %s", model)
|
||||
|
@ -72,31 +87,35 @@ class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer):
|
|||
max_tokens=TENSORFLOW_MAX_FEATURES.__wrapped__
|
||||
)
|
||||
|
||||
def train(self, dataset_func: DatasetFunc) -> None:
|
||||
def train(self, training_dataset_func: CachedDatasetFunc, validation_dataset_func: CachedDatasetFunc) -> None:
|
||||
if self.trained:
|
||||
log.error("Tried to train an already trained model.")
|
||||
raise AlreadyTrainedError()
|
||||
|
||||
log.debug("Building dataset...")
|
||||
training_set = self._build_dataset(dataset_func)
|
||||
log.debug("Building datasets...")
|
||||
training_set = self._build_dataset(training_dataset_func)
|
||||
validation_set = self._build_dataset(validation_dataset_func)
|
||||
log.debug("Built dataset: %s", training_set)
|
||||
|
||||
log.debug("Preparing training_set for %s...", self.text_vectorization_layer.adapt)
|
||||
only_text_set = training_set.map(lambda text, category: text)
|
||||
|
||||
log.debug("Adapting text_vectorization_layer: %s", self.text_vectorization_layer)
|
||||
self.text_vectorization_layer.adapt(only_text_set)
|
||||
log.debug("Adapted text_vectorization_layer: %s", self.text_vectorization_layer)
|
||||
|
||||
log.debug("Preparing training_set for %s...", self.model.fit)
|
||||
training_set = training_set.map(lambda text, category: (self.text_vectorization_layer(text), category))
|
||||
validation_set = validation_set.map(lambda text, category: (self.text_vectorization_layer(text), category))
|
||||
log.info("Training: %s", self.model)
|
||||
self.history: tensorflow.keras.callbacks.History | None = self.model.fit(
|
||||
training_set,
|
||||
validation_data=validation_set,
|
||||
epochs=TENSORFLOW_EPOCHS.__wrapped__,
|
||||
callbacks=[
|
||||
tensorflow.keras.callbacks.TerminateOnNaN()
|
||||
])
|
||||
log.info("Trained: %s", self.model)
|
||||
],
|
||||
)
|
||||
|
||||
if len(self.history.epoch) < TENSORFLOW_EPOCHS.__wrapped__:
|
||||
log.error("Model %s training failed: only %d epochs computed", self.model, len(self.history.epoch))
|
||||
|
|
|
@ -45,14 +45,44 @@ def WORKING_SET_SIZE(val: str | None) -> int:
|
|||
|
||||
|
||||
@config.optional()
|
||||
def DATA_SET_SIZE(val: str | None) -> int:
|
||||
def TRAINING_SET_SIZE(val: str | None) -> int:
|
||||
"""
|
||||
The number of reviews from each category to fetch for the datasets.
|
||||
The number of reviews from each category to fetch for the training dataset.
|
||||
|
||||
Defaults to `1750`.
|
||||
Defaults to `5000`.
|
||||
"""
|
||||
if val is None:
|
||||
return 1750
|
||||
return 5000
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
raise cfig.InvalidValueError("Not an int.")
|
||||
|
||||
|
||||
@config.optional()
|
||||
def VALIDATION_SET_SIZE(val: str | None) -> int:
|
||||
"""
|
||||
The number of reviews from each category to fetch for the training dataset.
|
||||
|
||||
Defaults to `400`.
|
||||
"""
|
||||
if val is None:
|
||||
return 400
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
raise cfig.InvalidValueError("Not an int.")
|
||||
|
||||
|
||||
@config.optional()
|
||||
def EVALUATION_SET_SIZE(val: str | None) -> int:
|
||||
"""
|
||||
The number of reviews from each category to fetch for the evaluation dataset.
|
||||
|
||||
Defaults to `1000`.
|
||||
"""
|
||||
if val is None:
|
||||
return 1000
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
|
@ -79,10 +109,10 @@ def TENSORFLOW_EMBEDDING_SIZE(val: str | None) -> int:
|
|||
"""
|
||||
The size of the embeddings tensor to use in Tensorflow models.
|
||||
|
||||
Defaults to `12`.
|
||||
Defaults to `6`.
|
||||
"""
|
||||
if val is None:
|
||||
return 12
|
||||
return 6
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
|
@ -94,10 +124,10 @@ def TENSORFLOW_EPOCHS(val: str | None) -> int:
|
|||
"""
|
||||
The number of epochs to train Tensorflow models for.
|
||||
|
||||
Defaults to `15`.
|
||||
Defaults to `12`.
|
||||
"""
|
||||
if val is None:
|
||||
return 15
|
||||
return 12
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
|
@ -109,7 +139,9 @@ __all__ = (
|
|||
"MONGO_HOST",
|
||||
"MONGO_PORT",
|
||||
"WORKING_SET_SIZE",
|
||||
"DATA_SET_SIZE",
|
||||
"TRAINING_SET_SIZE",
|
||||
"VALIDATION_SET_SIZE",
|
||||
"EVALUATION_SET_SIZE",
|
||||
"TENSORFLOW_MAX_FEATURES",
|
||||
"TENSORFLOW_EMBEDDING_SIZE",
|
||||
"TENSORFLOW_EPOCHS",
|
||||
|
|
|
@ -9,7 +9,7 @@ from .datatypes import Review
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DatasetFunc = t.Callable[[], t.Generator[Review, t.Any, None]]
|
||||
CachedDatasetFunc = t.Callable[[], t.Generator[Review, t.Any, None]]
|
||||
|
||||
|
||||
def store_cache(reviews: t.Iterator[Review], path: str | pathlib.Path) -> None:
|
||||
|
@ -34,7 +34,7 @@ def store_cache(reviews: t.Iterator[Review], path: str | pathlib.Path) -> None:
|
|||
pickle.dump(document, file)
|
||||
|
||||
|
||||
def load_cache(path: str | pathlib.Path) -> DatasetFunc:
|
||||
def load_cache(path: str | pathlib.Path) -> CachedDatasetFunc:
|
||||
"""
|
||||
Load the contents of a directory into a `Review` iterator.
|
||||
"""
|
||||
|
@ -69,12 +69,12 @@ def delete_cache(path: str | pathlib.Path) -> None:
|
|||
if not path.exists():
|
||||
raise FileNotFoundError("The specified path does not exist.")
|
||||
|
||||
log.warning("Deleting cache directory: %s", path)
|
||||
log.debug("Deleting cache directory: %s", path)
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"DatasetFunc",
|
||||
"CachedDatasetFunc",
|
||||
"store_cache",
|
||||
"load_cache",
|
||||
"delete_cache",
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import pymongo
|
||||
import pymongo.errors
|
||||
import contextlib
|
||||
import typing as t
|
||||
import logging
|
||||
|
@ -13,12 +14,12 @@ def mongo_client_from_config() -> t.ContextManager[pymongo.MongoClient]:
|
|||
"""
|
||||
Create a new MongoDB client and yield it.
|
||||
"""
|
||||
log.debug("Opening connection to MongoDB...")
|
||||
log.debug("Creating MongoDB client...")
|
||||
client: pymongo.MongoClient = pymongo.MongoClient(
|
||||
host=MONGO_HOST.__wrapped__,
|
||||
port=MONGO_PORT.__wrapped__,
|
||||
)
|
||||
log.info("Opened connection to MongoDB!")
|
||||
log.debug("Created MongoDB client!")
|
||||
|
||||
yield client
|
||||
|
||||
|
|
|
@ -10,6 +10,11 @@ Category = float
|
|||
|
||||
|
||||
class Review:
|
||||
__slots__ = (
|
||||
"text",
|
||||
"category",
|
||||
)
|
||||
|
||||
def __init__(self, text: Text, category: Category):
|
||||
self.text: str = text
|
||||
self.category: float = category
|
||||
|
|
|
@ -9,6 +9,9 @@ from .datatypes import Review
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SampleFunc = t.Callable[[pymongo.collection.Collection, int], t.Iterator[Review]]
|
||||
|
||||
|
||||
def sample_reviews(collection: pymongo.collection.Collection, amount: int) -> t.Iterator[Review]:
|
||||
"""
|
||||
Get ``amount`` random reviews from the ``reviews`` collection.
|
||||
|
@ -41,18 +44,20 @@ def sample_reviews_by_rating(collection: pymongo.collection.Collection, rating:
|
|||
|
||||
|
||||
def sample_reviews_polar(collection: pymongo.collection.Collection, amount: int) -> t.Iterator[Review]:
|
||||
log.debug("Getting a sample of %d polar reviews...", amount * 2)
|
||||
category_amount = amount // 2
|
||||
|
||||
log.debug("Getting a sample of %d polar reviews...", category_amount * 2)
|
||||
|
||||
cursor = collection.aggregate([
|
||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||
{"$match": {"overall": 1.0}},
|
||||
{"$sample": {"size": amount}},
|
||||
{"$sample": {"size": category_amount}},
|
||||
{"$unionWith": {
|
||||
"coll": collection.name,
|
||||
"pipeline": [
|
||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||
{"$match": {"overall": 5.0}},
|
||||
{"$sample": {"size": amount}},
|
||||
{"$sample": {"size": category_amount}},
|
||||
],
|
||||
}},
|
||||
{"$addFields": {
|
||||
|
@ -69,37 +74,39 @@ def sample_reviews_polar(collection: pymongo.collection.Collection, amount: int)
|
|||
|
||||
|
||||
def sample_reviews_varied(collection: pymongo.collection.Collection, amount: int) -> t.Iterator[Review]:
|
||||
log.debug("Getting a sample of %d varied reviews...", amount * 5)
|
||||
category_amount = amount // 5
|
||||
|
||||
log.debug("Getting a sample of %d varied reviews...", category_amount * 5)
|
||||
|
||||
# Wow, this is ugly.
|
||||
cursor = collection.aggregate([
|
||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||
{"$match": {"overall": 1.0}},
|
||||
{"$sample": {"size": amount}},
|
||||
{"$sample": {"size": category_amount}},
|
||||
{"$unionWith": {
|
||||
"coll": collection.name,
|
||||
"pipeline": [
|
||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||
{"$match": {"overall": 2.0}},
|
||||
{"$sample": {"size": amount}},
|
||||
{"$sample": {"size": category_amount}},
|
||||
{"$unionWith": {
|
||||
"coll": collection.name,
|
||||
"pipeline": [
|
||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||
{"$match": {"overall": 3.0}},
|
||||
{"$sample": {"size": amount}},
|
||||
{"$sample": {"size": category_amount}},
|
||||
{"$unionWith": {
|
||||
"coll": collection.name,
|
||||
"pipeline": [
|
||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||
{"$match": {"overall": 4.0}},
|
||||
{"$sample": {"size": amount}},
|
||||
{"$sample": {"size": category_amount}},
|
||||
{"$unionWith": {
|
||||
"coll": collection.name,
|
||||
"pipeline": [
|
||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||
{"$match": {"overall": 5.0}},
|
||||
{"$sample": {"size": amount}},
|
||||
{"$sample": {"size": category_amount}},
|
||||
],
|
||||
}}
|
||||
],
|
||||
|
@ -122,6 +129,7 @@ def sample_reviews_varied(collection: pymongo.collection.Collection, amount: int
|
|||
|
||||
|
||||
__all__ = (
|
||||
"SampleFunc",
|
||||
"sample_reviews",
|
||||
"sample_reviews_by_rating",
|
||||
"sample_reviews_polar",
|
||||
|
|
|
@ -15,15 +15,15 @@ def install_log_handler(loggers: list[logging.Logger] = None):
|
|||
for logger in loggers:
|
||||
coloredlogs.install(
|
||||
logger=logger,
|
||||
level="DEBUG",
|
||||
fmt="{asctime} | {name:<32} | {levelname:>8} | {message}",
|
||||
level="DEBUG" if __debug__ else "INFO",
|
||||
fmt="{asctime} | {name:<80} | {levelname:>8} | {message}",
|
||||
style="{",
|
||||
level_styles=dict(
|
||||
debug=dict(color="white"),
|
||||
info=dict(color="cyan"),
|
||||
warning=dict(color="yellow"),
|
||||
error=dict(color="red"),
|
||||
critical=dict(color="red", bold=True),
|
||||
critical=dict(color="black", background="red", bold=True),
|
||||
),
|
||||
field_styles=dict(
|
||||
asctime=dict(color='magenta'),
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .base import BaseTokenizer
|
||||
from .nltk_word_tokenize import NLTKWordTokenizer
|
||||
from .potts import PottsTokenizer, PottsTokenizerWithNegation
|
||||
from .plain import PlainTokenizer
|
||||
from .lower import LowercaseTokenizer
|
||||
|
||||
|
||||
|
@ -9,5 +10,6 @@ __all__ = (
|
|||
"NLTKWordTokenizer",
|
||||
"PottsTokenizer",
|
||||
"PottsTokenizerWithNegation",
|
||||
"PlainTokenizer",
|
||||
"LowercaseTokenizer",
|
||||
)
|
||||
|
|
|
@ -14,21 +14,21 @@ class BaseTokenizer:
|
|||
f.__notimplemented__ = True
|
||||
return f
|
||||
|
||||
def can_tokenize_builtins(self) -> bool:
|
||||
return getattr(self.tokenize_builtins, "__notimplemented__", False)
|
||||
def supports_plain(self) -> bool:
|
||||
return not getattr(self.tokenize_plain, "__notimplemented__", False)
|
||||
|
||||
def can_tokenize_tensorflow(self) -> bool:
|
||||
return getattr(self.tokenize_tensorflow, "__notimplemented__", False)
|
||||
def supports_tensorflow(self) -> bool:
|
||||
return not getattr(self.tokenize_tensorflow, "__notimplemented__", False)
|
||||
|
||||
@__not_implemented
|
||||
def tokenize_builtins(self, text: str) -> list[str]:
|
||||
def tokenize_plain(self, text: str) -> list[str]:
|
||||
"""
|
||||
Convert a text string into a list of tokens.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@__not_implemented
|
||||
def tokenize_tensorflow(self, text: tensorflow.Tensor) -> tensorflow.Tensor:
|
||||
def tokenize_tensorflow(self, text: "tensorflow.Tensor") -> "tensorflow.Tensor":
|
||||
"""
|
||||
Convert a `tensorflow.Tensor` string into another `tensorflow.Tensor` space-separated string.
|
||||
"""
|
||||
|
|
|
@ -4,7 +4,11 @@ from .base import BaseTokenizer
|
|||
|
||||
|
||||
class LowercaseTokenizer(BaseTokenizer):
|
||||
def tokenize_builtins(self, text: str) -> list[str]:
|
||||
"""
|
||||
Tokenizer which converts the words to lowercase before splitting them via spaces.
|
||||
"""
|
||||
|
||||
def tokenize_plain(self, text: str) -> list[str]:
|
||||
return text.lower().split()
|
||||
|
||||
def tokenize_tensorflow(self, text: tensorflow.Tensor) -> tensorflow.Tensor:
|
||||
|
|
|
@ -10,7 +10,7 @@ class NLTKWordTokenizer(BaseTokenizer):
|
|||
Tokenizer based on `nltk.word_tokenize`.
|
||||
"""
|
||||
|
||||
def tokenize_builtins(self, text: str) -> t.Iterable[str]:
|
||||
def tokenize_plain(self, text: str) -> t.Iterable[str]:
|
||||
tokens = nltk.word_tokenize(text)
|
||||
nltk.sentiment.util.mark_negation(tokens, shallow=True)
|
||||
return tokens
|
||||
|
|
16
unimore_bda_6/tokenizer/plain.py
Normal file
16
unimore_bda_6/tokenizer/plain.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import tensorflow
|
||||
|
||||
from .base import BaseTokenizer
|
||||
|
||||
|
||||
class PlainTokenizer(BaseTokenizer):
|
||||
"""
|
||||
Tokenizer which just splits the text into tokens by separating them at whitespaces.
|
||||
"""
|
||||
|
||||
def tokenize_plain(self, text: str) -> list[str]:
|
||||
return text.split()
|
||||
|
||||
def tokenize_tensorflow(self, text: tensorflow.Tensor) -> tensorflow.Tensor:
|
||||
text = tensorflow.expand_dims(text, -1, name="tokens")
|
||||
return text
|
|
@ -124,7 +124,7 @@ regex_strings = (
|
|||
|
|
||||
(?:\S) # Everything else that isn't whitespace.
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
######################################################################
|
||||
# This is the core tokenizing regex:
|
||||
|
@ -139,6 +139,7 @@ html_entity_digit_re = re.compile(r"&#\d+;")
|
|||
html_entity_alpha_re = re.compile(r"&\w+;")
|
||||
amp = "&"
|
||||
|
||||
|
||||
######################################################################
|
||||
|
||||
|
||||
|
@ -165,7 +166,7 @@ class PottsTokenizer(BaseTokenizer):
|
|||
pass
|
||||
# Now the alpha versions:
|
||||
ents = set(html_entity_alpha_re.findall(s))
|
||||
ents = filter((lambda x : x != amp), ents)
|
||||
ents = filter((lambda x: x != amp), ents)
|
||||
for ent in ents:
|
||||
entname = ent[1:-1]
|
||||
try:
|
||||
|
@ -175,7 +176,7 @@ class PottsTokenizer(BaseTokenizer):
|
|||
s = s.replace(amp, " and ")
|
||||
return s
|
||||
|
||||
def tokenize_builtins(self, text: str) -> t.Iterable[str]:
|
||||
def tokenize_plain(self, text: str) -> t.Iterable[str]:
|
||||
# Fix HTML character entitites:
|
||||
s = self.__html2string(text)
|
||||
# Tokenize:
|
||||
|
@ -187,8 +188,8 @@ class PottsTokenizer(BaseTokenizer):
|
|||
|
||||
|
||||
class PottsTokenizerWithNegation(PottsTokenizer):
|
||||
def tokenize_builtins(self, text: str) -> t.Iterable[str]:
|
||||
words = super().tokenize_builtins(text)
|
||||
def tokenize_plain(self, text: str) -> t.Iterable[str]:
|
||||
words = super().tokenize_plain(text)
|
||||
nltk.sentiment.util.mark_negation(words, shallow=True)
|
||||
return words
|
||||
|
||||
|
|
Loading…
Reference in a new issue