1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-11-22 16:04:18 +00:00
bda-6-steffo/unimore_bda_6/analysis/tf_text.py

127 lines
5 KiB
Python
Raw Normal View History

2023-02-04 00:36:42 +00:00
import tensorflow
2023-02-08 09:54:14 +00:00
import logging
2023-02-04 00:36:42 +00:00
from ..database import Text, Category, DatasetFunc
2023-02-08 09:54:14 +00:00
from ..config import TENSORFLOW_EMBEDDING_SIZE, TENSORFLOW_MAX_FEATURES, TENSORFLOW_EPOCHS
from ..tokenizer import BaseTokenizer
from .base import BaseSentimentAnalyzer, AlreadyTrainedError, NotTrainedError, TrainingFailedError
log = logging.getLogger(__name__)
2023-02-04 00:36:42 +00:00
class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer):
2023-02-08 09:54:14 +00:00
def __init__(self, tokenizer: BaseTokenizer):
2023-02-04 05:14:24 +00:00
super().__init__()
self.trained: bool = False
2023-02-04 00:36:42 +00:00
2023-02-08 09:54:14 +00:00
self.text_vectorization_layer: tensorflow.keras.layers.TextVectorization = self._build_vectorizer(tokenizer)
self.model: tensorflow.keras.Sequential = self._build_model()
2023-02-08 09:54:14 +00:00
self.history: tensorflow.keras.callbacks.History | None = None
2023-02-04 00:36:42 +00:00
2023-02-08 09:54:14 +00:00
@staticmethod
def _build_dataset(dataset_func: DatasetFunc) -> tensorflow.data.Dataset:
def dataset_func_with_tensor_tuple():
for review in dataset_func():
yield review.to_tensor_tuple()
2023-02-08 09:54:14 +00:00
log.debug("Creating dataset...")
dataset = tensorflow.data.Dataset.from_generator(
dataset_func_with_tensor_tuple,
output_signature=(
tensorflow.TensorSpec(shape=(), dtype=tensorflow.string, name="text"),
2023-02-08 09:54:14 +00:00
tensorflow.TensorSpec(shape=(1, 5,), dtype=tensorflow.float32, name="category"),
)
)
2023-02-04 00:36:42 +00:00
2023-02-08 09:54:14 +00:00
log.debug("Caching dataset...")
dataset = dataset.cache()
log.debug("Configuring dataset prefetch...")
dataset = dataset.prefetch(buffer_size=tensorflow.data.AUTOTUNE)
return dataset
@staticmethod
def _build_model() -> tensorflow.keras.Sequential:
log.debug("Creating %s model...", tensorflow.keras.Sequential)
model = tensorflow.keras.Sequential([
tensorflow.keras.layers.Embedding(
2023-02-08 09:54:14 +00:00
input_dim=TENSORFLOW_MAX_FEATURES.__wrapped__ + 1,
output_dim=TENSORFLOW_EMBEDDING_SIZE.__wrapped__,
),
2023-02-08 09:54:14 +00:00
tensorflow.keras.layers.Dropout(0.2),
2023-02-04 00:36:42 +00:00
tensorflow.keras.layers.GlobalAveragePooling1D(),
2023-02-08 09:54:14 +00:00
tensorflow.keras.layers.Dropout(0.2),
tensorflow.keras.layers.Dense(5, activation="softmax"),
2023-02-04 00:36:42 +00:00
])
2023-02-08 09:54:14 +00:00
log.debug("Compiling model: %s", model)
model.compile(
optimizer=tensorflow.keras.optimizers.Adam(global_clipnorm=1.0),
loss=tensorflow.keras.losses.CategoricalCrossentropy(),
metrics=[
tensorflow.keras.metrics.CategoricalAccuracy(),
]
)
log.debug("Compiled model: %s", model)
return model
@staticmethod
def _build_vectorizer(tokenizer: BaseTokenizer) -> tensorflow.keras.layers.TextVectorization:
return tensorflow.keras.layers.TextVectorization(
standardize=tokenizer.tokenize_tensorflow,
max_tokens=TENSORFLOW_MAX_FEATURES.__wrapped__
)
def train(self, dataset_func: DatasetFunc) -> None:
if self.trained:
2023-02-08 09:54:14 +00:00
log.error("Tried to train an already trained model.")
raise AlreadyTrainedError()
2023-02-08 09:54:14 +00:00
log.debug("Building dataset...")
training_set = self._build_dataset(dataset_func)
2023-02-08 09:54:14 +00:00
log.debug("Built dataset: %s", training_set)
2023-02-08 09:54:14 +00:00
log.debug("Preparing training_set for %s...", self.text_vectorization_layer.adapt)
only_text_set = training_set.map(lambda text, category: text)
2023-02-08 09:54:14 +00:00
log.debug("Adapting text_vectorization_layer: %s", self.text_vectorization_layer)
self.text_vectorization_layer.adapt(only_text_set)
2023-02-08 09:54:14 +00:00
log.debug("Adapted text_vectorization_layer: %s", self.text_vectorization_layer)
log.debug("Preparing training_set for %s...", self.model.fit)
training_set = training_set.map(lambda text, category: (self.text_vectorization_layer(text), category))
log.info("Training: %s", self.model)
self.history: tensorflow.keras.callbacks.History | None = self.model.fit(
training_set,
epochs=TENSORFLOW_EPOCHS.__wrapped__,
callbacks=[
tensorflow.keras.callbacks.TerminateOnNaN()
])
log.info("Trained: %s", self.model)
if len(self.history.epoch) < TENSORFLOW_EPOCHS.__wrapped__:
log.error("Model %s training failed: only %d epochs computed", self.model, len(self.history.epoch))
raise TrainingFailedError()
else:
log.info("Model %s training succeeded!", self.model)
2023-02-04 00:36:42 +00:00
self.trained = True
def use(self, text: Text) -> Category:
if not self.trained:
2023-02-08 09:54:14 +00:00
log.error("Tried to use a non-trained model.")
2023-02-04 00:36:42 +00:00
raise NotTrainedError()
2023-02-08 09:54:14 +00:00
vector = self.text_vectorization_layer(text)
2023-02-08 09:54:14 +00:00
prediction = self.model.predict(vector, verbose=False)
max_i = None
max_p = None
for i, p in enumerate(iter(prediction[0])):
if max_p is None or p > max_p:
max_i = i
max_p = p
2023-02-08 09:54:14 +00:00
result = float(max_i) + 1.0
2023-02-08 09:54:14 +00:00
return result