1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-11-21 23:44:19 +00:00

Fix tensorspec error

This commit is contained in:
Steffo 2023-02-10 04:20:35 +01:00
parent 0a4ce38982
commit 0ce584e856
Signed by: steffo
GPG key ID: 2A24051445686895
2 changed files with 7 additions and 7 deletions

View file

@ -19,10 +19,10 @@ else:
log.debug("Tensorflow successfully found GPU acceleration!")
ConversionFunc = t.Callable[[Review], list[tensorflow.Tensor]]
ConversionFunc = t.Callable[[Review], tensorflow.Tensor | tuple]
def build_dataset(dataset_func: CachedDatasetFunc, conversion_func: ConversionFunc, output_signature: tensorflow.TensorSpec | list[tensorflow.TensorSpec]) -> tensorflow.data.Dataset:
def build_dataset(dataset_func: CachedDatasetFunc, conversion_func: ConversionFunc, output_signature: tensorflow.TensorSpec | tuple) -> tensorflow.data.Dataset:
"""
Convert a `CachedDatasetFunc` to a `tensorflow.data.Dataset`.
"""
@ -176,10 +176,10 @@ class TensorflowCategorySentimentAnalyzer(TensorflowSentimentAnalyzer):
return build_dataset(
dataset_func=dataset_func,
conversion_func=Review.to_tensor_tuple,
output_signature=[
output_signature=(
tensorflow.TensorSpec(shape=(1,), dtype=tensorflow.string, name="text"),
tensorflow.TensorSpec(shape=(5,), dtype=tensorflow.float32, name="review_one_hot"),
],
),
)
def _build_model(self) -> tensorflow.keras.Sequential:

View file

@ -49,11 +49,11 @@ class Review:
1.0 if self.category == 5.0 else 0.0,
]], dtype=tensorflow.float32)
def to_tensor_tuple(self) -> list[tensorflow.Tensor, tensorflow.Tensor]:
t = [
def to_tensor_tuple(self) -> tuple[tensorflow.Tensor, tensorflow.Tensor]:
t = (
self.to_tensor_text(),
self.to_tensor_category(),
]
)
log.debug("Converted %s", t)
return t