matejak/estimagus

View on GitHub
estimage/visualize/pert.py

Summary

Maintainability
A
0 mins
Test Coverage
from .. import PluginResolver

from . import utils


@PluginResolver.class_is_extendable("PertPlotter")
class PertPlotter:
    PERT_COLOR="blue"
    EXPECTED_COLOR="orange"

    def __init__(self, task_name: str, estimation):
        self.pert = estimation.get_pert()
        self.task_name = task_name
        self.estimation = estimation
        self.expected = self.estimation.expected

    def handle_border_discontinuities(self, ax):
        display_kwargs = dict(
            ec=self.PERT_COLOR, lw=2,
        )
        domain, values = self.pert
        if values[0] != 0:
            ax.scatter(domain[0], values[0], fc=self.PERT_COLOR, zorder=3, ** display_kwargs)
            ax.scatter(domain[0], 0, fc="white", ** display_kwargs, zorder=3)
        if values[-1] != 0:
            ax.scatter(domain[-1], values[-1], fc=self.PERT_COLOR, ** display_kwargs, zorder=3)
            ax.scatter(domain[-1], 0, fc="white", ** display_kwargs, zorder=3)

    def plot_continuous_pert(self, ax):
        domain, values = self.pert
        ax.plot(domain, values, c=self.PERT_COLOR, lw=2, label=f'task {self.task_name}')
        limits = ax.get_xlim()
        ax.plot((limits[0], domain[0]), (0, 0), c=self.PERT_COLOR, lw=2)
        ax.plot((domain[-1], limits[1]), (0, 0), c=self.PERT_COLOR, lw=2)
        self.handle_border_discontinuities(ax)
        ax.axvline(self.expected, c=self.EXPECTED_COLOR, label="expected value")
        ax.set_xlim(limits)

    def plot_delta_pert(self, ax):
        domain = self.pert[0]
        ax.axhline(0, color=self.PERT_COLOR, lw=2, label=f'task {self.task_name}')
        ax.axvline(self.expected, c=self.EXPECTED_COLOR, label="expected value", zorder=2)
        ax.set_ylim(-0.1, 1.1)
        ax.set_xlim(max(domain[0], -0.1), max(domain[-1], 0))
        ax.annotate(
            "", xy=(self.expected, 1), xycoords='data', xytext=(self.expected, 0), textcoords='data',
            arrowprops=dict(arrowstyle="->", connectionstyle="arc3", ec=self.PERT_COLOR, lw=2), zorder=4)
        ax.scatter(self.expected, 0, ec="b", fc="w", lw=2, zorder=3)

    def plot_any_pert(self, ax):
        if self.estimation.sigma == 0:
            self.plot_delta_pert(ax)
        else:
            self.plot_continuous_pert(ax)


def get_pert_in_figure(estimation, task_name, cls=None):
    if not cls:
        cls = PertPlotter
    plt = utils.get_standard_pyplot()

    fig, ax = plt.subplots(1, 1)

    plotter = cls(task_name, estimation)
    plotter.plot_any_pert(ax)

    ax.set_xlabel("points")
    ax.set_ylabel("probability density")
    ax.set_yticklabels([])
    ax.grid()
    ax.legend()

    return fig