mirror of
https://github.com/Steffo99/unimore-bda-6.git
synced 2024-11-22 07:54:19 +00:00
Fix tensorspec error
This commit is contained in:
parent
0a4ce38982
commit
0ce584e856
2 changed files with 7 additions and 7 deletions
|
@ -19,10 +19,10 @@ else:
|
||||||
log.debug("Tensorflow successfully found GPU acceleration!")
|
log.debug("Tensorflow successfully found GPU acceleration!")
|
||||||
|
|
||||||
|
|
||||||
ConversionFunc = t.Callable[[Review], list[tensorflow.Tensor]]
|
ConversionFunc = t.Callable[[Review], tensorflow.Tensor | tuple]
|
||||||
|
|
||||||
|
|
||||||
def build_dataset(dataset_func: CachedDatasetFunc, conversion_func: ConversionFunc, output_signature: tensorflow.TensorSpec | list[tensorflow.TensorSpec]) -> tensorflow.data.Dataset:
|
def build_dataset(dataset_func: CachedDatasetFunc, conversion_func: ConversionFunc, output_signature: tensorflow.TensorSpec | tuple) -> tensorflow.data.Dataset:
|
||||||
"""
|
"""
|
||||||
Convert a `CachedDatasetFunc` to a `tensorflow.data.Dataset`.
|
Convert a `CachedDatasetFunc` to a `tensorflow.data.Dataset`.
|
||||||
"""
|
"""
|
||||||
|
@ -176,10 +176,10 @@ class TensorflowCategorySentimentAnalyzer(TensorflowSentimentAnalyzer):
|
||||||
return build_dataset(
|
return build_dataset(
|
||||||
dataset_func=dataset_func,
|
dataset_func=dataset_func,
|
||||||
conversion_func=Review.to_tensor_tuple,
|
conversion_func=Review.to_tensor_tuple,
|
||||||
output_signature=[
|
output_signature=(
|
||||||
tensorflow.TensorSpec(shape=(1,), dtype=tensorflow.string, name="text"),
|
tensorflow.TensorSpec(shape=(1,), dtype=tensorflow.string, name="text"),
|
||||||
tensorflow.TensorSpec(shape=(5,), dtype=tensorflow.float32, name="review_one_hot"),
|
tensorflow.TensorSpec(shape=(5,), dtype=tensorflow.float32, name="review_one_hot"),
|
||||||
],
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_model(self) -> tensorflow.keras.Sequential:
|
def _build_model(self) -> tensorflow.keras.Sequential:
|
||||||
|
|
|
@ -49,11 +49,11 @@ class Review:
|
||||||
1.0 if self.category == 5.0 else 0.0,
|
1.0 if self.category == 5.0 else 0.0,
|
||||||
]], dtype=tensorflow.float32)
|
]], dtype=tensorflow.float32)
|
||||||
|
|
||||||
def to_tensor_tuple(self) -> list[tensorflow.Tensor, tensorflow.Tensor]:
|
def to_tensor_tuple(self) -> tuple[tensorflow.Tensor, tensorflow.Tensor]:
|
||||||
t = [
|
t = (
|
||||||
self.to_tensor_text(),
|
self.to_tensor_text(),
|
||||||
self.to_tensor_category(),
|
self.to_tensor_category(),
|
||||||
]
|
)
|
||||||
log.debug("Converted %s", t)
|
log.debug("Converted %s", t)
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue