mirror of
https://github.com/Steffo99/unimore-bda-6.git
synced 2024-11-21 23:44:19 +00:00
Fix tensorspec error
This commit is contained in:
parent
0a4ce38982
commit
0ce584e856
2 changed files with 7 additions and 7 deletions
|
@ -19,10 +19,10 @@ else:
|
|||
log.debug("Tensorflow successfully found GPU acceleration!")
|
||||
|
||||
|
||||
ConversionFunc = t.Callable[[Review], list[tensorflow.Tensor]]
|
||||
ConversionFunc = t.Callable[[Review], tensorflow.Tensor | tuple]
|
||||
|
||||
|
||||
def build_dataset(dataset_func: CachedDatasetFunc, conversion_func: ConversionFunc, output_signature: tensorflow.TensorSpec | list[tensorflow.TensorSpec]) -> tensorflow.data.Dataset:
|
||||
def build_dataset(dataset_func: CachedDatasetFunc, conversion_func: ConversionFunc, output_signature: tensorflow.TensorSpec | tuple) -> tensorflow.data.Dataset:
|
||||
"""
|
||||
Convert a `CachedDatasetFunc` to a `tensorflow.data.Dataset`.
|
||||
"""
|
||||
|
@ -176,10 +176,10 @@ class TensorflowCategorySentimentAnalyzer(TensorflowSentimentAnalyzer):
|
|||
return build_dataset(
|
||||
dataset_func=dataset_func,
|
||||
conversion_func=Review.to_tensor_tuple,
|
||||
output_signature=[
|
||||
output_signature=(
|
||||
tensorflow.TensorSpec(shape=(1,), dtype=tensorflow.string, name="text"),
|
||||
tensorflow.TensorSpec(shape=(5,), dtype=tensorflow.float32, name="review_one_hot"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
def _build_model(self) -> tensorflow.keras.Sequential:
|
||||
|
|
|
@ -49,11 +49,11 @@ class Review:
|
|||
1.0 if self.category == 5.0 else 0.0,
|
||||
]], dtype=tensorflow.float32)
|
||||
|
||||
def to_tensor_tuple(self) -> list[tensorflow.Tensor, tensorflow.Tensor]:
|
||||
t = [
|
||||
def to_tensor_tuple(self) -> tuple[tensorflow.Tensor, tensorflow.Tensor]:
|
||||
t = (
|
||||
self.to_tensor_text(),
|
||||
self.to_tensor_category(),
|
||||
]
|
||||
)
|
||||
log.debug("Converted %s", t)
|
||||
return t
|
||||
|
||||
|
|
Loading…
Reference in a new issue