hrntsm/Tunny

View on GitHub
Optuna/Visualization/Python/plot_pareto_front.py

Summary

Maintainability
A
1 hr
Test Coverage
import json
import optuna
from optuna import Study
import plotly.graph_objects as go


def plot_pareto_front(
    study: Study,
    objective_names: list[str],
    objective_index: list[int],
    includeDominatedTrials: bool,
) -> go.Figure:
    if len(objective_index) == 2:
        targets = lambda t: (t.values[objective_index[0]], t.values[objective_index[1]])
    elif len(objective_index) == 3:
        targets = lambda t: [
            t.values[objective_index[0]],
            t.values[objective_index[1]],
            t.values[objective_index[2]],
        ]

    fig: go.Figure = optuna.visualization.plot_pareto_front(
        study,
        target_names=objective_names,
        targets=targets,
        include_dominated_trials=True if includeDominatedTrials else False,
    )

    return fig


def truncate(fig, study: Study) -> go.Figure:
    user_attr = study.trials[0].user_attrs
    has_geometry = "Geometry" in user_attr
    if has_geometry == False:
        return fig

    for scatter_id, _ in enumerate(fig.data):
        new_texts = []
        for _, original_label in enumerate(fig.data[scatter_id]["text"]):
            json_label = json.loads(original_label.replace("<br>", "\\n"))
            json_label["user_attrs"].pop("Geometry")
            param_len = len(json_label["params"])
            while len(json_label["params"]) > 10:
                keys = list(json_label["params"].keys())
                json_label["params"].pop(keys.pop())
            if param_len > 10:
                json_label["params"]["__Omit_values__"] = "True"
            new_texts.append(json.dumps(json_label, indent=2).replace("\\n", "<br>"))
        fig.data[scatter_id]["text"] = new_texts

    return fig