1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-11-22 07:54:19 +00:00

Fix tensorspec error

This commit is contained in:
Steffo 2023-02-10 04:20:35 +01:00
parent 0a4ce38982
commit 0ce584e856
Signed by: steffo
GPG key ID: 2A24051445686895
2 changed files with 7 additions and 7 deletions

View file

@ -19,10 +19,10 @@ else:
log.debug("Tensorflow successfully found GPU acceleration!") log.debug("Tensorflow successfully found GPU acceleration!")
ConversionFunc = t.Callable[[Review], list[tensorflow.Tensor]] ConversionFunc = t.Callable[[Review], tensorflow.Tensor | tuple]
def build_dataset(dataset_func: CachedDatasetFunc, conversion_func: ConversionFunc, output_signature: tensorflow.TensorSpec | list[tensorflow.TensorSpec]) -> tensorflow.data.Dataset: def build_dataset(dataset_func: CachedDatasetFunc, conversion_func: ConversionFunc, output_signature: tensorflow.TensorSpec | tuple) -> tensorflow.data.Dataset:
""" """
Convert a `CachedDatasetFunc` to a `tensorflow.data.Dataset`. Convert a `CachedDatasetFunc` to a `tensorflow.data.Dataset`.
""" """
@ -176,10 +176,10 @@ class TensorflowCategorySentimentAnalyzer(TensorflowSentimentAnalyzer):
return build_dataset( return build_dataset(
dataset_func=dataset_func, dataset_func=dataset_func,
conversion_func=Review.to_tensor_tuple, conversion_func=Review.to_tensor_tuple,
output_signature=[ output_signature=(
tensorflow.TensorSpec(shape=(1,), dtype=tensorflow.string, name="text"), tensorflow.TensorSpec(shape=(1,), dtype=tensorflow.string, name="text"),
tensorflow.TensorSpec(shape=(5,), dtype=tensorflow.float32, name="review_one_hot"), tensorflow.TensorSpec(shape=(5,), dtype=tensorflow.float32, name="review_one_hot"),
], ),
) )
def _build_model(self) -> tensorflow.keras.Sequential: def _build_model(self) -> tensorflow.keras.Sequential:

View file

@ -49,11 +49,11 @@ class Review:
1.0 if self.category == 5.0 else 0.0, 1.0 if self.category == 5.0 else 0.0,
]], dtype=tensorflow.float32) ]], dtype=tensorflow.float32)
def to_tensor_tuple(self) -> list[tensorflow.Tensor, tensorflow.Tensor]: def to_tensor_tuple(self) -> tuple[tensorflow.Tensor, tensorflow.Tensor]:
t = [ t = (
self.to_tensor_text(), self.to_tensor_text(),
self.to_tensor_category(), self.to_tensor_category(),
] )
log.debug("Converted %s", t) log.debug("Converted %s", t)
return t return t