aergoio/aergo

View on GitHub
contract/statesql.go

Summary

Maintainability
A
2 hrs
Test Coverage
F
1%
package contract

/*
#include "sqlite3-binding.h"
*/
import "C"
import (
    "context"
    "database/sql"
    "encoding/json"
    "errors"
    "fmt"
    "os"
    "path/filepath"
    "sync"

    "github.com/aergoio/aergo-lib/log"
    "github.com/aergoio/aergo/v2/internal/enc/base58"
    "github.com/aergoio/aergo/v2/state"
    "github.com/aergoio/aergo/v2/types"
)

var (
    ErrDBOpen = errors.New("failed to open the sql database")
    ErrUndo   = errors.New("failed to undo the sql database")
    ErrFindRp = errors.New("cannot find a recovery point")

    database = &sqlDatabase{}
    load     sync.Once

    sqlLgr = log.NewLogger("statesql")

    queryConn     *SQLiteConn
    queryConnLock sync.Mutex
)

const (
    statesqlDriver = "statesql"
    queryDriver    = "query"
)

type sqlDatabase struct {
    DBs        map[string]*litetree
    OpenDbName string
    DataDir    string
}

func init() {
    sql.Register(statesqlDriver, &SQLiteDriver{
        ConnectHook: func(conn *SQLiteConn) error {
            if _, ok := database.DBs[database.OpenDbName]; !ok {
                b, err := base58.Decode(database.OpenDbName)
                if err != nil {
                    sqlLgr.Error().Err(err).Msg("Open SQL Connection")
                    return nil
                }
                database.DBs[database.OpenDbName] = &litetree{
                    Conn:      nil,
                    db:        nil,
                    tx:        nil,
                    conn:      conn,
                    name:      database.OpenDbName,
                    accountID: types.AccountID(types.ToHashID(b)),
                }
            } else {
                sqlLgr.Warn().Err(errors.New("duplicated connection")).Msg("Open SQL Connection")
            }
            return nil
        },
    })
    sql.Register(queryDriver, &SQLiteDriver{
        ConnectHook: func(conn *SQLiteConn) error {
            queryConn = conn
            return nil
        },
    })
}

func checkPath(path string) error {
    _, err := os.Stat(path)
    if os.IsNotExist(err) {
        err = os.Mkdir(path, 0755)
    }
    return err
}

func LoadDatabase(dataDir string) error {
    var err error
    load.Do(func() {
        path := filepath.Join(dataDir, statesqlDriver)
        sqlLgr.Debug().Str("path", path).Msg("loading statesql")
        if err = checkPath(path); err == nil {
            database.DBs = make(map[string]*litetree)
            database.DataDir = path
        }
    })
    return err
}

func LoadTestDatabase(dataDir string) error {
    var err error
    path := filepath.Join(dataDir, statesqlDriver)
    sqlLgr.Debug().Str("path", path).Msg("loading statesql")
    if err = checkPath(path); err == nil {
        database.DBs = make(map[string]*litetree)
        database.DataDir = path
    }
    return err
}

func CloseDatabase() {
    var err error
    for name, db := range database.DBs {
        if db.tx != nil {
            err = db.tx.rollback()
            if err != nil {
                sqlLgr.Warn().Err(err).Str("db_name", name).Msg("SQL TX close")
            }
            db.tx = nil
        }
        err = db.close()
        if err != nil {
            sqlLgr.Warn().Err(err).Str("db_name", name).Msg("SQL DB close")
        }
        delete(database.DBs, name)
    }
}

func SaveRecoveryPoint(bs *state.BlockState) error {
    defer CloseDatabase()

    for id, db := range database.DBs {
        if db.tx != nil {
            err := db.tx.commit()
            db.tx = nil
            if err != nil {
                sqlLgr.Warn().Err(err).Str("db_name", id).Msg("SQL TX commit")
                continue
            }
            rp := db.recoveryPoint()
            if rp == 0 {
                return ErrFindRp
            }
            if rp > 0 {
                if sqlLgr.IsDebugEnabled() {
                    sqlLgr.Debug().Str("db_name", id).Uint64("commit_id", rp).Msg("save recovery point")
                }
                receiverState, err := bs.GetAccountState(db.accountID)
                if err != nil {
                    return err
                }
                receiverChange := receiverState.Clone()
                receiverChange.SqlRecoveryPoint = uint64(rp)
                err = bs.PutState(db.accountID, receiverChange)
                if err != nil {
                    return err
                }
            }
        }
    }
    return nil
}

func beginTx(dbName string, rp uint64) (sqlTx, error) {
    db, err := conn(dbName)
    defer func() {
        if err != nil {
            delete(database.DBs, dbName)
        }
    }()
    if err != nil {
        return nil, err
    }
    if rp == 1 {
        tx, err := db.BeginTx(context.Background(), nil)
        if err != nil {
            goto failed
        }
        _, err = tx.ExecContext(context.Background(), "create table if not exists _dummy(_dummy)")
        if err != nil {
            goto failed
        }
        err = tx.Commit()
        if err != nil {
            goto failed
        }
    }
failed:
    if err != nil {
        sqlLgr.Fatal().Err(err)
        _ = db.close()
        return nil, ErrDBOpen
    }
    return db.beginTx(rp)
}

func beginReadOnly(dbName string, rp uint64) (sqlTx, error) {
    db, err := readOnlyConn(dbName)
    if err != nil {
        return nil, err
    }
    return newReadOnlySqlTx(db, rp)
}

func conn(dbName string) (*litetree, error) {
    if db, ok := database.DBs[dbName]; ok {
        return db, nil
    }
    return openDB(dbName)
}

func dataSrc(dbName string) string {
    return fmt.Sprintf(
        "file:%s/%s.db?branches=on&max_db_size=%d",
        database.DataDir,
        dbName,
        maxSQLDBSize*1024*1024)
}

func readOnlyConn(dbName string) (*litetree, error) {
    queryConnLock.Lock()
    defer queryConnLock.Unlock()

    db, err := sql.Open(queryDriver, dataSrc(dbName)+"&_query_only=true")
    if err != nil {
        sqlLgr.Fatal().Err(err)
        return nil, ErrDBOpen
    }
    var c *sql.Conn
    err = db.Ping()
    if err == nil {
        c, err = db.Conn(context.Background())
    }
    if err != nil {
        sqlLgr.Fatal().Err(err)
        _ = db.Close()
        return nil, ErrDBOpen
    }
    return &litetree{
        Conn: c,
        db:   db,
        tx:   nil,
        conn: queryConn,
        name: dbName,
    }, nil
}

func openDB(dbName string) (*litetree, error) {
    database.OpenDbName = dbName
    db, err := sql.Open(statesqlDriver, dataSrc(dbName))
    if err != nil {
        sqlLgr.Fatal().Err(err)
        return nil, ErrDBOpen
    }
    c, err := db.Conn(context.Background())
    if err != nil {
        sqlLgr.Fatal().Err(err)
        _ = db.Close()
        return nil, ErrDBOpen
    }
    err = c.PingContext(context.Background())
    if err != nil {
        sqlLgr.Fatal().Err(err)
        _ = c.Close()
        _ = db.Close()
        return nil, ErrDBOpen
    }
    database.DBs[dbName].Conn = c
    database.DBs[dbName].db = db
    return database.DBs[dbName], nil
}

type litetree struct {
    *sql.Conn
    db        *sql.DB
    tx        sqlTx
    conn      *SQLiteConn
    name      string
    accountID types.AccountID
}

func (db *litetree) beginTx(rp uint64) (sqlTx, error) {
    if db.tx == nil {
        err := db.restoreRecoveryPoint(rp)
        if err != nil {
            return nil, err
        }
        if sqlLgr.IsDebugEnabled() {
            sqlLgr.Debug().Str("db_name", db.name).Msg("begin transaction")
        }
        tx, err := db.BeginTx(context.Background(), nil)
        if err != nil {
            return nil, err
        }
        db.tx = &writableSqlTx{
            sqlTxCommon: sqlTxCommon{litetree: db},
            Tx:          tx,
        }
    }
    return db.tx, nil
}

type branchInfo struct {
    TotalCommits uint64 `json:"total_commits"`
}

func (db *litetree) recoveryPoint() uint64 {
    row := db.QueryRowContext(context.Background(), "pragma branch_info(master)")
    var rv string
    err := row.Scan(&rv)
    if err != nil {
        return uint64(0)
    }
    var bi branchInfo
    err = json.Unmarshal([]byte(rv), &bi)
    if err != nil {
        return uint64(0)
    }
    return bi.TotalCommits
}

func (db *litetree) restoreRecoveryPoint(stateRp uint64) error {
    lastRp := db.recoveryPoint()
    if sqlLgr.IsDebugEnabled() {
        sqlLgr.Debug().Str("db_name", db.name).
            Uint64("state_rp", stateRp).
            Uint64("last_rp", lastRp).Msgf("restore recovery point")
    }
    if lastRp == 0 {
        return ErrFindRp
    }
    if stateRp == lastRp {
        return nil
    }
    if stateRp > lastRp {
        return ErrUndo
    }
    if err := db.rollbackToRecoveryPoint(stateRp); err != nil {
        return err
    }
    if sqlLgr.IsDebugEnabled() {
        sqlLgr.Debug().Str("db_name", db.name).Uint64("commit_id", stateRp).
            Msg("restore recovery point")
    }
    return nil
}

func (db *litetree) rollbackToRecoveryPoint(rp uint64) error {
    _, err := db.ExecContext(
        context.Background(),
        fmt.Sprintf("pragma branch_truncate(master.%d)", rp),
    )
    return err
}

func (db *litetree) snapshotView(rp uint64) error {
    if sqlLgr.IsDebugEnabled() {
        sqlLgr.Debug().Uint64("rp", rp).Msgf("snapshot view, %p", db.Conn)
    }
    _, err := db.ExecContext(
        context.Background(),
        fmt.Sprintf("pragma branch=master.%d", rp),
    )
    if err != nil && rp == 1 {
        return nil
    }
    return err
}

func (db *litetree) close() error {
    err := db.Conn.Close()
    if err != nil {
        _ = db.db.Close()
        return err
    }
    return db.db.Close()
}

type sqlTx interface {
    commit() error
    rollback() error
    savepoint() error
    release() error
    rollbackToSavepoint() error
    subSavepoint(string) error
    subRelease(string) error
    rollbackToSubSavepoint(string) error
    getHandle() *C.sqlite3
    close() error
    begin() error
}

type sqlTxCommon struct {
    *litetree
}

func (tx *sqlTxCommon) getHandle() *C.sqlite3 {
    return tx.litetree.conn.db
}

type writableSqlTx struct {
    sqlTxCommon
    *sql.Tx
}

var _ sqlTx = (*writableSqlTx)(nil)

func (tx *writableSqlTx) commit() error {
    if sqlLgr.IsDebugEnabled() {
        sqlLgr.Debug().Str("db_name", tx.litetree.name).Msg("commit")
    }
    return tx.Tx.Commit()
}

func (tx *writableSqlTx) rollback() error {
    if sqlLgr.IsDebugEnabled() {
        sqlLgr.Debug().Str("db_name", tx.litetree.name).Msg("rollback")
    }
    return tx.Tx.Rollback()
}

func (tx *writableSqlTx) savepoint() error {
    if sqlLgr.IsDebugEnabled() {
        sqlLgr.Debug().Str("db_name", tx.litetree.name).Msg("savepoint")
    }
    _, err := tx.Tx.Exec("SAVEPOINT \"" + tx.litetree.name + "\"")
    return err
}

func (tx *writableSqlTx) subSavepoint(name string) error {
    if sqlLgr.IsDebugEnabled() {
        sqlLgr.Debug().Str("db_name", name).Msg("savepoint")
    }
    _, err := tx.Tx.Exec("SAVEPOINT \"" + name + "\"")
    return err
}

func (tx *writableSqlTx) release() error {
    if sqlLgr.IsDebugEnabled() {
        sqlLgr.Debug().Str("db_name", tx.litetree.name).Msg("release")
    }
    err := tx.litetree.conn.DBCacheFlush()
    if err != nil {
        return err
    }
    _, err = tx.Tx.Exec("RELEASE SAVEPOINT \"" + tx.litetree.name + "\"")
    return err
}

func (tx *writableSqlTx) subRelease(name string) error {
    if sqlLgr.IsDebugEnabled() {
        sqlLgr.Debug().Str("name", name).Msg("release")
    }
    _, err := tx.Tx.Exec("RELEASE SAVEPOINT \"" + name + "\"")
    return err
}

func (tx *writableSqlTx) rollbackToSavepoint() error {
    if sqlLgr.IsDebugEnabled() {
        sqlLgr.Debug().Str("db_name", tx.litetree.name).Msg("rollback to savepoint")
    }
    _, err := tx.Tx.Exec("ROLLBACK TO SAVEPOINT \"" + tx.litetree.name + "\"")
    return err
}

func (tx *writableSqlTx) rollbackToSubSavepoint(name string) error {
    if sqlLgr.IsDebugEnabled() {
        sqlLgr.Debug().Str("db_name", name).Msg("rollback to savepoint")
    }
    _, err := tx.Tx.Exec("ROLLBACK TO SAVEPOINT \"" + name + "\"")
    return err
}

func (tx *writableSqlTx) close() error {
    return errors.New("assert(only read-tx allowed)")
}

func (tx *writableSqlTx) begin() error {
    _, err := tx.Tx.Exec("BEGIN")
    return err
}

type readOnlySqlTx struct {
    sqlTxCommon
}

var _ sqlTx = (*readOnlySqlTx)(nil)

func newReadOnlySqlTx(db *litetree, rp uint64) (sqlTx, error) {
    if err := db.snapshotView(rp); err != nil {
        return nil, err
    }
    tx := &readOnlySqlTx{
        sqlTxCommon: sqlTxCommon{litetree: db},
    }
    return tx, nil
}

func (tx *readOnlySqlTx) commit() error {
    return errors.New("only select queries allowed")
}

func (tx *readOnlySqlTx) rollback() error {
    if sqlLgr.IsDebugEnabled() {
        sqlLgr.Debug().Str("db_name", tx.litetree.name).Msg("read-only tx is closed")
    }
    return tx.litetree.close()
}

func (tx *readOnlySqlTx) savepoint() error {
    return errors.New("only select queries allowed")
}

func (tx *readOnlySqlTx) release() error {
    return errors.New("only select queries allowed")
}

func (tx *readOnlySqlTx) rollbackToSavepoint() error {
    return tx.rollback()
}

func (tx *readOnlySqlTx) subSavepoint(name string) error {
    return nil
}

func (tx *readOnlySqlTx) subRelease(name string) error {
    return nil
}

func (tx *readOnlySqlTx) rollbackToSubSavepoint(name string) error {
    return nil
}

func (tx *readOnlySqlTx) close() error {
    return tx.sqlTxCommon.close()
}

func (tx *readOnlySqlTx) begin() error {
    return errors.New("assert(only writable-tx allowed)")
}