chie8842/expstock

View on GitHub
expstock/expstock.py

Summary

Maintainability
A
50 mins
Test Coverage
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import functools
import os
import platform
import sys
import subprocess
from expstock.dbconnect import DbConnect


class ExpStock(object):
    """ExpStock
    e.g. Now, you have variables `a`, `b` and experimental function `hoge`,

    >>> a = 1
    >>> b = 2
    >>> e = ExpStock()
    >>> e.append_param(param_1, param_2, param_3)
    >>> @e.expstock
    ... def exp_func(x, y):
    ...     print(a + b)
    >>> exp_func(a, b)

    or at a simple script, you can use `pre_stock` and `post_stock`, too.

    >>> a = 1
    >>> b = 2
    >>> e = ExpStock()
    >>> e.append_param(param_1, param_2, param_3)
    >>> pre_stock()
    >>> print(a + b)
    >>> post_stock()
    """

    def __init__(self,
                 log_dirname='',
                 params=[],
                 memo='',
                 git_check=True,
                 report=False,
                 exp_name='Untitled',
                 dbsave=False):
        """__init__

        :param log_dirname: name of the directory to stock expriments
        :param params: parameters of the function of the experiments.
                       `params` should be a list of tuples.
                       e.g. [('param_1', 1), ('param_2', 'hoge')]
                       At there, `param_1` and `param_2` are the names of variables,
                       and `1` and `hoge` are the variables for experiments
        :param memo:
        :param git_check:
        """
        self.stock_root_dir = 'experiments'
        self.params = params
        self.memo = memo
        self.git_check = git_check
        self.git_head = b''
        self.git_diff = b''
        self.start_time = None
        self.finish_time = None
        self.exp_name = exp_name
        self.result = ''
        self._set_dirname(log_dirname)
        self.report = report
        self.dbsave=dbsave
        self.dbfile = os.path.join(self.stock_root_dir, 'experiments.db')

    def _set_dirname(self, log_dirname):
        if log_dirname != '':
            self.log_dirname = os.path.abspath(log_dirname)
        else:
            current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
            self.log_dirname = os.path.abspath(
                    os.path.join(
                        self.stock_root_dir,
                        '{}_{}'.format(current_time, self.exp_name)
                        )
                    )

    def set_memo(self, memo):
        self.memo = memo

    def _mk_logdir(self):
        if not os.path.isdir(self.log_dirname):
            os.makedirs(self.log_dirname)

    def _get_machine_info(self):
        """_get_machine_info
        TODO: split to independent module
        TODO: add function to get resource usage
        """
        machine_info = []
        machine_info.append(('system', platform.system()))
        machine_info.append(('node', platform.node()))
        machine_info.append(('release', platform.release()))
        machine_info.append(('version', platform.version()))
        machine_info.append(('machine', platform.machine()))
        machine_info.append(('processor', platform.processor()))
        machine_info.append(('python_version', platform.python_version()))

        return machine_info

    def _get_git_info(self):
        """_get_git_info
        TODO: split to independent module
        :return: (git_head, git_diff)
        :return type: (str, str)
        """
        git_head = subprocess.check_output(['git', 'log', '-n', '1', '--format=%H'])
        git_diff = subprocess.check_output(['git', 'diff'])
        return git_head, git_diff

    def append_param(self, **kwargs):
        flg = 1
        kwargs = locals()['kwargs']
        for obj in kwargs:
            self.params.append((obj, kwargs[obj]))

    def set_memo(self, memo):
        self.memo = memo

    def _write_log(self, filename, value, open_type='w'):
        filepath = os.path.join(self.log_dirname, filename)
        with open(filepath, open_type) as f:
            f.write(value)

    def _write_logs(self, filename, values):
        filepath = os.path.join(self.log_dirname, filename)
        with open(filepath, 'w') as f:
            for value in values:
                if type(value) == tuple:
                    f.write('{} = {}\n'.format(value[0], value[1]))
                elif type(value) == str:
                    f.write(value + '\n')

    def _dbsave_pre(self):
        dbfile = os.path.expanduser('~/.experiments.db')
        self.dbconn = DbConnect(dbfile)
        self.dbconn.insert_into_experiments_pre(self)
        self.dbconn.insert_into_params(self)

    def _dbsave_post(self):
        try:
            self.dbconn.update_experiments(self)
        except NameError:
            pass

    def _create_report(self):
        """_create_report
        TODO: Jinja2を使って書き換える
        """

        filepath = os.path.join(self.log_dirname, 'report.txt')
        text = """
report
========

exp_name: {}
memo: {}

start_time: {}
finish_time: {}

git_head: {}

""".format(
                self.exp_name,
                self.memo,
                self.start_time,
                self.finish_time,
                self.git_head)

        params_text = 'params:\n'
        empty_row = '\n'

        with open(filepath, 'w') as f:
            f.write(text)
            f.write(params_text)
            for param in self.params:
                f.write('{} = {}\n'.format(param[0], param[1]))

            f.write(empty_row)

    def pre_stock(self):
        self._mk_logdir()
        self.machine_info = self._get_machine_info()

        if self.git_check == True:
            self.git_head, self.git_diff = self._get_git_info()
            self._write_log('git_head.txt', self.git_head, 'wb')
            self._write_log('git_diff.txt', self.git_diff, 'wb')
        self._write_log('memo.txt', self.memo, 'w')
        self._write_logs('params.txt', self.params)
        self._write_logs('machine_info.txt', self.machine_info)

        self.start_time = datetime.now()
        self.start_time_str = self.start_time.strftime('%Y/%m/%d %H:%M:%S')
        self._write_log('exec_time.txt', 'start_time: {}\n'.format(self.start_time_str), 'w')
        if self.dbsave == True:
            self._dbsave_pre()

        # Change target of stdout and stderr to log files
        stdout_path = os.path.join(self.log_dirname, 'stdout.txt')
        stderr_path = os.path.join(self.log_dirname, 'stderr.txt')
        sys.stdout = open(stdout_path, 'w')
        sys.stderr = open(stderr_path, 'w')

    def post_stock(self, result=''):
        self.result = result
        sys.stdout.close()
        sys.stdout = sys.__stdout__
        sys.stderr.close()
        sys.stderr = sys.__stderr__
        self.finish_time = datetime.now()
        self.finish_time_str = self.finish_time.strftime('%Y/%m/%d %H:%M:%S')
        self.execution_time = self.finish_time - self.start_time
        self.execution_time_str = str(self.execution_time)
        self._write_log(
                'exec_time.txt',
                'finish_time: {}\nexecution_time: {}\n'.format(
                    self.finish_time_str,
                    self.execution_time_str),
                'a')
        self._write_log('result.txt', result)

        if self.report == True:
            self._create_report()

        if self.dbsave == True:
            self._dbsave_post()

def expstock(e):
    """expstock

    :param e: ExpStock instance
    :return: return value of the decorated function
    """

    def _expstock(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            e.pre_stock()
            result = func(*args, **kwargs)
            e.post_stock(str(result))
            return result
        return wrapper
    return _expstock