sylvchev/mdla

View on GitHub
experiments/experiment_multivariate_recovering.py

Summary

Maintainability
F
6 days
Test Coverage
"""Dictionary recovering experiment for multivariate random dataset"""
import os
import pickle
from os.path import exists

import matplotlib.pyplot as plt
from numpy import arange, array, zeros
from numpy.linalg import norm
from numpy.random import RandomState, permutation, rand, randint, randn

from dict_metrics import detection_rate, emd, hausdorff
from mdla import MiniBatchMultivariateDictLearning


display = os.environ.get("DISPLAY")
if display is None:
    # if launched from a screen
    import matplotlib

    matplotlib.use("Agg")


def _generate_testbed(
    kernel_init_len,
    n_nonzero_coefs,
    n_kernels,
    n_samples=10,
    n_features=5,
    n_dims=3,
    snr=1000,
):
    """Generate a dataset from a random dictionary

    Generate a random dictionary and a dataset, where samples are combination
    of n_nonzero_coefs dictionary atoms. Noise is added, based on SNR value,
    with 1000 indicated that no noise should be added.
    Return the dictionary, the dataset and an array indicated how atoms are
    combined to obtain each sample
    """
    dico = [randn(kernel_init_len, n_dims) for i in range(n_kernels)]
    for i in range(len(dico)):
        dico[i] /= norm(dico[i], "fro")

    signals = list()
    decomposition = list()
    for _ in range(n_samples):
        s = zeros(shape=(n_features, n_dims))
        d = zeros(shape=(n_nonzero_coefs, 3))
        rk = permutation(range(n_kernels))
        for j in range(n_nonzero_coefs):
            k_idx = rk[j]
            k_amplitude = 3.0 * rand() + 1.0
            k_offset = randint(n_features - kernel_init_len + 1)
            s[k_offset : k_offset + kernel_init_len, :] += k_amplitude * dico[k_idx]
            d[j, :] = array([k_amplitude, k_offset, k_idx])
        decomposition.append(d)
        noise = randn(n_features, n_dims)
        if snr == 1000:
            alpha = 0
        else:
            ps = norm(s, "fro")
            pn = norm(noise, "fro")
            alpha = ps / (pn * 10 ** (snr / 20.0))
        signals.append(s + alpha * noise)
    signals = array(signals)

    return dico, signals, decomposition


def plot_boxes(fig, data, color="blue", n_iter=100, label=""):
    bp = fig.boxplot(data)
    plt.setp(bp["medians"], color=color)
    plt.setp(bp["caps"], color=color)
    plt.setp(bp["boxes"], color=color)
    plt.setp(bp["fliers"], color=color)
    plt.setp(bp["whiskers"], color=color)
    med = [m.get_ydata()[0] for n, m in enumerate(bp["medians"])]
    _ = fig.plot(arange(1, n_iter + 1), med, linewidth=1, color=color, label=label)


def plot_recov_all(
    wc, wfs, wcpa, wbc, wg, wfb, hc, hfs, hcpa, hbc, hg, hfb, dr99, dr97, n_iter, figname
):
    snr = ["30", "20", "10"]
    fig = plt.figure(figsize=(18, 10))
    for i, s in enumerate(snr):
        # plotting data from detection rate
        detection = fig.add_subplot(3, 3, i * 3 + 1)
        plot_boxes(
            detection, dr99[i, :, :] / 100.0, "green", n_iter, r"$c_\operatorname{99}$"
        )
        plot_boxes(
            detection, dr97[i, :, :] / 100.0, "magenta", n_iter, r"$c_\operatorname{97}$"
        )
        detection.axis([0, n_iter, 0, 1])
        detection.set_xticks(arange(0, n_iter + 1, 10))
        detection.set_xticklabels([])
        detection.legend(loc="lower right")

        methaus = fig.add_subplot(3, 3, i * 3 + 2)
        plot_boxes(methaus, 1 - hc[i, :, :], "chartreuse", n_iter, r"$1-d_H^c$")
        plot_boxes(methaus, 1 - hcpa[i, :, :], "red", n_iter, r"$1-d_H^{cpa}$")
        plot_boxes(methaus, 1 - hfs[i, :, :], "magenta", n_iter, r"$1-d_H^{fs}$")
        plot_boxes(methaus, 1 - hbc[i, :, :], "blue", n_iter, r"$1-d_H^{bc}$")
        plot_boxes(methaus, 1 - hg[i, :, :], "deepskyblue", n_iter, r"$1-d_H^{g}$")
        plot_boxes(methaus, 1 - hfb[i, :, :], "orange", n_iter, r"$1-d_H^{fb}$")

        methaus.axis([0, n_iter, 0, 1])
        methaus.set_xticks(arange(0, n_iter + 1, 10))
        methaus.set_xticklabels([])
        methaus.set_yticklabels([])
        methaus.legend(loc="lower right")

        metwass = fig.add_subplot(3, 3, i * 3 + 3)
        plot_boxes(metwass, 1 - wc[i, :, :], "chartreuse", n_iter, r"$1-d_W^c$")
        plot_boxes(metwass, 1 - wcpa[i, :, :], "red", n_iter, r"$1-d_W^{cpa}$")
        plot_boxes(metwass, 1 - wfs[i, :, :], "magenta", n_iter, r"$1-d_W^{fs}$")
        plot_boxes(metwass, 1 - wbc[i, :, :], "blue", n_iter, r"$1-d_W^{bc}$")
        plot_boxes(metwass, 1 - wg[i, :, :], "deepskyblue", n_iter, r"$1-d_W^{g}$")
        plot_boxes(metwass, 1 - wfb[i, :, :], "orange", n_iter, r"$1-d_W^{fb}$")
        metwass.axis([0, n_iter, 0, 1])
        metwass.set_xticks(arange(0, n_iter + 1, 10))
        metwass.set_xticklabels([])
        metwass.set_yticklabels([])
        metwass.legend(loc="lower right")
        metwass.set_title(" ")

        metwass.annotate(
            "SNR " + s,
            xy=(0.51, 1.0 - i * 1.0 / 3.0 + i * 0.01 - 0.001),
            xycoords="figure fraction",
            horizontalalignment="center",
            verticalalignment="top",
            fontsize="large",
        )

    detection.set_xticks(arange(0, n_iter + 1, 10))
    detection.set_xticklabels(arange(0, n_iter + 1, 10))
    methaus.set_xticks(arange(0, n_iter + 1, 10))
    methaus.set_xticklabels(arange(0, n_iter + 1, 10))
    metwass.set_xticks(arange(0, n_iter + 1, 10))
    metwass.set_xticklabels(arange(0, n_iter + 1, 10))
    plt.tight_layout(1.2)
    plt.savefig(figname + ".png")


def plot_recov(wc, wfs, hc, hfs, dr99, dr97, n_iter, figname):
    snr = ["30", "20", "10"]
    fig = plt.figure(figsize=(18, 10))
    for i, s in enumerate(snr):
        # plotting data from detection rate
        detection = fig.add_subplot(3, 3, i * 3 + 1)
        plot_boxes(
            detection, dr99[i, :, :] / 100.0, "green", n_iter, r"$c_\operatorname{99}$"
        )
        plot_boxes(
            detection, dr97[i, :, :] / 100.0, "magenta", n_iter, r"$c_\operatorname{97}$"
        )
        detection.axis([0, n_iter, 0, 1])
        detection.set_xticks(arange(0, n_iter + 1, 10))
        detection.set_xticklabels([])
        detection.legend(loc="lower right")

        # plotting data from hausdorff metric
        methaus = fig.add_subplot(3, 3, i * 3 + 2)
        plot_boxes(methaus, 1 - hc[i, :, :], "cyan", n_iter, r"$1-d_H^c$")
        plot_boxes(methaus, 1 - hfs[i, :, :], "yellow", n_iter, r"$1-d_H^{fs}$")
        methaus.axis([0, n_iter, 0, 1])
        methaus.set_xticks(arange(0, n_iter + 1, 10))
        methaus.set_xticklabels([])
        methaus.set_yticklabels([])
        methaus.legend(loc="lower right")

        # plotting data from wasserstein metric
        metwass = fig.add_subplot(3, 3, i * 3 + 3)
        plot_boxes(metwass, 1 - wc[i, :, :], "red", n_iter, r"$1-d_W^c$")
        plot_boxes(metwass, 1 - wfs[i, :, :], "blue", n_iter, r"$1-d_W^{fs}$")
        metwass.axis([0, n_iter, 0, 1])
        metwass.set_xticks(arange(0, n_iter + 1, 10))
        metwass.set_xticklabels([])
        metwass.set_yticklabels([])
        metwass.legend(loc="lower right")
        metwass.set_title(" ")

        metwass.annotate(
            "SNR " + s,
            xy=(0.51, 1.0 - i * 1.0 / 3.0 + i * 0.01 - 0.001),
            xycoords="figure fraction",
            horizontalalignment="center",
            verticalalignment="top",
            fontsize="large",
        )

    detection.set_xticks(arange(0, n_iter + 1, 10))
    detection.set_xticklabels(arange(0, n_iter + 1, 10))
    methaus.set_xticks(arange(0, n_iter + 1, 10))
    methaus.set_xticklabels(arange(0, n_iter + 1, 10))
    metwass.set_xticks(arange(0, n_iter + 1, 10))
    metwass.set_xticklabels(arange(0, n_iter + 1, 10))
    plt.tight_layout(1.2)
    plt.savefig(figname + ".png")


def callback_recovery(loc):
    d = loc["dict_obj"]
    d.wc.append(emd(loc["dictionary"], d.generating_dict, "chordal", scale=True))
    d.wfs.append(emd(loc["dictionary"], d.generating_dict, "fubinistudy", scale=True))
    d.wcpa.append(
        emd(loc["dictionary"], d.generating_dict, "chordal_principal_angles", scale=True)
    )
    d.wbc.append(emd(loc["dictionary"], d.generating_dict, "binetcauchy", scale=True))
    d.wg.append(emd(loc["dictionary"], d.generating_dict, "geodesic", scale=True))
    d.wfb.append(emd(loc["dictionary"], d.generating_dict, "frobenius", scale=True))
    d.hc.append(hausdorff(loc["dictionary"], d.generating_dict, "chordal", scale=True))
    d.hfs.append(
        hausdorff(loc["dictionary"], d.generating_dict, "fubinistudy", scale=True)
    )
    d.hcpa.append(
        hausdorff(
            loc["dictionary"], d.generating_dict, "chordal_principal_angles", scale=True
        )
    )
    d.hbc.append(
        hausdorff(loc["dictionary"], d.generating_dict, "binetcauchy", scale=True)
    )
    d.hg.append(hausdorff(loc["dictionary"], d.generating_dict, "geodesic", scale=True))
    d.hfb.append(hausdorff(loc["dictionary"], d.generating_dict, "frobenius", scale=True))
    d.dr99.append(detection_rate(loc["dictionary"], d.generating_dict, 0.99))
    d.dr97.append(detection_rate(loc["dictionary"], d.generating_dict, 0.97))


rng_global = RandomState(1)
n_samples, n_dims, n_kernels = 1500, 5, 50
n_features = kernel_init_len = 20
n_nonzero_coefs, learning_rate = 3, 1.5
n_experiments, n_iter = 15, 25
snr = [30, 20, 10]
n_snr = len(snr)
n_jobs, batch_size = -1, 60

backup_fname = "expe_multi_reco_all.pck"

if exists(backup_fname):
    with open(backup_fname, "r") as f:
        o = pickle.load(f)
    wc, wfs, hc, hfs = o["wc"], o["wfs"], o["hc"], o["hfs"]
    wcpa, wbc, wg, wfb = o["wcpa"], o["wbc"], o["wg"], o["wfb"]
    hcpa, hbc, hg, hfb = o["hcpa"], o["hbc"], o["hg"], o["hfb"]
    dr99, dr97 = o["dr99"], o["dr97"]
    plot_recov(wc, wfs, hc, hfs, dr99, dr97, n_iter, "multivariate_recov")
else:
    wc = zeros((n_snr, n_experiments, n_iter))
    wfs = zeros((n_snr, n_experiments, n_iter))
    wcpa = zeros((n_snr, n_experiments, n_iter))
    wbc = zeros((n_snr, n_experiments, n_iter))
    wg = zeros((n_snr, n_experiments, n_iter))
    wfb = zeros((n_snr, n_experiments, n_iter))
    hc = zeros((n_snr, n_experiments, n_iter))
    hfs = zeros((n_snr, n_experiments, n_iter))
    hcpa = zeros((n_snr, n_experiments, n_iter))
    hbc = zeros((n_snr, n_experiments, n_iter))
    hg = zeros((n_snr, n_experiments, n_iter))
    hfb = zeros((n_snr, n_experiments, n_iter))
    dr99 = zeros((n_snr, n_experiments, n_iter))
    dr97 = zeros((n_snr, n_experiments, n_iter))

    for i, s in enumerate(snr):
        for e in range(n_experiments):
            g, X, code = _generate_testbed(
                kernel_init_len,
                n_nonzero_coefs,
                n_kernels,
                n_samples,
                n_features,
                n_dims,
                s,
            )
            d = MiniBatchMultivariateDictLearning(
                n_kernels=n_kernels,
                batch_size=batch_size,
                n_iter=n_iter,
                n_nonzero_coefs=n_nonzero_coefs,
                callback=callback_recovery,
                n_jobs=n_jobs,
                learning_rate=learning_rate,
                kernel_init_len=kernel_init_len,
                verbose=1,
                random_state=rng_global,
            )
            d.generating_dict = list(g)
            d.wc, d.wfs, d.hc, d.hfs = list(), list(), list(), list()
            d.wcpa, d.wbc, d.wg, d.wfb = list(), list(), list(), list()
            d.hcpa, d.hbc, d.hg, d.hfb = list(), list(), list(), list()
            d.dr99, d.dr97 = list(), list()
            print("\nExperiment", e + 1, "on", n_experiments)
            d = d.fit(X)
            wc[i, e, :] = array(d.wc)
            wfs[i, e, :] = array(d.wfs)
            hc[i, e, :] = array(d.hc)
            hfs[i, e, :] = array(d.hfs)
            wcpa[i, e, :] = array(d.wcpa)
            wbc[i, e, :] = array(d.wbc)
            wg[i, e, :] = array(d.wg)
            wfb[i, e, :] = array(d.wfb)
            hcpa[i, e, :] = array(d.hcpa)
            hbc[i, e, :] = array(d.hbc)
            hg[i, e, :] = array(d.hg)
            hfb[i, e, :] = array(d.hfb)
            dr99[i, e, :] = array(d.dr99)
            dr97[i, e, :] = array(d.dr97)
    # fmt: off
    with open(backup_fname, "w") as f:
        o = {
            "wc": wc, "wfs": wfs, "hc": hc, "hfs": hfs, "dr99": dr99, "dr97": dr97,
            "wcpa": wcpa, "wbc": wbc, "wg": wg, "wfb": wfb, "hcpa": hcpa, "hbc": hbc,
            "hg": hg, "hfb": hfb,
        }
        pickle.dump(o, f)
    # plot_recov(wc, wfs, hc, hfs, dr99, dr97, n_iter, "multivariate_recov")
    plot_recov_all(
        wc, wfs, wcpa, wbc, wg, wfb, hc, hfs, hcpa, hbc, hg, hfb,
        dr99, dr97, n_iter, "multivariate_recov_all",
    )
    # fmt: on