1
Fork 0
mirror of https://github.com/Steffo99/unimore-bda-6.git synced 2024-11-22 07:54:19 +00:00

Add missing file

This commit is contained in:
Steffo 2023-02-08 19:48:41 +01:00
parent 027f8e07e8
commit 704624507a
Signed by: steffo
GPG key ID: 2A24051445686895

View file

@ -0,0 +1,74 @@
import typing as t
import contextlib
import dataclasses
import logging
import pymongo
from .config import TRAINING_SET_SIZE, VALIDATION_SET_SIZE, EVALUATION_SET_SIZE
from .database import SampleFunc, CachedDatasetFunc, mongo_client_from_config, reviews_collection, store_cache, load_cache, delete_cache
log = logging.getLogger(__name__)
@dataclasses.dataclass
class Caches:
"""
Container for the three generators that can create datasets.
"""
training: CachedDatasetFunc
validation: CachedDatasetFunc
evaluation: CachedDatasetFunc
@classmethod
@contextlib.contextmanager
def from_database_samples(cls, collection: pymongo.collection.Collection, sample_func: SampleFunc) -> t.ContextManager["Caches"]:
"""
Create a new caches object from a database collection and a sampling function.
"""
log.debug("Gathering datasets...")
reviews_training = sample_func(collection, TRAINING_SET_SIZE.__wrapped__)
reviews_validation = sample_func(collection, VALIDATION_SET_SIZE.__wrapped__)
reviews_evaluation = sample_func(collection, EVALUATION_SET_SIZE.__wrapped__)
log.debug("Caching datasets...")
store_cache(reviews_training, "./data/training")
store_cache(reviews_validation, "./data/validation")
store_cache(reviews_evaluation, "./data/evaluation")
log.debug("Loading dataset caches...")
training_cache = load_cache("./data/training")
validation_cache = load_cache("./data/validation")
evaluation_cache = load_cache("./data/evaluation")
yield Caches(training=training_cache, validation=validation_cache, evaluation=evaluation_cache)
log.debug("Cleaning up caches...")
delete_cache("./data/training")
delete_cache("./data/validation")
delete_cache("./data/evaluation")
@staticmethod
def ensure_clean():
log.debug("Ensuring there are no leftover caches...")
try:
delete_cache("./data/training")
except FileNotFoundError:
pass
try:
delete_cache("./data/validation")
except FileNotFoundError:
pass
try:
delete_cache("./data/evaluation")
except FileNotFoundError:
pass
__all__ = (
"Caches",
)