1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-11-22 16:04:18 +00:00
bda-6-steffo/unimore_bda_6/analysis/tf_text.py

272 lines
10 KiB
Python
Raw Normal View History

2023-02-09 17:54:58 +00:00
import abc
import typing as t
import numpy
2023-02-04 00:36:42 +00:00
import tensorflow
2023-02-08 09:54:14 +00:00
import logging
2023-02-04 00:36:42 +00:00
2023-02-09 17:54:58 +00:00
from ..database import Text, Category, CachedDatasetFunc, Review
2023-02-08 09:54:14 +00:00
from ..config import TENSORFLOW_EMBEDDING_SIZE, TENSORFLOW_MAX_FEATURES, TENSORFLOW_EPOCHS
from ..tokenizer import BaseTokenizer
from .base import BaseSentimentAnalyzer, AlreadyTrainedError, NotTrainedError, TrainingFailedError
log = logging.getLogger(__name__)
2023-02-04 00:36:42 +00:00
2023-02-08 18:46:05 +00:00
if len(tensorflow.config.list_physical_devices(device_type="GPU")) == 0:
log.warning("Tensorflow reports no GPU acceleration available.")
else:
log.debug("Tensorflow successfully found GPU acceleration!")
2023-02-10 03:20:35 +00:00
ConversionFunc = t.Callable[[Review], tensorflow.Tensor | tuple]
2023-02-09 17:54:58 +00:00
2023-02-10 03:20:35 +00:00
def build_dataset(dataset_func: CachedDatasetFunc, conversion_func: ConversionFunc, output_signature: tensorflow.TensorSpec | tuple) -> tensorflow.data.Dataset:
2023-02-09 17:54:58 +00:00
"""
Convert a `CachedDatasetFunc` to a `tensorflow.data.Dataset`.
"""
def dataset_generator():
for review in dataset_func():
yield conversion_func(review)
log.debug("Creating dataset...")
dataset = tensorflow.data.Dataset.from_generator(
dataset_generator,
output_signature=output_signature,
)
log.debug("Caching dataset...")
dataset = dataset.cache()
log.debug("Configuring dataset prefetch...")
dataset = dataset.prefetch(buffer_size=tensorflow.data.AUTOTUNE)
return dataset
class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer, metaclass=abc.ABCMeta):
"""
Base class for a sentiment analyzer using `tensorflow`.
"""
2023-02-08 18:46:05 +00:00
def __init__(self, *, tokenizer: BaseTokenizer):
if not tokenizer.supports_tensorflow():
raise TypeError("Tokenizer does not support Tensorflow")
super().__init__(tokenizer=tokenizer)
self.trained: bool = False
2023-02-09 17:54:58 +00:00
self.failed: bool = False
2023-02-04 00:36:42 +00:00
2023-02-09 17:54:58 +00:00
self.tokenizer: BaseTokenizer = tokenizer
self.text_vectorization_layer: tensorflow.keras.layers.TextVectorization = self._build_text_vectorization_layer()
self.model: tensorflow.keras.Sequential = self._build_model()
2023-02-08 09:54:14 +00:00
self.history: tensorflow.keras.callbacks.History | None = None
2023-02-04 00:36:42 +00:00
2023-02-09 17:54:58 +00:00
def _build_text_vectorization_layer(self) -> tensorflow.keras.layers.TextVectorization:
2023-02-08 18:46:05 +00:00
"""
2023-02-09 17:54:58 +00:00
Create a `tensorflow`-compatible `TextVectorization` layer.
2023-02-08 18:46:05 +00:00
"""
2023-02-09 17:54:58 +00:00
log.debug("Creating TextVectorization layer...")
layer = tensorflow.keras.layers.TextVectorization(
2023-02-10 04:18:24 +00:00
standardize=self.tokenizer.tokenize_tensorflow_and_expand_dims,
2023-02-09 17:54:58 +00:00
max_tokens=TENSORFLOW_MAX_FEATURES.__wrapped__
)
2023-02-09 17:54:58 +00:00
log.debug("Created TextVectorization layer: %s", layer)
return layer
2023-02-04 00:36:42 +00:00
2023-02-09 17:54:58 +00:00
@abc.abstractmethod
def _build_model(self) -> tensorflow.keras.Sequential:
"""
Create the `tensorflow.keras.Sequential` model that should be executed by this sentiment analyzer.
"""
raise NotImplementedError()
2023-02-08 09:54:14 +00:00
2023-02-09 17:54:58 +00:00
@abc.abstractmethod
def _build_dataset(self, dataset_func: CachedDatasetFunc) -> tensorflow.data.Dataset:
"""
Create a `tensorflow.data.Dataset` from the given `CachedDatasetFunc`.
"""
raise NotImplementedError()
2023-02-08 09:54:14 +00:00
2023-02-09 17:54:58 +00:00
def _adapt_textvectorization(self, dataset: tensorflow.data.Dataset) -> None:
"""
Adapt the `.text_vectorization_layer` to the given dataset.
"""
log.debug("Preparing dataset to adapt %s...", self.text_vectorization_layer)
dataset = dataset.map(lambda text, category: text)
log.debug("Adapting %s...", self.text_vectorization_layer)
self.text_vectorization_layer.adapt(dataset)
2023-02-08 09:54:14 +00:00
2023-02-09 17:54:58 +00:00
def _vectorize_dataset(self, dataset: tensorflow.data.Dataset) -> tensorflow.data.Dataset:
"""
Apply the `.text_vectorization_layer` to the text in the dataset.
"""
def vectorize_entry(text, category):
return self.text_vectorization_layer(text), category
2023-02-08 09:54:14 +00:00
2023-02-09 17:54:58 +00:00
log.debug("Vectorizing dataset: %s", dataset)
dataset = dataset.map(vectorize_entry)
log.debug("Vectorized dataset: %s", dataset)
return dataset
2023-02-08 18:46:05 +00:00
def train(self, training_dataset_func: CachedDatasetFunc, validation_dataset_func: CachedDatasetFunc) -> None:
2023-02-09 17:54:58 +00:00
if self.failed:
log.error("Tried to train a failed model.")
raise AlreadyTrainedError("Cannot re-train a failed model.")
if self.trained:
2023-02-08 09:54:14 +00:00
log.error("Tried to train an already trained model.")
2023-02-09 17:54:58 +00:00
raise AlreadyTrainedError("Cannot re-train an already trained model.")
2023-02-08 18:46:05 +00:00
training_set = self._build_dataset(training_dataset_func)
validation_set = self._build_dataset(validation_dataset_func)
2023-02-09 17:54:58 +00:00
self._adapt_textvectorization(training_set)
2023-02-08 18:46:05 +00:00
2023-02-09 17:54:58 +00:00
training_set = self._vectorize_dataset(training_set)
validation_set = self._vectorize_dataset(validation_set)
2023-02-08 09:54:14 +00:00
log.info("Training: %s", self.model)
self.history: tensorflow.keras.callbacks.History | None = self.model.fit(
training_set,
2023-02-08 18:46:05 +00:00
validation_data=validation_set,
2023-02-08 09:54:14 +00:00
epochs=TENSORFLOW_EPOCHS.__wrapped__,
callbacks=[
tensorflow.keras.callbacks.TerminateOnNaN()
2023-02-08 18:46:05 +00:00
],
)
2023-02-08 09:54:14 +00:00
if len(self.history.epoch) < TENSORFLOW_EPOCHS.__wrapped__:
log.error("Model %s training failed: only %d epochs computed", self.model, len(self.history.epoch))
2023-02-09 17:54:58 +00:00
self.failed = True
2023-02-08 09:54:14 +00:00
raise TrainingFailedError()
else:
log.info("Model %s training succeeded!", self.model)
2023-02-09 17:54:58 +00:00
self.trained = True
2023-02-04 00:36:42 +00:00
2023-02-09 17:54:58 +00:00
@abc.abstractmethod
def _translate_prediction(self, a: numpy.array) -> Category:
"""
Convert the results of `tensorflow.keras.Sequential.predict` into a `.Category`.
"""
raise NotImplementedError()
2023-02-04 00:36:42 +00:00
def use(self, text: Text) -> Category:
2023-02-09 17:54:58 +00:00
if self.failed:
log.error("Tried to use a failed model.")
raise NotTrainedError("Cannot use a failed model.")
2023-02-04 00:36:42 +00:00
if not self.trained:
2023-02-08 09:54:14 +00:00
log.error("Tried to use a non-trained model.")
2023-02-09 17:54:58 +00:00
raise NotTrainedError("Cannot use a non-trained model.")
2023-02-04 00:36:42 +00:00
2023-02-08 09:54:14 +00:00
vector = self.text_vectorization_layer(text)
prediction = self.model.predict(vector, verbose=False)
2023-02-10 04:52:13 +00:00
return self._translate_prediction(prediction)
2023-02-09 17:54:58 +00:00
class TensorflowCategorySentimentAnalyzer(TensorflowSentimentAnalyzer):
"""
A `tensorflow`-based sentiment analyzer that considers each star rating as a separate category.
"""
def _build_dataset(self, dataset_func: CachedDatasetFunc) -> tensorflow.data.Dataset:
return build_dataset(
dataset_func=dataset_func,
2023-02-10 04:52:13 +00:00
conversion_func=Review.to_tensor_tuple_category,
2023-02-10 03:20:35 +00:00
output_signature=(
2023-02-10 04:18:24 +00:00
tensorflow.TensorSpec(shape=(), dtype=tensorflow.string, name="text"),
2023-02-10 04:52:13 +00:00
tensorflow.TensorSpec(shape=(1, 5,), dtype=tensorflow.float32, name="category_one_hot"),
2023-02-10 03:20:35 +00:00
),
2023-02-09 17:54:58 +00:00
)
def _build_model(self) -> tensorflow.keras.Sequential:
log.debug("Creating sequential categorizer model...")
model = tensorflow.keras.Sequential([
tensorflow.keras.layers.Embedding(
input_dim=TENSORFLOW_MAX_FEATURES.__wrapped__ + 1,
output_dim=TENSORFLOW_EMBEDDING_SIZE.__wrapped__,
),
2023-02-11 03:32:17 +00:00
tensorflow.keras.layers.Dropout(0.25),
2023-02-09 17:54:58 +00:00
tensorflow.keras.layers.GlobalAveragePooling1D(),
2023-02-11 03:32:17 +00:00
tensorflow.keras.layers.Dropout(0.25),
tensorflow.keras.layers.Dense(8),
tensorflow.keras.layers.Dropout(0.25),
2023-02-09 17:54:58 +00:00
tensorflow.keras.layers.Dense(5, activation="softmax"),
])
log.debug("Compiling model: %s", model)
model.compile(
2023-02-11 04:57:14 +00:00
optimizer=tensorflow.keras.optimizers.Adam(clipnorm=1.0),
2023-02-09 17:54:58 +00:00
loss=tensorflow.keras.losses.CategoricalCrossentropy(),
metrics=[
tensorflow.keras.metrics.CategoricalAccuracy(),
]
)
log.debug("Compiled model: %s", model)
return model
def _translate_prediction(self, a: numpy.array) -> Category:
max_i = None
max_p = None
2023-02-09 17:54:58 +00:00
for i, p in enumerate(iter(a[0])):
if max_p is None or p > max_p:
max_i = i
max_p = p
2023-02-08 09:54:14 +00:00
result = float(max_i) + 1.0
2023-02-11 04:57:14 +00:00
return float(round(result))
2023-02-09 17:54:58 +00:00
2023-02-10 04:52:13 +00:00
class TensorflowPolarSentimentAnalyzer(TensorflowSentimentAnalyzer):
"""
A `tensorflow`-based sentiment analyzer that uses the floating point value rating to get as close as possible to the correct category.
"""
def _build_dataset(self, dataset_func: CachedDatasetFunc) -> tensorflow.data.Dataset:
return build_dataset(
dataset_func=dataset_func,
conversion_func=Review.to_tensor_tuple_normvalue,
output_signature=(
tensorflow.TensorSpec(shape=(), dtype=tensorflow.string, name="text"),
tensorflow.TensorSpec(shape=(1,), dtype=tensorflow.float32, name="category"),
),
)
def _build_model(self) -> tensorflow.keras.Sequential:
log.debug("Creating sequential categorizer model...")
model = tensorflow.keras.Sequential([
tensorflow.keras.layers.Embedding(
input_dim=TENSORFLOW_MAX_FEATURES.__wrapped__ + 1,
output_dim=TENSORFLOW_EMBEDDING_SIZE.__wrapped__,
),
2023-02-11 03:32:17 +00:00
tensorflow.keras.layers.Dropout(0.25),
2023-02-10 04:52:13 +00:00
tensorflow.keras.layers.GlobalAveragePooling1D(),
2023-02-11 03:32:17 +00:00
tensorflow.keras.layers.Dropout(0.25),
2023-02-11 04:57:14 +00:00
tensorflow.keras.layers.Dense(1, activation="sigmoid"),
2023-02-10 04:52:13 +00:00
])
log.debug("Compiling model: %s", model)
model.compile(
2023-02-11 04:57:14 +00:00
optimizer=tensorflow.keras.optimizers.Adam(clipnorm=1.0),
2023-02-11 03:32:17 +00:00
loss=tensorflow.keras.losses.MeanAbsoluteError(),
2023-02-10 04:52:13 +00:00
)
log.debug("Compiled model: %s", model)
return model
def _translate_prediction(self, a: numpy.array) -> Category:
2023-02-11 04:57:14 +00:00
a: float = a[0, 0]
a = a * 2 + 1
a = float(round(a))
return a
2023-02-10 04:52:13 +00:00
2023-02-09 17:54:58 +00:00
__all__ = (
"TensorflowSentimentAnalyzer",
"TensorflowCategorySentimentAnalyzer",
2023-02-10 04:52:13 +00:00
"TensorflowPolarSentimentAnalyzer",
2023-02-09 17:54:58 +00:00
)