mirror of
https://github.com/Steffo99/unimore-bda-6.git
synced 2024-11-26 01:34:20 +00:00
enough
This commit is contained in:
parent
4d6c8f0fee
commit
e3005ab8b0
21 changed files with 309 additions and 149 deletions
|
@ -40,6 +40,7 @@
|
||||||
<option value="E501" />
|
<option value="E501" />
|
||||||
<option value="E221" />
|
<option value="E221" />
|
||||||
<option value="E203" />
|
<option value="E203" />
|
||||||
|
<option value="E402" />
|
||||||
</list>
|
</list>
|
||||||
</option>
|
</option>
|
||||||
</inspection_tool>
|
</inspection_tool>
|
||||||
|
|
|
@ -10,4 +10,7 @@
|
||||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_19">
|
<component name="ProjectRootManager" version="2" languageLevel="JDK_19">
|
||||||
<output url="file://$PROJECT_DIR$/out" />
|
<output url="file://$PROJECT_DIR$/out" />
|
||||||
</component>
|
</component>
|
||||||
|
<component name="PythonCompatibilityInspectionAdvertiser">
|
||||||
|
<option name="version" value="3" />
|
||||||
|
</component>
|
||||||
</project>
|
</project>
|
|
@ -1,12 +1,12 @@
|
||||||
<component name="ProjectRunConfigurationManager">
|
<component name="ProjectRunConfigurationManager">
|
||||||
<configuration default="false" name="unimore_bda_6" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true">
|
<configuration default="false" name="unimore_bda_6" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true">
|
||||||
<module name="unimore-bda-6" />
|
<module name="unimore-bda-6" />
|
||||||
<option name="INTERPRETER_OPTIONS" value="" />
|
<option name="INTERPRETER_OPTIONS" value="-O" />
|
||||||
<option name="PARENT_ENVS" value="true" />
|
<option name="PARENT_ENVS" value="true" />
|
||||||
<envs>
|
<envs>
|
||||||
|
<env name="PYTHONUNBUFFERED" value="1" />
|
||||||
<env name="CONFIRM_OVERWRITE" value="False" />
|
<env name="CONFIRM_OVERWRITE" value="False" />
|
||||||
<env name="NLTK_DATA" value="./data/nltk" />
|
<env name="NLTK_DATA" value="./data/nltk" />
|
||||||
<env name="PYTHONUNBUFFERED" value="1" />
|
|
||||||
<env name="TF_CPP_MIN_LOG_LEVEL" value="2" />
|
<env name="TF_CPP_MIN_LOG_LEVEL" value="2" />
|
||||||
<env name="WORKING_SET_SIZE" value="1000000" />
|
<env name="WORKING_SET_SIZE" value="1000000" />
|
||||||
<env name="XLA_FLAGS" value="--xla_gpu_cuda_data_dir=/opt/cuda" />
|
<env name="XLA_FLAGS" value="--xla_gpu_cuda_data_dir=/opt/cuda" />
|
||||||
|
|
57
poetry.lock
generated
57
poetry.lock
generated
|
@ -1230,6 +1230,61 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
tests = ["pytest", "pytest-cov"]
|
tests = ["pytest", "pytest-cov"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokenizers"
|
||||||
|
version = "0.13.2"
|
||||||
|
description = "Fast and Customizable Tokenizers"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "tokenizers-0.13.2-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:a6f36b1b499233bb4443b5e57e20630c5e02fba61109632f5e00dab970440157"},
|
||||||
|
{file = "tokenizers-0.13.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:bc6983282ee74d638a4b6d149e5dadd8bc7ff1d0d6de663d69f099e0c6bddbeb"},
|
||||||
|
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16756e6ab264b162f99c0c0a8d3d521328f428b33374c5ee161c0ebec42bf3c0"},
|
||||||
|
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b10db6e4b036c78212c6763cb56411566edcf2668c910baa1939afd50095ce48"},
|
||||||
|
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:238e879d1a0f4fddc2ce5b2d00f219125df08f8532e5f1f2ba9ad42f02b7da59"},
|
||||||
|
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"},
|
||||||
|
{file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"},
|
||||||
|
{file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"},
|
||||||
|
{file = "tokenizers-0.13.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:9eee037bb5aa14daeb56b4c39956164b2bebbe6ab4ca7779d88aa16b79bd4e17"},
|
||||||
|
{file = "tokenizers-0.13.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d1b079c4c9332048fec4cb9c2055c2373c74fbb336716a5524c9a720206d787e"},
|
||||||
|
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"},
|
||||||
|
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"},
|
||||||
|
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"},
|
||||||
|
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"},
|
||||||
|
{file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"},
|
||||||
|
{file = "tokenizers-0.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:fa7ef7ee380b1f49211bbcfac8a006b1a3fa2fa4c7f4ee134ae384eb4ea5e453"},
|
||||||
|
{file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"},
|
||||||
|
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"},
|
||||||
|
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"},
|
||||||
|
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0901a5c6538d2d2dc752c6b4bde7dab170fddce559ec75662cfad03b3187c8f6"},
|
||||||
|
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ba9baa76b5a3eefa78b6cc351315a216232fd727ee5e3ce0f7c6885d9fb531b"},
|
||||||
|
{file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"},
|
||||||
|
{file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"},
|
||||||
|
{file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"},
|
||||||
|
{file = "tokenizers-0.13.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f44d59bafe3d61e8a56b9e0a963075187c0f0091023120b13fbe37a87936f171"},
|
||||||
|
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"},
|
||||||
|
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"},
|
||||||
|
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"},
|
||||||
|
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3bc9f7d7f4c1aa84bb6b8d642a60272c8a2c987669e9bb0ac26daf0c6a9fc8"},
|
||||||
|
{file = "tokenizers-0.13.2-cp38-cp38-win32.whl", hash = "sha256:efbf189fb9cf29bd29e98c0437bdb9809f9de686a1e6c10e0b954410e9ca2142"},
|
||||||
|
{file = "tokenizers-0.13.2-cp38-cp38-win_amd64.whl", hash = "sha256:0b4cb2c60c094f31ea652f6cf9f349aae815f9243b860610c29a69ed0d7a88f8"},
|
||||||
|
{file = "tokenizers-0.13.2-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:b47d6212e7dd05784d7330b3b1e5a170809fa30e2b333ca5c93fba1463dec2b7"},
|
||||||
|
{file = "tokenizers-0.13.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:80a57501b61ec4f94fb7ce109e2b4a1a090352618efde87253b4ede6d458b605"},
|
||||||
|
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61507a9953f6e7dc3c972cbc57ba94c80c8f7f686fbc0876afe70ea2b8cc8b04"},
|
||||||
|
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c09f4fa620e879debdd1ec299bb81e3c961fd8f64f0e460e64df0818d29d845c"},
|
||||||
|
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:66c892d85385b202893ac6bc47b13390909e205280e5df89a41086cfec76fedb"},
|
||||||
|
{file = "tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3e306b0941ad35087ae7083919a5c410a6b672be0343609d79a1171a364ce79"},
|
||||||
|
{file = "tokenizers-0.13.2-cp39-cp39-win32.whl", hash = "sha256:79189e7f706c74dbc6b753450757af172240916d6a12ed4526af5cc6d3ceca26"},
|
||||||
|
{file = "tokenizers-0.13.2-cp39-cp39-win_amd64.whl", hash = "sha256:486d637b413fddada845a10a45c74825d73d3725da42ccd8796ccd7a1c07a024"},
|
||||||
|
{file = "tokenizers-0.13.2.tar.gz", hash = "sha256:f9525375582fd1912ac3caa2f727d36c86ff8c0c6de45ae1aaff90f87f33b907"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
|
||||||
|
docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
|
||||||
|
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tqdm"
|
name = "tqdm"
|
||||||
version = "4.64.1"
|
version = "4.64.1"
|
||||||
|
@ -1390,4 +1445,4 @@ files = [
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "~3.10"
|
python-versions = "~3.10"
|
||||||
content-hash = "b17bed73011c627355660e5ba5ca176a920c1d87915a1ab875e5a5ddd28a9dca"
|
content-hash = "d63867d77886e8ae2c09771a5ec1053dae99a5699f7e905a69bd298e1b986a80"
|
||||||
|
|
|
@ -137,6 +137,7 @@ nltk = "^3.8.1"
|
||||||
cfig = {extras = ["cli"], version = "^0.3.0"}
|
cfig = {extras = ["cli"], version = "^0.3.0"}
|
||||||
coloredlogs = "^15.0.1"
|
coloredlogs = "^15.0.1"
|
||||||
tensorflow = "^2.11.0"
|
tensorflow = "^2.11.0"
|
||||||
|
tokenizers = "^0.13.2"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,84 +1,81 @@
|
||||||
import logging
|
import logging
|
||||||
import tensorflow
|
import pymongo.errors
|
||||||
|
from .log import install_log_handler
|
||||||
|
|
||||||
from .config import config, DATA_SET_SIZE
|
install_log_handler()
|
||||||
from .database import mongo_client_from_config, reviews_collection, sample_reviews_polar, sample_reviews_varied, store_cache, load_cache, delete_cache
|
|
||||||
|
from .config import config
|
||||||
|
from .database import mongo_client_from_config, reviews_collection, sample_reviews_polar, sample_reviews_varied
|
||||||
from .analysis.nltk_sentiment import NLTKSentimentAnalyzer
|
from .analysis.nltk_sentiment import NLTKSentimentAnalyzer
|
||||||
from .analysis.tf_text import TensorflowSentimentAnalyzer
|
from .analysis.tf_text import TensorflowSentimentAnalyzer
|
||||||
from .analysis.base import TrainingFailedError
|
from .analysis.base import TrainingFailedError
|
||||||
from .tokenizer import LowercaseTokenizer
|
from .tokenizer import PlainTokenizer, LowercaseTokenizer, NLTKWordTokenizer, PottsTokenizer, PottsTokenizerWithNegation
|
||||||
from .log import install_log_handler
|
from .gathering import Caches
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if len(tensorflow.config.list_physical_devices(device_type="GPU")) == 0:
|
log.info("Started unimore-bda-6 in %s mode!", "DEBUG" if __debug__ else "PRODUCTION")
|
||||||
log.warning("Tensorflow reports no GPU acceleration available.")
|
|
||||||
else:
|
|
||||||
log.debug("Tensorflow successfully found GPU acceleration!")
|
|
||||||
|
|
||||||
try:
|
log.debug("Validating configuration...")
|
||||||
delete_cache("./data/training")
|
config.proxies.resolve()
|
||||||
delete_cache("./data/evaluation")
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
for dataset_func in [sample_reviews_polar, sample_reviews_varied]:
|
log.debug("Ensuring there are no leftover caches...")
|
||||||
for SentimentAnalyzer in [TensorflowSentimentAnalyzer, NLTKSentimentAnalyzer]:
|
Caches.ensure_clean()
|
||||||
for Tokenizer in [
|
|
||||||
# NLTKWordTokenizer,
|
|
||||||
# PottsTokenizer,
|
|
||||||
# PottsTokenizerWithNegation,
|
|
||||||
LowercaseTokenizer,
|
|
||||||
]:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
tokenizer = Tokenizer()
|
|
||||||
model = SentimentAnalyzer(tokenizer=tokenizer)
|
|
||||||
|
|
||||||
with mongo_client_from_config() as db:
|
with mongo_client_from_config() as db:
|
||||||
log.debug("Finding the reviews MongoDB collection...")
|
try:
|
||||||
collection = reviews_collection(db)
|
db.admin.command("ping")
|
||||||
|
except pymongo.errors.ServerSelectionTimeoutError:
|
||||||
|
log.fatal("MongoDB database is not available, exiting...")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
reviews = reviews_collection(db)
|
||||||
|
|
||||||
|
for sample_func in [sample_reviews_varied, sample_reviews_polar]:
|
||||||
|
|
||||||
|
for SentimentAnalyzer in [
|
||||||
|
TensorflowSentimentAnalyzer,
|
||||||
|
NLTKSentimentAnalyzer
|
||||||
|
]:
|
||||||
|
|
||||||
|
for Tokenizer in [
|
||||||
|
PlainTokenizer,
|
||||||
|
LowercaseTokenizer,
|
||||||
|
NLTKWordTokenizer,
|
||||||
|
PottsTokenizer,
|
||||||
|
PottsTokenizerWithNegation,
|
||||||
|
]:
|
||||||
|
|
||||||
|
slog = logging.getLogger(f"{__name__}.{sample_func.__name__}.{SentimentAnalyzer.__name__}.{Tokenizer.__name__}")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
training_cache = load_cache("./data/training")
|
slog.debug("Creating sentiment analyzer...")
|
||||||
evaluation_cache = load_cache("./data/evaluation")
|
sa = SentimentAnalyzer(tokenizer=Tokenizer())
|
||||||
except FileNotFoundError:
|
except TypeError:
|
||||||
log.debug("Gathering datasets...")
|
slog.warning("%s does not support %s, skipping...", Tokenizer.__name__, SentimentAnalyzer.__name__)
|
||||||
reviews_training = dataset_func(collection=collection, amount=DATA_SET_SIZE.__wrapped__)
|
break
|
||||||
reviews_evaluation = dataset_func(collection=collection, amount=DATA_SET_SIZE.__wrapped__)
|
|
||||||
|
|
||||||
log.debug("Caching datasets...")
|
with Caches.from_database_samples(collection=reviews, sample_func=sample_func) as datasets:
|
||||||
store_cache(reviews_training, "./data/training")
|
try:
|
||||||
store_cache(reviews_evaluation, "./data/evaluation")
|
slog.info("Training sentiment analyzer: %s", sa)
|
||||||
del reviews_training
|
sa.train(training_dataset_func=datasets.training, validation_dataset_func=datasets.validation)
|
||||||
del reviews_evaluation
|
|
||||||
|
|
||||||
training_cache = load_cache("./data/training")
|
|
||||||
evaluation_cache = load_cache("./data/evaluation")
|
|
||||||
log.debug("Caches stored and loaded successfully!")
|
|
||||||
else:
|
|
||||||
log.debug("Caches loaded successfully!")
|
|
||||||
|
|
||||||
log.info("Training model: %s", model)
|
|
||||||
model.train(training_cache)
|
|
||||||
log.info("Evaluating model: %s", model)
|
|
||||||
evaluation_results = model.evaluate(evaluation_cache)
|
|
||||||
log.info("%s", evaluation_results)
|
|
||||||
|
|
||||||
except TrainingFailedError:
|
except TrainingFailedError:
|
||||||
log.error("Training failed, restarting with a different dataset.")
|
slog.error("Training failed, trying again with a different dataset...")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
else:
|
else:
|
||||||
log.info("Training")
|
slog.info("Training succeeded!")
|
||||||
|
|
||||||
|
slog.info("Evaluating sentiment analyzer: %s", sa)
|
||||||
|
evaluation_results = sa.evaluate(evaluation_dataset_func=datasets.evaluation)
|
||||||
|
slog.info("Evaluation results: %s", evaluation_results)
|
||||||
break
|
break
|
||||||
finally:
|
|
||||||
delete_cache("./data/training")
|
|
||||||
delete_cache("./data/evaluation")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
install_log_handler()
|
|
||||||
config.proxies.resolve()
|
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -1,38 +1,42 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
|
||||||
from ..database import Text, Category, DatasetFunc
|
from ..database import Text, Category, CachedDatasetFunc
|
||||||
|
from ..tokenizer import BaseTokenizer
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class EvaluationResults:
|
|
||||||
correct: int
|
|
||||||
evaluated: int
|
|
||||||
score: float
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<EvaluationResults: score of {self.score} out of {self.evaluated} evaluated tuples>"
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"{self.evaluated} evaluated, {self.correct} correct, {self.correct / self.evaluated * 100:.2} % accuracy, {self.score:.2} score, {self.score / self.evaluated * 100:.2} scoreaccuracy"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSentimentAnalyzer(metaclass=abc.ABCMeta):
|
class BaseSentimentAnalyzer(metaclass=abc.ABCMeta):
|
||||||
"""
|
"""
|
||||||
Abstract base class for sentiment analyzers implemented in this project.
|
Abstract base class for sentiment analyzers implemented in this project.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# noinspection PyUnusedLocal
|
||||||
|
def __init__(self, *, tokenizer: BaseTokenizer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<{self.__class__.__qualname__}>"
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def train(self, dataset_func: DatasetFunc) -> None:
|
def train(self, training_dataset_func: CachedDatasetFunc, validation_dataset_func: CachedDatasetFunc) -> None:
|
||||||
"""
|
"""
|
||||||
Train the analyzer with the given training dataset.
|
Train the analyzer with the given training and validation datasets.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def evaluate(self, dataset_func: DatasetFunc) -> EvaluationResults:
|
@abc.abstractmethod
|
||||||
|
def use(self, text: Text) -> Category:
|
||||||
|
"""
|
||||||
|
Run the model on the given input.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def evaluate(self, evaluation_dataset_func: CachedDatasetFunc) -> EvaluationResults:
|
||||||
"""
|
"""
|
||||||
Perform a model evaluation by calling repeatedly `.use` on every text of the test dataset and by comparing its resulting category with the expected category.
|
Perform a model evaluation by calling repeatedly `.use` on every text of the test dataset and by comparing its resulting category with the expected category.
|
||||||
|
|
||||||
|
@ -43,23 +47,30 @@ class BaseSentimentAnalyzer(metaclass=abc.ABCMeta):
|
||||||
correct: int = 0
|
correct: int = 0
|
||||||
score: float = 0.0
|
score: float = 0.0
|
||||||
|
|
||||||
for review in dataset_func():
|
for review in evaluation_dataset_func():
|
||||||
resulting_category = self.use(review.text)
|
resulting_category = self.use(review.text)
|
||||||
evaluated += 1
|
evaluated += 1
|
||||||
correct += 1 if resulting_category == review.category else 0
|
correct += 1 if resulting_category == review.category else 0
|
||||||
score += 1 - (abs(resulting_category - review.category) / 4)
|
score += 1 - (abs(resulting_category - review.category) / 4)
|
||||||
if not evaluated % 100:
|
|
||||||
temp_results = EvaluationResults(correct=correct, evaluated=evaluated, score=score)
|
|
||||||
log.debug(f"{temp_results!s}")
|
|
||||||
|
|
||||||
return EvaluationResults(correct=correct, evaluated=evaluated, score=score)
|
return EvaluationResults(correct=correct, evaluated=evaluated, score=score)
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def use(self, text: Text) -> Category:
|
@dataclasses.dataclass
|
||||||
|
class EvaluationResults:
|
||||||
"""
|
"""
|
||||||
Run the model on the given input.
|
Container for the results of a dataset evaluation.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
|
||||||
|
correct: int
|
||||||
|
evaluated: int
|
||||||
|
score: float
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<EvaluationResults: {self!s}>"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.evaluated} evaluated, {self.correct} correct, {self.correct / self.evaluated:.2%} accuracy, {self.score:.2f} score, {self.score / self.evaluated:.2%} scoreaccuracy"
|
||||||
|
|
||||||
|
|
||||||
class AlreadyTrainedError(Exception):
|
class AlreadyTrainedError(Exception):
|
||||||
|
|
|
@ -6,7 +6,7 @@ import logging
|
||||||
import typing as t
|
import typing as t
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from ..database import Text, Category, Review, DatasetFunc
|
from ..database import Text, Category, Review, CachedDatasetFunc
|
||||||
from .base import BaseSentimentAnalyzer, AlreadyTrainedError, NotTrainedError
|
from .base import BaseSentimentAnalyzer, AlreadyTrainedError, NotTrainedError
|
||||||
from ..log import count_passage
|
from ..log import count_passage
|
||||||
from ..tokenizer import BaseTokenizer
|
from ..tokenizer import BaseTokenizer
|
||||||
|
@ -23,7 +23,11 @@ class NLTKSentimentAnalyzer(BaseSentimentAnalyzer):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *, tokenizer: BaseTokenizer) -> None:
|
def __init__(self, *, tokenizer: BaseTokenizer) -> None:
|
||||||
super().__init__()
|
if not tokenizer.supports_plain():
|
||||||
|
raise TypeError("Tokenizer does not support NLTK")
|
||||||
|
|
||||||
|
super().__init__(tokenizer=tokenizer)
|
||||||
|
|
||||||
self.model: nltk.sentiment.SentimentAnalyzer = nltk.sentiment.SentimentAnalyzer()
|
self.model: nltk.sentiment.SentimentAnalyzer = nltk.sentiment.SentimentAnalyzer()
|
||||||
self.trained: bool = False
|
self.trained: bool = False
|
||||||
self.tokenizer: BaseTokenizer = tokenizer
|
self.tokenizer: BaseTokenizer = tokenizer
|
||||||
|
@ -36,7 +40,7 @@ class NLTKSentimentAnalyzer(BaseSentimentAnalyzer):
|
||||||
Convert the `Text` of a `DataTuple` to a `TokenBag`.
|
Convert the `Text` of a `DataTuple` to a `TokenBag`.
|
||||||
"""
|
"""
|
||||||
count_passage(log, "tokenize_datatuple", 100)
|
count_passage(log, "tokenize_datatuple", 100)
|
||||||
return self.tokenizer.tokenize_builtins(datatuple.text), datatuple.category
|
return self.tokenizer.tokenize_plain(datatuple.text), datatuple.category
|
||||||
|
|
||||||
def _add_feature_unigrams(self, dataset: t.Iterator[tuple[TokenBag, Category]]) -> None:
|
def _add_feature_unigrams(self, dataset: t.Iterator[tuple[TokenBag, Category]]) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -67,13 +71,13 @@ class NLTKSentimentAnalyzer(BaseSentimentAnalyzer):
|
||||||
count_passage(log, "extract_features", 100)
|
count_passage(log, "extract_features", 100)
|
||||||
return self.model.extract_features(data[0]), data[1]
|
return self.model.extract_features(data[0]), data[1]
|
||||||
|
|
||||||
def train(self, dataset_func: DatasetFunc) -> None:
|
def train(self, training_dataset_func: CachedDatasetFunc, validation_dataset_func: CachedDatasetFunc) -> None:
|
||||||
# Forbid retraining the model
|
# Forbid retraining the model
|
||||||
if self.trained:
|
if self.trained:
|
||||||
raise AlreadyTrainedError()
|
raise AlreadyTrainedError()
|
||||||
|
|
||||||
# Get a generator
|
# Get a generator
|
||||||
dataset: t.Generator[Review] = dataset_func()
|
dataset: t.Generator[Review] = training_dataset_func()
|
||||||
|
|
||||||
# Tokenize the dataset
|
# Tokenize the dataset
|
||||||
dataset: t.Iterator[tuple[TokenBag, Category]] = map(self.__tokenize_review, dataset)
|
dataset: t.Iterator[tuple[TokenBag, Category]] = map(self.__tokenize_review, dataset)
|
||||||
|
@ -103,7 +107,7 @@ class NLTKSentimentAnalyzer(BaseSentimentAnalyzer):
|
||||||
raise NotTrainedError()
|
raise NotTrainedError()
|
||||||
|
|
||||||
# Tokenize the input
|
# Tokenize the input
|
||||||
tokens = self.tokenizer.tokenize_builtins(text)
|
tokens = self.tokenizer.tokenize_plain(text)
|
||||||
|
|
||||||
# Run the classification method
|
# Run the classification method
|
||||||
return self.model.classify(instance=tokens)
|
return self.model.classify(instance=tokens)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import tensorflow
|
import tensorflow
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from ..database import Text, Category, DatasetFunc
|
from ..database import Text, Category, CachedDatasetFunc
|
||||||
from ..config import TENSORFLOW_EMBEDDING_SIZE, TENSORFLOW_MAX_FEATURES, TENSORFLOW_EPOCHS
|
from ..config import TENSORFLOW_EMBEDDING_SIZE, TENSORFLOW_MAX_FEATURES, TENSORFLOW_EPOCHS
|
||||||
from ..tokenizer import BaseTokenizer
|
from ..tokenizer import BaseTokenizer
|
||||||
from .base import BaseSentimentAnalyzer, AlreadyTrainedError, NotTrainedError, TrainingFailedError
|
from .base import BaseSentimentAnalyzer, AlreadyTrainedError, NotTrainedError, TrainingFailedError
|
||||||
|
@ -9,9 +9,19 @@ from .base import BaseSentimentAnalyzer, AlreadyTrainedError, NotTrainedError, T
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if len(tensorflow.config.list_physical_devices(device_type="GPU")) == 0:
|
||||||
|
log.warning("Tensorflow reports no GPU acceleration available.")
|
||||||
|
else:
|
||||||
|
log.debug("Tensorflow successfully found GPU acceleration!")
|
||||||
|
|
||||||
|
|
||||||
class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer):
|
class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer):
|
||||||
def __init__(self, tokenizer: BaseTokenizer):
|
def __init__(self, *, tokenizer: BaseTokenizer):
|
||||||
super().__init__()
|
if not tokenizer.supports_tensorflow():
|
||||||
|
raise TypeError("Tokenizer does not support Tensorflow")
|
||||||
|
|
||||||
|
super().__init__(tokenizer=tokenizer)
|
||||||
|
|
||||||
self.trained: bool = False
|
self.trained: bool = False
|
||||||
|
|
||||||
self.text_vectorization_layer: tensorflow.keras.layers.TextVectorization = self._build_vectorizer(tokenizer)
|
self.text_vectorization_layer: tensorflow.keras.layers.TextVectorization = self._build_vectorizer(tokenizer)
|
||||||
|
@ -19,7 +29,11 @@ class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer):
|
||||||
self.history: tensorflow.keras.callbacks.History | None = None
|
self.history: tensorflow.keras.callbacks.History | None = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_dataset(dataset_func: DatasetFunc) -> tensorflow.data.Dataset:
|
def _build_dataset(dataset_func: CachedDatasetFunc) -> tensorflow.data.Dataset:
|
||||||
|
"""
|
||||||
|
Convert a `CachedDatasetFunc` to a `tensorflow.data.Dataset`.
|
||||||
|
"""
|
||||||
|
|
||||||
def dataset_func_with_tensor_tuple():
|
def dataset_func_with_tensor_tuple():
|
||||||
for review in dataset_func():
|
for review in dataset_func():
|
||||||
yield review.to_tensor_tuple()
|
yield review.to_tensor_tuple()
|
||||||
|
@ -43,15 +57,16 @@ class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_model() -> tensorflow.keras.Sequential:
|
def _build_model() -> tensorflow.keras.Sequential:
|
||||||
log.debug("Creating %s model...", tensorflow.keras.Sequential)
|
log.debug("Creating model...")
|
||||||
model = tensorflow.keras.Sequential([
|
model = tensorflow.keras.Sequential([
|
||||||
tensorflow.keras.layers.Embedding(
|
tensorflow.keras.layers.Embedding(
|
||||||
input_dim=TENSORFLOW_MAX_FEATURES.__wrapped__ + 1,
|
input_dim=TENSORFLOW_MAX_FEATURES.__wrapped__ + 1,
|
||||||
output_dim=TENSORFLOW_EMBEDDING_SIZE.__wrapped__,
|
output_dim=TENSORFLOW_EMBEDDING_SIZE.__wrapped__,
|
||||||
),
|
),
|
||||||
tensorflow.keras.layers.Dropout(0.2),
|
tensorflow.keras.layers.Dropout(0.25),
|
||||||
tensorflow.keras.layers.GlobalAveragePooling1D(),
|
tensorflow.keras.layers.GlobalAveragePooling1D(),
|
||||||
tensorflow.keras.layers.Dropout(0.2),
|
tensorflow.keras.layers.Dropout(0.25),
|
||||||
|
tensorflow.keras.layers.Dense(25),
|
||||||
tensorflow.keras.layers.Dense(5, activation="softmax"),
|
tensorflow.keras.layers.Dense(5, activation="softmax"),
|
||||||
])
|
])
|
||||||
log.debug("Compiling model: %s", model)
|
log.debug("Compiling model: %s", model)
|
||||||
|
@ -72,31 +87,35 @@ class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer):
|
||||||
max_tokens=TENSORFLOW_MAX_FEATURES.__wrapped__
|
max_tokens=TENSORFLOW_MAX_FEATURES.__wrapped__
|
||||||
)
|
)
|
||||||
|
|
||||||
def train(self, dataset_func: DatasetFunc) -> None:
|
def train(self, training_dataset_func: CachedDatasetFunc, validation_dataset_func: CachedDatasetFunc) -> None:
|
||||||
if self.trained:
|
if self.trained:
|
||||||
log.error("Tried to train an already trained model.")
|
log.error("Tried to train an already trained model.")
|
||||||
raise AlreadyTrainedError()
|
raise AlreadyTrainedError()
|
||||||
|
|
||||||
log.debug("Building dataset...")
|
log.debug("Building datasets...")
|
||||||
training_set = self._build_dataset(dataset_func)
|
training_set = self._build_dataset(training_dataset_func)
|
||||||
|
validation_set = self._build_dataset(validation_dataset_func)
|
||||||
log.debug("Built dataset: %s", training_set)
|
log.debug("Built dataset: %s", training_set)
|
||||||
|
|
||||||
log.debug("Preparing training_set for %s...", self.text_vectorization_layer.adapt)
|
log.debug("Preparing training_set for %s...", self.text_vectorization_layer.adapt)
|
||||||
only_text_set = training_set.map(lambda text, category: text)
|
only_text_set = training_set.map(lambda text, category: text)
|
||||||
|
|
||||||
log.debug("Adapting text_vectorization_layer: %s", self.text_vectorization_layer)
|
log.debug("Adapting text_vectorization_layer: %s", self.text_vectorization_layer)
|
||||||
self.text_vectorization_layer.adapt(only_text_set)
|
self.text_vectorization_layer.adapt(only_text_set)
|
||||||
log.debug("Adapted text_vectorization_layer: %s", self.text_vectorization_layer)
|
log.debug("Adapted text_vectorization_layer: %s", self.text_vectorization_layer)
|
||||||
|
|
||||||
log.debug("Preparing training_set for %s...", self.model.fit)
|
log.debug("Preparing training_set for %s...", self.model.fit)
|
||||||
training_set = training_set.map(lambda text, category: (self.text_vectorization_layer(text), category))
|
training_set = training_set.map(lambda text, category: (self.text_vectorization_layer(text), category))
|
||||||
|
validation_set = validation_set.map(lambda text, category: (self.text_vectorization_layer(text), category))
|
||||||
log.info("Training: %s", self.model)
|
log.info("Training: %s", self.model)
|
||||||
self.history: tensorflow.keras.callbacks.History | None = self.model.fit(
|
self.history: tensorflow.keras.callbacks.History | None = self.model.fit(
|
||||||
training_set,
|
training_set,
|
||||||
|
validation_data=validation_set,
|
||||||
epochs=TENSORFLOW_EPOCHS.__wrapped__,
|
epochs=TENSORFLOW_EPOCHS.__wrapped__,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
tensorflow.keras.callbacks.TerminateOnNaN()
|
tensorflow.keras.callbacks.TerminateOnNaN()
|
||||||
])
|
],
|
||||||
log.info("Trained: %s", self.model)
|
)
|
||||||
|
|
||||||
if len(self.history.epoch) < TENSORFLOW_EPOCHS.__wrapped__:
|
if len(self.history.epoch) < TENSORFLOW_EPOCHS.__wrapped__:
|
||||||
log.error("Model %s training failed: only %d epochs computed", self.model, len(self.history.epoch))
|
log.error("Model %s training failed: only %d epochs computed", self.model, len(self.history.epoch))
|
||||||
|
|
|
@ -45,14 +45,44 @@ def WORKING_SET_SIZE(val: str | None) -> int:
|
||||||
|
|
||||||
|
|
||||||
@config.optional()
|
@config.optional()
|
||||||
def DATA_SET_SIZE(val: str | None) -> int:
|
def TRAINING_SET_SIZE(val: str | None) -> int:
|
||||||
"""
|
"""
|
||||||
The number of reviews from each category to fetch for the datasets.
|
The number of reviews from each category to fetch for the training dataset.
|
||||||
|
|
||||||
Defaults to `1750`.
|
Defaults to `5000`.
|
||||||
"""
|
"""
|
||||||
if val is None:
|
if val is None:
|
||||||
return 1750
|
return 5000
|
||||||
|
try:
|
||||||
|
return int(val)
|
||||||
|
except ValueError:
|
||||||
|
raise cfig.InvalidValueError("Not an int.")
|
||||||
|
|
||||||
|
|
||||||
|
@config.optional()
|
||||||
|
def VALIDATION_SET_SIZE(val: str | None) -> int:
|
||||||
|
"""
|
||||||
|
The number of reviews from each category to fetch for the training dataset.
|
||||||
|
|
||||||
|
Defaults to `400`.
|
||||||
|
"""
|
||||||
|
if val is None:
|
||||||
|
return 400
|
||||||
|
try:
|
||||||
|
return int(val)
|
||||||
|
except ValueError:
|
||||||
|
raise cfig.InvalidValueError("Not an int.")
|
||||||
|
|
||||||
|
|
||||||
|
@config.optional()
|
||||||
|
def EVALUATION_SET_SIZE(val: str | None) -> int:
|
||||||
|
"""
|
||||||
|
The number of reviews from each category to fetch for the evaluation dataset.
|
||||||
|
|
||||||
|
Defaults to `1000`.
|
||||||
|
"""
|
||||||
|
if val is None:
|
||||||
|
return 1000
|
||||||
try:
|
try:
|
||||||
return int(val)
|
return int(val)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -79,10 +109,10 @@ def TENSORFLOW_EMBEDDING_SIZE(val: str | None) -> int:
|
||||||
"""
|
"""
|
||||||
The size of the embeddings tensor to use in Tensorflow models.
|
The size of the embeddings tensor to use in Tensorflow models.
|
||||||
|
|
||||||
Defaults to `12`.
|
Defaults to `6`.
|
||||||
"""
|
"""
|
||||||
if val is None:
|
if val is None:
|
||||||
return 12
|
return 6
|
||||||
try:
|
try:
|
||||||
return int(val)
|
return int(val)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -94,10 +124,10 @@ def TENSORFLOW_EPOCHS(val: str | None) -> int:
|
||||||
"""
|
"""
|
||||||
The number of epochs to train Tensorflow models for.
|
The number of epochs to train Tensorflow models for.
|
||||||
|
|
||||||
Defaults to `15`.
|
Defaults to `12`.
|
||||||
"""
|
"""
|
||||||
if val is None:
|
if val is None:
|
||||||
return 15
|
return 12
|
||||||
try:
|
try:
|
||||||
return int(val)
|
return int(val)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -109,7 +139,9 @@ __all__ = (
|
||||||
"MONGO_HOST",
|
"MONGO_HOST",
|
||||||
"MONGO_PORT",
|
"MONGO_PORT",
|
||||||
"WORKING_SET_SIZE",
|
"WORKING_SET_SIZE",
|
||||||
"DATA_SET_SIZE",
|
"TRAINING_SET_SIZE",
|
||||||
|
"VALIDATION_SET_SIZE",
|
||||||
|
"EVALUATION_SET_SIZE",
|
||||||
"TENSORFLOW_MAX_FEATURES",
|
"TENSORFLOW_MAX_FEATURES",
|
||||||
"TENSORFLOW_EMBEDDING_SIZE",
|
"TENSORFLOW_EMBEDDING_SIZE",
|
||||||
"TENSORFLOW_EPOCHS",
|
"TENSORFLOW_EPOCHS",
|
||||||
|
|
|
@ -9,7 +9,7 @@ from .datatypes import Review
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
DatasetFunc = t.Callable[[], t.Generator[Review, t.Any, None]]
|
CachedDatasetFunc = t.Callable[[], t.Generator[Review, t.Any, None]]
|
||||||
|
|
||||||
|
|
||||||
def store_cache(reviews: t.Iterator[Review], path: str | pathlib.Path) -> None:
|
def store_cache(reviews: t.Iterator[Review], path: str | pathlib.Path) -> None:
|
||||||
|
@ -34,7 +34,7 @@ def store_cache(reviews: t.Iterator[Review], path: str | pathlib.Path) -> None:
|
||||||
pickle.dump(document, file)
|
pickle.dump(document, file)
|
||||||
|
|
||||||
|
|
||||||
def load_cache(path: str | pathlib.Path) -> DatasetFunc:
|
def load_cache(path: str | pathlib.Path) -> CachedDatasetFunc:
|
||||||
"""
|
"""
|
||||||
Load the contents of a directory into a `Review` iterator.
|
Load the contents of a directory into a `Review` iterator.
|
||||||
"""
|
"""
|
||||||
|
@ -69,12 +69,12 @@ def delete_cache(path: str | pathlib.Path) -> None:
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise FileNotFoundError("The specified path does not exist.")
|
raise FileNotFoundError("The specified path does not exist.")
|
||||||
|
|
||||||
log.warning("Deleting cache directory: %s", path)
|
log.debug("Deleting cache directory: %s", path)
|
||||||
shutil.rmtree(path)
|
shutil.rmtree(path)
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"DatasetFunc",
|
"CachedDatasetFunc",
|
||||||
"store_cache",
|
"store_cache",
|
||||||
"load_cache",
|
"load_cache",
|
||||||
"delete_cache",
|
"delete_cache",
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import pymongo
|
import pymongo
|
||||||
|
import pymongo.errors
|
||||||
import contextlib
|
import contextlib
|
||||||
import typing as t
|
import typing as t
|
||||||
import logging
|
import logging
|
||||||
|
@ -13,12 +14,12 @@ def mongo_client_from_config() -> t.ContextManager[pymongo.MongoClient]:
|
||||||
"""
|
"""
|
||||||
Create a new MongoDB client and yield it.
|
Create a new MongoDB client and yield it.
|
||||||
"""
|
"""
|
||||||
log.debug("Opening connection to MongoDB...")
|
log.debug("Creating MongoDB client...")
|
||||||
client: pymongo.MongoClient = pymongo.MongoClient(
|
client: pymongo.MongoClient = pymongo.MongoClient(
|
||||||
host=MONGO_HOST.__wrapped__,
|
host=MONGO_HOST.__wrapped__,
|
||||||
port=MONGO_PORT.__wrapped__,
|
port=MONGO_PORT.__wrapped__,
|
||||||
)
|
)
|
||||||
log.info("Opened connection to MongoDB!")
|
log.debug("Created MongoDB client!")
|
||||||
|
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,11 @@ Category = float
|
||||||
|
|
||||||
|
|
||||||
class Review:
|
class Review:
|
||||||
|
__slots__ = (
|
||||||
|
"text",
|
||||||
|
"category",
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, text: Text, category: Category):
|
def __init__(self, text: Text, category: Category):
|
||||||
self.text: str = text
|
self.text: str = text
|
||||||
self.category: float = category
|
self.category: float = category
|
||||||
|
|
|
@ -9,6 +9,9 @@ from .datatypes import Review
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SampleFunc = t.Callable[[pymongo.collection.Collection, int], t.Iterator[Review]]
|
||||||
|
|
||||||
|
|
||||||
def sample_reviews(collection: pymongo.collection.Collection, amount: int) -> t.Iterator[Review]:
|
def sample_reviews(collection: pymongo.collection.Collection, amount: int) -> t.Iterator[Review]:
|
||||||
"""
|
"""
|
||||||
Get ``amount`` random reviews from the ``reviews`` collection.
|
Get ``amount`` random reviews from the ``reviews`` collection.
|
||||||
|
@ -41,18 +44,20 @@ def sample_reviews_by_rating(collection: pymongo.collection.Collection, rating:
|
||||||
|
|
||||||
|
|
||||||
def sample_reviews_polar(collection: pymongo.collection.Collection, amount: int) -> t.Iterator[Review]:
|
def sample_reviews_polar(collection: pymongo.collection.Collection, amount: int) -> t.Iterator[Review]:
|
||||||
log.debug("Getting a sample of %d polar reviews...", amount * 2)
|
category_amount = amount // 2
|
||||||
|
|
||||||
|
log.debug("Getting a sample of %d polar reviews...", category_amount * 2)
|
||||||
|
|
||||||
cursor = collection.aggregate([
|
cursor = collection.aggregate([
|
||||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||||
{"$match": {"overall": 1.0}},
|
{"$match": {"overall": 1.0}},
|
||||||
{"$sample": {"size": amount}},
|
{"$sample": {"size": category_amount}},
|
||||||
{"$unionWith": {
|
{"$unionWith": {
|
||||||
"coll": collection.name,
|
"coll": collection.name,
|
||||||
"pipeline": [
|
"pipeline": [
|
||||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||||
{"$match": {"overall": 5.0}},
|
{"$match": {"overall": 5.0}},
|
||||||
{"$sample": {"size": amount}},
|
{"$sample": {"size": category_amount}},
|
||||||
],
|
],
|
||||||
}},
|
}},
|
||||||
{"$addFields": {
|
{"$addFields": {
|
||||||
|
@ -69,37 +74,39 @@ def sample_reviews_polar(collection: pymongo.collection.Collection, amount: int)
|
||||||
|
|
||||||
|
|
||||||
def sample_reviews_varied(collection: pymongo.collection.Collection, amount: int) -> t.Iterator[Review]:
|
def sample_reviews_varied(collection: pymongo.collection.Collection, amount: int) -> t.Iterator[Review]:
|
||||||
log.debug("Getting a sample of %d varied reviews...", amount * 5)
|
category_amount = amount // 5
|
||||||
|
|
||||||
|
log.debug("Getting a sample of %d varied reviews...", category_amount * 5)
|
||||||
|
|
||||||
# Wow, this is ugly.
|
# Wow, this is ugly.
|
||||||
cursor = collection.aggregate([
|
cursor = collection.aggregate([
|
||||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||||
{"$match": {"overall": 1.0}},
|
{"$match": {"overall": 1.0}},
|
||||||
{"$sample": {"size": amount}},
|
{"$sample": {"size": category_amount}},
|
||||||
{"$unionWith": {
|
{"$unionWith": {
|
||||||
"coll": collection.name,
|
"coll": collection.name,
|
||||||
"pipeline": [
|
"pipeline": [
|
||||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||||
{"$match": {"overall": 2.0}},
|
{"$match": {"overall": 2.0}},
|
||||||
{"$sample": {"size": amount}},
|
{"$sample": {"size": category_amount}},
|
||||||
{"$unionWith": {
|
{"$unionWith": {
|
||||||
"coll": collection.name,
|
"coll": collection.name,
|
||||||
"pipeline": [
|
"pipeline": [
|
||||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||||
{"$match": {"overall": 3.0}},
|
{"$match": {"overall": 3.0}},
|
||||||
{"$sample": {"size": amount}},
|
{"$sample": {"size": category_amount}},
|
||||||
{"$unionWith": {
|
{"$unionWith": {
|
||||||
"coll": collection.name,
|
"coll": collection.name,
|
||||||
"pipeline": [
|
"pipeline": [
|
||||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||||
{"$match": {"overall": 4.0}},
|
{"$match": {"overall": 4.0}},
|
||||||
{"$sample": {"size": amount}},
|
{"$sample": {"size": category_amount}},
|
||||||
{"$unionWith": {
|
{"$unionWith": {
|
||||||
"coll": collection.name,
|
"coll": collection.name,
|
||||||
"pipeline": [
|
"pipeline": [
|
||||||
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
{"$limit": WORKING_SET_SIZE.__wrapped__},
|
||||||
{"$match": {"overall": 5.0}},
|
{"$match": {"overall": 5.0}},
|
||||||
{"$sample": {"size": amount}},
|
{"$sample": {"size": category_amount}},
|
||||||
],
|
],
|
||||||
}}
|
}}
|
||||||
],
|
],
|
||||||
|
@ -122,6 +129,7 @@ def sample_reviews_varied(collection: pymongo.collection.Collection, amount: int
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
|
"SampleFunc",
|
||||||
"sample_reviews",
|
"sample_reviews",
|
||||||
"sample_reviews_by_rating",
|
"sample_reviews_by_rating",
|
||||||
"sample_reviews_polar",
|
"sample_reviews_polar",
|
||||||
|
|
|
@ -15,15 +15,15 @@ def install_log_handler(loggers: list[logging.Logger] = None):
|
||||||
for logger in loggers:
|
for logger in loggers:
|
||||||
coloredlogs.install(
|
coloredlogs.install(
|
||||||
logger=logger,
|
logger=logger,
|
||||||
level="DEBUG",
|
level="DEBUG" if __debug__ else "INFO",
|
||||||
fmt="{asctime} | {name:<32} | {levelname:>8} | {message}",
|
fmt="{asctime} | {name:<80} | {levelname:>8} | {message}",
|
||||||
style="{",
|
style="{",
|
||||||
level_styles=dict(
|
level_styles=dict(
|
||||||
debug=dict(color="white"),
|
debug=dict(color="white"),
|
||||||
info=dict(color="cyan"),
|
info=dict(color="cyan"),
|
||||||
warning=dict(color="yellow"),
|
warning=dict(color="yellow"),
|
||||||
error=dict(color="red"),
|
error=dict(color="red"),
|
||||||
critical=dict(color="red", bold=True),
|
critical=dict(color="black", background="red", bold=True),
|
||||||
),
|
),
|
||||||
field_styles=dict(
|
field_styles=dict(
|
||||||
asctime=dict(color='magenta'),
|
asctime=dict(color='magenta'),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from .base import BaseTokenizer
|
from .base import BaseTokenizer
|
||||||
from .nltk_word_tokenize import NLTKWordTokenizer
|
from .nltk_word_tokenize import NLTKWordTokenizer
|
||||||
from .potts import PottsTokenizer, PottsTokenizerWithNegation
|
from .potts import PottsTokenizer, PottsTokenizerWithNegation
|
||||||
|
from .plain import PlainTokenizer
|
||||||
from .lower import LowercaseTokenizer
|
from .lower import LowercaseTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,5 +10,6 @@ __all__ = (
|
||||||
"NLTKWordTokenizer",
|
"NLTKWordTokenizer",
|
||||||
"PottsTokenizer",
|
"PottsTokenizer",
|
||||||
"PottsTokenizerWithNegation",
|
"PottsTokenizerWithNegation",
|
||||||
|
"PlainTokenizer",
|
||||||
"LowercaseTokenizer",
|
"LowercaseTokenizer",
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,21 +14,21 @@ class BaseTokenizer:
|
||||||
f.__notimplemented__ = True
|
f.__notimplemented__ = True
|
||||||
return f
|
return f
|
||||||
|
|
||||||
def can_tokenize_builtins(self) -> bool:
|
def supports_plain(self) -> bool:
|
||||||
return getattr(self.tokenize_builtins, "__notimplemented__", False)
|
return not getattr(self.tokenize_plain, "__notimplemented__", False)
|
||||||
|
|
||||||
def can_tokenize_tensorflow(self) -> bool:
|
def supports_tensorflow(self) -> bool:
|
||||||
return getattr(self.tokenize_tensorflow, "__notimplemented__", False)
|
return not getattr(self.tokenize_tensorflow, "__notimplemented__", False)
|
||||||
|
|
||||||
@__not_implemented
|
@__not_implemented
|
||||||
def tokenize_builtins(self, text: str) -> list[str]:
|
def tokenize_plain(self, text: str) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Convert a text string into a list of tokens.
|
Convert a text string into a list of tokens.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@__not_implemented
|
@__not_implemented
|
||||||
def tokenize_tensorflow(self, text: tensorflow.Tensor) -> tensorflow.Tensor:
|
def tokenize_tensorflow(self, text: "tensorflow.Tensor") -> "tensorflow.Tensor":
|
||||||
"""
|
"""
|
||||||
Convert a `tensorflow.Tensor` string into another `tensorflow.Tensor` space-separated string.
|
Convert a `tensorflow.Tensor` string into another `tensorflow.Tensor` space-separated string.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -4,7 +4,11 @@ from .base import BaseTokenizer
|
||||||
|
|
||||||
|
|
||||||
class LowercaseTokenizer(BaseTokenizer):
|
class LowercaseTokenizer(BaseTokenizer):
|
||||||
def tokenize_builtins(self, text: str) -> list[str]:
|
"""
|
||||||
|
Tokenizer which converts the words to lowercase before splitting them via spaces.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def tokenize_plain(self, text: str) -> list[str]:
|
||||||
return text.lower().split()
|
return text.lower().split()
|
||||||
|
|
||||||
def tokenize_tensorflow(self, text: tensorflow.Tensor) -> tensorflow.Tensor:
|
def tokenize_tensorflow(self, text: tensorflow.Tensor) -> tensorflow.Tensor:
|
||||||
|
|
|
@ -10,7 +10,7 @@ class NLTKWordTokenizer(BaseTokenizer):
|
||||||
Tokenizer based on `nltk.word_tokenize`.
|
Tokenizer based on `nltk.word_tokenize`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def tokenize_builtins(self, text: str) -> t.Iterable[str]:
|
def tokenize_plain(self, text: str) -> t.Iterable[str]:
|
||||||
tokens = nltk.word_tokenize(text)
|
tokens = nltk.word_tokenize(text)
|
||||||
nltk.sentiment.util.mark_negation(tokens, shallow=True)
|
nltk.sentiment.util.mark_negation(tokens, shallow=True)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
16
unimore_bda_6/tokenizer/plain.py
Normal file
16
unimore_bda_6/tokenizer/plain.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
import tensorflow
|
||||||
|
|
||||||
|
from .base import BaseTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class PlainTokenizer(BaseTokenizer):
|
||||||
|
"""
|
||||||
|
Tokenizer which just splits the text into tokens by separating them at whitespaces.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def tokenize_plain(self, text: str) -> list[str]:
|
||||||
|
return text.split()
|
||||||
|
|
||||||
|
def tokenize_tensorflow(self, text: tensorflow.Tensor) -> tensorflow.Tensor:
|
||||||
|
text = tensorflow.expand_dims(text, -1, name="tokens")
|
||||||
|
return text
|
|
@ -139,6 +139,7 @@ html_entity_digit_re = re.compile(r"&#\d+;")
|
||||||
html_entity_alpha_re = re.compile(r"&\w+;")
|
html_entity_alpha_re = re.compile(r"&\w+;")
|
||||||
amp = "&"
|
amp = "&"
|
||||||
|
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
|
|
||||||
|
@ -175,7 +176,7 @@ class PottsTokenizer(BaseTokenizer):
|
||||||
s = s.replace(amp, " and ")
|
s = s.replace(amp, " and ")
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def tokenize_builtins(self, text: str) -> t.Iterable[str]:
|
def tokenize_plain(self, text: str) -> t.Iterable[str]:
|
||||||
# Fix HTML character entitites:
|
# Fix HTML character entitites:
|
||||||
s = self.__html2string(text)
|
s = self.__html2string(text)
|
||||||
# Tokenize:
|
# Tokenize:
|
||||||
|
@ -187,8 +188,8 @@ class PottsTokenizer(BaseTokenizer):
|
||||||
|
|
||||||
|
|
||||||
class PottsTokenizerWithNegation(PottsTokenizer):
|
class PottsTokenizerWithNegation(PottsTokenizer):
|
||||||
def tokenize_builtins(self, text: str) -> t.Iterable[str]:
|
def tokenize_plain(self, text: str) -> t.Iterable[str]:
|
||||||
words = super().tokenize_builtins(text)
|
words = super().tokenize_plain(text)
|
||||||
nltk.sentiment.util.mark_negation(words, shallow=True)
|
nltk.sentiment.util.mark_negation(words, shallow=True)
|
||||||
return words
|
return words
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue