From 0ce584e856f600c20b6f01739d1cf4a3187a3358 Mon Sep 17 00:00:00 2001 From: Stefano Pigozzi Date: Fri, 10 Feb 2023 04:20:35 +0100 Subject: [PATCH] Fix tensorspec error --- unimore_bda_6/analysis/tf_text.py | 8 ++++---- unimore_bda_6/database/datatypes.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/unimore_bda_6/analysis/tf_text.py b/unimore_bda_6/analysis/tf_text.py index f001f2d..840bfac 100644 --- a/unimore_bda_6/analysis/tf_text.py +++ b/unimore_bda_6/analysis/tf_text.py @@ -19,10 +19,10 @@ else: log.debug("Tensorflow successfully found GPU acceleration!") -ConversionFunc = t.Callable[[Review], list[tensorflow.Tensor]] +ConversionFunc = t.Callable[[Review], tensorflow.Tensor | tuple] -def build_dataset(dataset_func: CachedDatasetFunc, conversion_func: ConversionFunc, output_signature: tensorflow.TensorSpec | list[tensorflow.TensorSpec]) -> tensorflow.data.Dataset: +def build_dataset(dataset_func: CachedDatasetFunc, conversion_func: ConversionFunc, output_signature: tensorflow.TensorSpec | tuple) -> tensorflow.data.Dataset: """ Convert a `CachedDatasetFunc` to a `tensorflow.data.Dataset`. """ @@ -176,10 +176,10 @@ class TensorflowCategorySentimentAnalyzer(TensorflowSentimentAnalyzer): return build_dataset( dataset_func=dataset_func, conversion_func=Review.to_tensor_tuple, - output_signature=[ + output_signature=( tensorflow.TensorSpec(shape=(1,), dtype=tensorflow.string, name="text"), tensorflow.TensorSpec(shape=(5,), dtype=tensorflow.float32, name="review_one_hot"), - ], + ), ) def _build_model(self) -> tensorflow.keras.Sequential: diff --git a/unimore_bda_6/database/datatypes.py b/unimore_bda_6/database/datatypes.py index 4872ba2..980b7e6 100644 --- a/unimore_bda_6/database/datatypes.py +++ b/unimore_bda_6/database/datatypes.py @@ -49,11 +49,11 @@ class Review: 1.0 if self.category == 5.0 else 0.0, ]], dtype=tensorflow.float32) - def to_tensor_tuple(self) -> list[tensorflow.Tensor, tensorflow.Tensor]: - t = [ + def to_tensor_tuple(self) -> tuple[tensorflow.Tensor, tensorflow.Tensor]: + t = ( self.to_tensor_text(), self.to_tensor_category(), - ] + ) log.debug("Converted %s", t) return t