1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-11-22 07:54:19 +00:00
bda-6-steffo/unimore_bda_6/analysis/tf_text.py

74 lines
2.5 KiB
Python

import tensorflow
import itertools
import typing as t
from ..database import Text, Category, Review
from ..tokenizer import BaseTokenizer
from .base import BaseSentimentAnalyzer, AlreadyTrainedError, NotTrainedError
class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer):
def __init__(self):
super().__init__()
self.trained = False
self.text_vectorization_layer = None
self.neural_network: tensorflow.keras.Sequential | None = None
@classmethod
def __bda_dataset_to_tf_dataset(cls, dataset_func: t.Callable[[], t.Iterator[Review]]) -> tensorflow.data.Dataset:
"""
Convert a `unimore_bda_6.database.DataSet` to a "real" `tensorflow.data.Dataset`.
"""
return tensorflow.data.Dataset.from_generator(
dataset_func,
output_signature=(
tensorflow.TensorSpec(shape=(), dtype=tensorflow.string),
tensorflow.TensorSpec(shape=(), dtype=tensorflow.string),
)
)
MAX_FEATURES = 20000
EMBEDDING_DIM = 16
EPOCHS = 10
def train(self, training_set: t.Iterator[Review]) -> None:
if self.trained:
raise AlreadyTrainedError()
training_set = self.__bda_dataset_to_tf_dataset(training_set)
self.text_vectorization_layer = tensorflow.keras.layers.TextVectorization(
max_tokens=self.MAX_FEATURES,
standardize=self.tokenizer.tokenize_tensorflow,
)
self.text_vectorization_layer.adapt(map(lambda t: t[0], training_set))
training_set = training_set.map(self.text_vectorization_layer)
# I have no idea of what I'm doing here
self.neural_network = tensorflow.keras.Sequential([
tensorflow.keras.layers.Embedding(self.MAX_FEATURES + 1, self.EMBEDDING_DIM),
tensorflow.keras.layers.Dropout(0.2),
tensorflow.keras.layers.GlobalAveragePooling1D(),
tensorflow.keras.layers.Dropout(0.2),
tensorflow.keras.layers.Dense(1),
])
self.neural_network.compile(
loss=tensorflow.losses.BinaryCrossentropy(from_logits=True), # Only works with two tags
metrics=tensorflow.metrics.BinaryAccuracy(threshold=0.0)
)
self.neural_network.fit(
training_set,
epochs=self.EPOCHS,
)
self.trained = True
def use(self, text: Text) -> Category:
if not self.trained:
raise NotTrainedError()
prediction = self.neural_network.predict(text)
breakpoint()