2023-02-04 00:36:42 +00:00
|
|
|
import tensorflow
|
|
|
|
import itertools
|
|
|
|
import typing as t
|
|
|
|
|
2023-02-05 16:40:22 +00:00
|
|
|
from ..database import Text, Category, Review, DatasetFunc
|
2023-02-04 00:36:42 +00:00
|
|
|
from ..tokenizer import BaseTokenizer
|
|
|
|
from .base import BaseSentimentAnalyzer, AlreadyTrainedError, NotTrainedError
|
|
|
|
|
|
|
|
|
|
|
|
class TensorflowSentimentAnalyzer(BaseSentimentAnalyzer):
|
2023-02-05 16:40:22 +00:00
|
|
|
def __init__(self, *, tokenizer: BaseTokenizer):
|
2023-02-04 05:14:24 +00:00
|
|
|
super().__init__()
|
2023-02-04 00:36:42 +00:00
|
|
|
self.trained = False
|
|
|
|
self.neural_network: tensorflow.keras.Sequential | None = None
|
2023-02-05 16:40:22 +00:00
|
|
|
self.tokenizer: BaseTokenizer = tokenizer # TODO
|
2023-02-04 00:36:42 +00:00
|
|
|
|
|
|
|
MAX_FEATURES = 20000
|
|
|
|
EMBEDDING_DIM = 16
|
|
|
|
EPOCHS = 10
|
|
|
|
|
2023-02-05 16:40:22 +00:00
|
|
|
def train(self, dataset_func: DatasetFunc) -> None:
|
2023-02-04 00:36:42 +00:00
|
|
|
if self.trained:
|
|
|
|
raise AlreadyTrainedError()
|
|
|
|
|
2023-02-05 16:40:22 +00:00
|
|
|
def dataset_func_with_tensor_text():
|
|
|
|
for review in dataset_func():
|
|
|
|
yield review.to_tensor_text()
|
|
|
|
|
|
|
|
text_set = tensorflow.data.Dataset.from_generator(
|
|
|
|
dataset_func_with_tensor_text,
|
|
|
|
output_signature=tensorflow.TensorSpec(shape=(), dtype=tensorflow.string)
|
|
|
|
)
|
2023-02-04 00:36:42 +00:00
|
|
|
|
2023-02-05 16:40:22 +00:00
|
|
|
text_vectorization_layer = tensorflow.keras.layers.TextVectorization(
|
2023-02-04 00:36:42 +00:00
|
|
|
max_tokens=self.MAX_FEATURES,
|
|
|
|
standardize=self.tokenizer.tokenize_tensorflow,
|
|
|
|
)
|
2023-02-05 16:40:22 +00:00
|
|
|
text_vectorization_layer.adapt(text_set)
|
2023-02-04 00:36:42 +00:00
|
|
|
|
2023-02-05 16:40:22 +00:00
|
|
|
def dataset_func_with_tensor_tuple():
|
|
|
|
for review in dataset_func():
|
|
|
|
yield review.to_tensor_tuple()
|
|
|
|
|
|
|
|
training_set = tensorflow.data.Dataset.from_generator(
|
|
|
|
dataset_func_with_tensor_tuple,
|
|
|
|
output_signature=(
|
|
|
|
tensorflow.TensorSpec(shape=(), dtype=tensorflow.string, name="text"),
|
|
|
|
tensorflow.TensorSpec(shape=(), dtype=tensorflow.float32, name="category"),
|
|
|
|
)
|
|
|
|
)
|
2023-02-04 00:36:42 +00:00
|
|
|
|
|
|
|
# I have no idea of what I'm doing here
|
|
|
|
self.neural_network = tensorflow.keras.Sequential([
|
|
|
|
tensorflow.keras.layers.Embedding(self.MAX_FEATURES + 1, self.EMBEDDING_DIM),
|
|
|
|
tensorflow.keras.layers.Dropout(0.2),
|
|
|
|
tensorflow.keras.layers.GlobalAveragePooling1D(),
|
|
|
|
tensorflow.keras.layers.Dropout(0.2),
|
|
|
|
tensorflow.keras.layers.Dense(1),
|
|
|
|
])
|
|
|
|
|
|
|
|
self.neural_network.compile(
|
|
|
|
loss=tensorflow.losses.BinaryCrossentropy(from_logits=True), # Only works with two tags
|
|
|
|
metrics=tensorflow.metrics.BinaryAccuracy(threshold=0.0)
|
|
|
|
)
|
|
|
|
|
2023-02-05 16:40:22 +00:00
|
|
|
training_set = training_set.map(text_vectorization_layer)
|
|
|
|
|
2023-02-04 00:36:42 +00:00
|
|
|
self.neural_network.fit(
|
|
|
|
training_set,
|
|
|
|
epochs=self.EPOCHS,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.trained = True
|
|
|
|
|
|
|
|
def use(self, text: Text) -> Category:
|
|
|
|
if not self.trained:
|
|
|
|
raise NotTrainedError()
|
|
|
|
|
|
|
|
prediction = self.neural_network.predict(text)
|
|
|
|
breakpoint()
|