hrntsm/Tunny

View on GitHub
Optuna/Study/StudyWrapper.cs

Summary

Maintainability
A
45 mins
Test Coverage
using System;
using System.Linq;

using Optuna.Trial;
using Optuna.Util;

using Python.Runtime;

namespace Optuna.Study
{
    public class StudyWrapper
    {
        public dynamic PyInstance { get; private set; }
        public bool StopFlag => PyInstance._stop_flag;
        public int Id => PyInstance._study_id;
        public double[][] BestTrials
        {
            get
            {
                dynamic[] bestTrials = PyInstance.best_trials;
                return bestTrials.Select(t => (double[])t.values).ToArray();
            }
        }

        public StudyWrapper(dynamic study)
        {
            PyInstance = study;
        }

        //TODO: Add check HumanInTheLoop instance
        public bool ShouldGenerate()
        {
            return PyInstance.should_generate();
        }

        public TrialWrapper Ask()
        {
            return new TrialWrapper(PyInstance.ask());
        }

        public void Tell(TrialWrapper trial, double value)
        {
            PyInstance.tell(trial.PyInstance, value);
        }

        public void Tell(TrialWrapper trial, double[] values)
        {
            PyInstance.tell(trial.PyInstance, values);
        }

        public void Tell(TrialWrapper trial, TrialState state)
        {
            dynamic optuna = Py.Import("optuna");
            dynamic pyState;
            switch (state)
            {
                case TrialState.COMPLETE:
                    pyState = optuna.trial.TrialState.COMPLETE;
                    break;
                case TrialState.FAIL:
                    pyState = optuna.trial.TrialState.FAIL;
                    break;
                case TrialState.PRUNED:
                    pyState = optuna.trial.TrialState.PRUNED;
                    break;
                case TrialState.RUNNING:
                    pyState = optuna.trial.TrialState.RUNNING;
                    break;
                default:
                    pyState = PyInstance.Complete;
                    break;
            }
            PyInstance.tell(trial.PyInstance, pyState);
        }

        public void EnqueueTrial(PyDict paramDict, PyDict attrDict, bool skipIfExist)
        {
            PyInstance.enqueue_trial(paramDict, user_attrs: attrDict, skip_if_exists: skipIfExist);
        }

        public void SetUserAttribute(string key, string value)
        {
            PyInstance.set_user_attr(key, value);
        }

        public void SetUserAttribute(string key, string[] values)
        {
            PyList pyList = PyConverter.EnumeratorToPyList(values);
            PyInstance.set_user_attr(key, pyList);
        }

        public void SetMetricNames(string[] names)
        {
            PyList pyList = PyConverter.EnumeratorToPyList(names);
            PyInstance.set_metric_names(pyList);
        }

        public static StudyWrapper CreateStudy(string studyName, dynamic sampler, string[] directions, dynamic storage, dynamic pruner, bool loadIfExists)
        {
            dynamic optuna = Py.Import("optuna");
            dynamic study = optuna.create_study(
                sampler: sampler,
                directions: directions,
                storage: storage,
                study_name: studyName,
                load_if_exists: loadIfExists,
                pruner: pruner
            );
            // for escape exception in Brute Force sampler
            // Study.stop() method throws exception when in_optimize_loop is false
            study._thread_local.in_optimize_loop = true;
            return new StudyWrapper(study);
        }

        public static StudyWrapper LoadStudy(dynamic storage, string studyName)
        {
            dynamic optuna = Py.Import("optuna");
            return studyName == null
                ? throw new ArgumentException("studyName must not be null.")
                : new StudyWrapper(optuna.load_study(storage: storage, study_name: studyName));
        }
    }
}