chie8842/expstock

View on GitHub
expstock/dbconnect.py

Summary

Maintainability
B
6 hrs
Test Coverage
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sqlite3

class DbConnect(object):

    def __init__(self, filepath):
        self.dbfile = filepath
        self.conn, self.c = self.db_connect()
        self._create_table_if_not_exists()
        self.experiment_id, self.param_id = self._get_ids()

    def _create_table_if_not_exists(self):
        sqlite_tables = self._get_tables()
        if 'experiments' not in sqlite_tables:
            self._create_table_experiments()
        if 'params' not in sqlite_tables:
            self._create_table_params()

    def db_connect(self):
        conn = sqlite3.connect(self.dbfile)
        c = conn.cursor()
        return conn, c

    def _get_tables(self):
        query = 'select name from sqlite_master where type="table"'
        tables = []
        for table_info in self.c.execute(query):
            tables.append(table_info[0])
        return tables

    def _create_table_experiments(self):
        query = """
            create table if not exists experiments(
                experiment_id integer primary key autoincrement
                , experiment_name text
                , memo text
                , start_time text
                , finish_time text
                , execution_time text
                , result text
                , git_head text
                , log_dir)
        """
        self.c.execute(query)
        self.conn.commit()

    def _create_table_params(self):
        query = """
            create table if not exists params(
                param_id integer primary key autoincrement
                , experiment_id integer
                , param_name text
                , value text)
        """
        self.c.execute(query)
        self.conn.commit()

    def _get_experiments_count(self):
        query = 'select count(*) from experiments'
        experiments_count = self.c.execute(query).__next__()
        return experiments_count[0]

    def _get_params_count(self):
        query = 'select count(*) from params'
        params_count = self.c.execute(query).__next__()
        return params_count[0]

    def _get_ids(self):
        querys = { 'get_experiment_id': 'select max(experiment_id) from experiments',
                'get_param_id': 'select max(param_id) from params' }
        if self._get_experiments_count() == 0:
            experiment_id = 1
        else:
            experiment_id = self.c.execute(querys['get_experiment_id']).__next__()[0] + 1

        if self._get_params_count() == 0:
            param_id = 1
        else:
            # param_id = 3
            param_id = self.c.execute(querys['get_param_id']).__next__()[0] + 1
        return experiment_id, param_id

    def insert_into_experiments(self, e):
        query = """
        insert into experiments values (?, ?, ?, ?, ?, ?, ?, ?, ?)
        """
        value = (
                self.experiment_id
                , e.exp_name
                , e.memo
                , e.start_time_str
                , e.finish_time_str
                , e.execution_time_str
                , e.result
                , e.git_head.decode('utf-8')
                , e.log_dirname + '/')
        self.c.execute(query, value)
        self.conn.commit()

    def insert_into_experiments_pre(self, e):
        query = """
        insert into experiments(
          experiment_id
          , experiment_name
          , memo
          , start_time
          , git_head
          , log_dir
        ) values (?, ?, ?, ?, ?, ?)
        """
        value = (
                self.experiment_id
                , e.exp_name
                , e.memo
                , e.start_time_str
                , e.git_head.decode('utf-8')
                , e.log_dirname + '/')
        self.c.execute(query, value)
        self.conn.commit()

    def update_experiments(self, e):
        query = """
        update experiments set
          finish_time = ?
          , execution_time = ?
          , result = ?
        where experiment_id = ?
        """
        value = (
                e.finish_time_str
                , e.execution_time_str
                , e.result
                , self.experiment_id)
        self.c.execute(query, value)
        self.conn.commit()


    def insert_into_params(self, e):
        query = """
        insert into params values (?, ?, ?, ?)
        """
        values = []
        for param in e.params:
            values.append(
                    (
                        self.param_id,
                        self.experiment_id,
                        param[0],
                        str(param[1])
                    ))
            self.param_id += 1
        self.c.executemany(query, values)
        self.conn.commit()