hrntsm/Tunny

View on GitHub
Tunny.Core/Settings/Pruner.cs

Summary

Maintainability
A
1 hr
Test Coverage
using System;
using System.Diagnostics;
using System.IO;

using Optuna.Pruner;

using Python.Runtime;

using Tunny.Core.TEnum;
using Tunny.Core.Util;

namespace Tunny.Core.Settings
{
    public class Pruner : IDisposable
    {
        public PrunerType Type { get; set; } = PrunerType.None;
        public bool IsEnabled { get; set; }
        public int EvaluateIntervalSeconds { get; set; }
        public bool IsWatcher { get; set; }
        public string ReporterPath { get; set; } = string.Empty;
        public string ReportFilePath { get; set; } = string.Empty;
        public string StopperPath { get; set; } = string.Empty;
        public string[] ReporterInput { get; set; } = Array.Empty<string>();
        public string[] StopperInput { get; set; } = Array.Empty<string>();
        public HyperbandPruner Hyperband { get; set; } = new HyperbandPruner();
        public MedianPruner Median { get; set; } = new MedianPruner();
        public NopPruner Nop { get; set; } = new NopPruner();
        public PatientPruner Patient { get; set; } = new PatientPruner();
        public PercentilePruner Percentile { get; set; } = new PercentilePruner();
        public SuccessiveHalvingPruner SuccessiveHalving { get; set; } = new SuccessiveHalvingPruner();
        public ThresholdPruner Threshold { get; set; } = new ThresholdPruner();
        public WilcoxonPruner Wilcoxon { get; set; } = new WilcoxonPruner();

        private Process _reporter;
        private PrunerStatus _status = PrunerStatus.None;

        public PrunerStatus GetPrunerStatus()
        {
            return _status;
        }

        public PrunerStatus CheckStatus()
        {
            if (ReporterPath == string.Empty || ReportFilePath == string.Empty || StopperPath == string.Empty)
            {
                _status = PrunerStatus.None;
            }
            else if (!File.Exists(ReporterPath) && !File.Exists(StopperPath))
            {
                _status = PrunerStatus.PathError;
            }
            else
            {
                _status = PrunerStatus.Runnable;
            }
            return _status;
        }

        public void ClearReporter()
        {
            if (_reporter != null)
            {
                _reporter.Kill();
                _reporter.Dispose();
                _reporter = null;
            }
        }

        public PrunerReport Evaluate()
        {
            PrunerReport report = null;
            if (_status != PrunerStatus.Runnable)
            {
                return report;
            }

            if (_reporter == null)
            {
                RunReporterProcess();
            }

            if (!IsWatcher)
            {
                _reporter.WaitForExit();
            }

            if (File.Exists(ReportFilePath))
            {
                report = PrunerReport.Deserialize(ReportFilePath);
                File.Delete(ReportFilePath);
            }

            if (!IsWatcher)
            {
                _reporter = null;
            }

            return report;
        }

        private void RunReporterProcess()
        {
            _reporter = new Process();
            _reporter.StartInfo.FileName = ReporterPath;
            _reporter.StartInfo.Arguments = string.Join(" ", ReporterInput);
            _reporter.StartInfo.WindowStyle = ProcessWindowStyle.Hidden;
            _reporter.Start();
        }

        public void RunStopperProcess()
        {
            var stopper = new Process();
            stopper.StartInfo.FileName = StopperPath;
            stopper.StartInfo.Arguments = string.Join(" ", StopperInput);
            stopper.StartInfo.WindowStyle = ProcessWindowStyle.Hidden;
            stopper.Start();
            stopper.WaitForExit();
        }

        public dynamic ToPython()
        {
            TLog.MethodStart();
            dynamic optunaPruner;
            dynamic optuna = Py.Import("optuna");
            switch (Type)
            {
                case PrunerType.Hyperband:
                    optunaPruner = Hyperband.ToPython(optuna);
                    break;
                case PrunerType.Median:
                    optunaPruner = Median.ToPython(optuna);
                    break;
                case PrunerType.Nop:
                    optunaPruner = Nop.ToPython(optuna);
                    break;
                case PrunerType.Patient:
                    optunaPruner = Patient.ToPython(optuna);
                    break;
                case PrunerType.Percentile:
                    optunaPruner = Percentile.ToPython(optuna);
                    break;
                case PrunerType.SuccessiveHalving:
                    optunaPruner = SuccessiveHalving.ToPython(optuna);
                    break;
                case PrunerType.Threshold:
                    optunaPruner = Threshold.ToPython(optuna);
                    break;
                case PrunerType.Wilcoxon:
                    optunaPruner = Wilcoxon.ToPython(optuna);
                    break;
                case PrunerType.None:
                    optunaPruner = null;
                    break;
                default:
                    throw new ArgumentException("Invalid pruner type.");
            }

            return optunaPruner;
        }

        public void Dispose()
        {
            if (_reporter != null)
            {
                _reporter.Kill();
                _reporter.Dispose();
            }
            GC.SuppressFinalize(this);
        }
    }
}