LucaCappelletti94/udbnn

View on GitHub
udbnn/udbnn.py

Summary

Maintainability
A
45 mins
Test Coverage
from typing import Dict
import pandas as pd
from typing import Tuple
from .fit import fit
from .utils import get_history_path, is_batch_size_cached, is_holdout_cached, load_dataset, load_settings, get_batch_sizes, normalized_holdouts_generator
from notipy_me import Notipy
from auto_tqdm import tqdm
from environments_utils import is_tmux
from extra_keras_utils import is_gpu_available

def train_batch_sizes(dataset_path:str, holdout, training:Tuple, testing:Tuple, settings:Dict, N:Notipy):
    batch_sizes = [
        v for v in get_batch_sizes(
            resolution=settings["batch_sizes"]["resolution"],
            minimum=settings["batch_sizes"]["minimum"],
            size=training[0].shape[0],
            seed=settings["batch_sizes"]["seed"]
        ) if not is_batch_size_cached(dataset_path, v, settings)
    ]
    for batch_size in tqdm(batch_sizes, desc="Batch sizes", leave=False):
        if not is_holdout_cached(dataset_path, batch_size, holdout):
            with open("{dataset_path}/history.json".format(dataset_path=get_history_path(dataset_path, batch_size, holdout)), "w") as f:
                history = pd.DataFrame(fit(training, testing, batch_size, settings).history)
                row = history[["auprc", "val_auprc"]].tail(1)
                row.index = [batch_size]
                row.index.name = "Batch size"
                N.add_report(row)
                history.to_json(f)

def train_holdout(dataset_path:str, settings:Dict, N:Notipy):
    dataset = load_dataset(dataset_path, settings["max_correlation"])    
    for holdout, (training, testing) in zip(settings["holdouts"], normalized_holdouts_generator(dataset, settings["holdouts"])()):
        train_batch_sizes(dataset_path, holdout, training, testing, settings, N)

def train_datasets(target:str, N:Notipy):
    settings = load_settings(target)
    datasets = [
        "{target}/{path}".format(target=target, path=dataset["path"])
        for dataset in settings["datasets"]
        if dataset["enabled"]
    ]
    for path in tqdm(datasets, desc="Datasets"): 
        train_holdout(path, settings, N)

def run(target:str):
    if not is_gpu_available():
        print("No GPU was detected!")
    if not is_tmux():
        print("Not running within TMUX!")
    with Notipy() as N:
        train_datasets(target, N)