hrntsm/Tunny

View on GitHub
Optuna/Visualization/Visualization.cs

Summary

Maintainability
B
4 hrs
Test Coverage
using System;

using Optuna.Util;

using Python.Runtime;

namespace Optuna.Visualization
{
    public class Visualization
    {
        public bool HasFigure
        {
            get
            { return _fig != null; }
        }

        private readonly dynamic _study;
        private dynamic _fig;

        public Visualization(dynamic study)
        {
            _study = study;
        }

        public void Slice(string objectiveName, int objectiveIndex, string variableName)
        {
            PyModule ps = Py.CreateScope();
            ps.Exec(ReadFileFromResource.Text("Optuna.Visualization.Python.plot_slice.py"));
            dynamic visualize = ps.Get("plot_slice");
            _fig = visualize(_study, objectiveName, objectiveIndex, variableName);
        }

        public void ParetoFront(string[] objectiveNames, int[] objectiveIndices, bool hasConstraint, bool includeDominatedTrials)
        {
            PyModule ps = Py.CreateScope();
            ps.Exec(ReadFileFromResource.Text("Optuna.Visualization.Python.plot_pareto_front.py"));
            dynamic visualize = ps.Get("plot_pareto_front");
            _fig = visualize(_study, objectiveNames, objectiveIndices, hasConstraint, includeDominatedTrials);
        }

        public void ParamImportances(string objectiveName, int objectiveIndex, string[] variableNames)
        {
            PyModule ps = Py.CreateScope();
            ps.Exec(ReadFileFromResource.Text("Optuna.Visualization.Python.plot_importances.py"));
            dynamic visualize = ps.Get("plot_importances");
            var pyList = new PyList();
            foreach (string item in variableNames)
            {
                pyList.Append(new PyString(item));
            }
            _fig = visualize(_study, objectiveName, objectiveIndex, pyList);
        }

        public void ParallelCoordinate(string objectiveName, int objectiveIndex, string[] variableNames)
        {
            PyModule ps = Py.CreateScope();
            ps.Exec(ReadFileFromResource.Text("Optuna.Visualization.Python.plot_parallel_coordinate.py"));
            dynamic visualize = ps.Get("plot_parallel_coordinate");
            _fig = visualize(_study, objectiveName, objectiveIndex, variableNames);
        }

        public void OptimizationHistory(string objectiveName, int objectiveIndex)
        {
            PyModule ps = Py.CreateScope();
            ps.Exec(ReadFileFromResource.Text("Optuna.Visualization.Python.plot_optimization_history.py"));
            dynamic visualize = ps.Get("plot_optimization_history");
            _fig = visualize(_study, objectiveName, objectiveIndex);
        }

        public void EDF(string objectiveName, int objectiveIndex)
        {
            PyModule ps = Py.CreateScope();
            ps.Exec(ReadFileFromResource.Text("Optuna.Visualization.Python.plot_edf.py"));
            dynamic visualize = ps.Get("plot_edf");
            _fig = visualize(_study, objectiveName, objectiveIndex);
        }

        public void Contour(string objectiveName, int objectiveIndex, string[] variableNames)
        {
            PyModule ps = Py.CreateScope();
            ps.Exec(ReadFileFromResource.Text("Optuna.Visualization.Python.plot_contour.py"));
            dynamic visualize = ps.Get("plot_contour");
            _fig = visualize(_study, objectiveName, objectiveIndex, variableNames);
        }

        public void Hypervolume()
        {
            PyModule ps = Py.CreateScope();
            ps.Exec(ReadFileFromResource.Text("Optuna.Visualization.Python.plot_hypervolume.py"));
            dynamic visualize = ps.Get("plot_hypervolume");
            _fig = visualize(_study);
        }

        public void Clustering(int nClusters, int[] objectiveIndex, int[] variableIndex)
        {
            PyModule ps = Py.CreateScope();
            ps.Exec(ReadFileFromResource.Text("Optuna.Visualization.Python.plot_clustering.py"));
            dynamic visualize = ps.Get("plot_clustering");
            _fig = visualize(_study, nClusters, objectiveIndex, variableIndex);
        }

        public void TruncateParetoFrontPlotHover()
        {
            CheckPlotCreated();
            PyModule ps = Py.CreateScope();
            ps.Exec(ReadFileFromResource.Text("Optuna.Visualization.Python.plot_pareto_front.py"));
            dynamic truncate = ps.Get("truncate");
            _fig = truncate(_fig, _study);
        }

        private void CheckPlotCreated()
        {
            if (_fig == null)
            {
                throw new InvalidOperationException("No plot has been created yet.");
            }
        }

        public void Show()
        {
            CheckPlotCreated();
            _fig.show();
        }

        public void SaveHtml(string path)
        {
            CheckPlotCreated();
            _fig.write_html(path);
        }
    }
}