sue445/plant_erd

View on GitHub
adapter/oracle/adapter.go

Summary

Maintainability
A
1 hr
Test Coverage
package oracle

import (
    mapset "github.com/deckarep/golang-set"
    "github.com/jmoiron/sqlx"
    _ "github.com/mattn/go-oci8" // for sql
    "github.com/sue445/plant_erd/db"
)

// Adapter represents Oracle adapter
type Adapter struct {
    db *sqlx.DB
}

// Close represents function for close database
type Close func() error

// NewAdapter returns a new Adapter instance
func NewAdapter(config *Config) (*Adapter, Close, error) {
    db, err := sqlx.Connect("oci8", config.FormatDSN())

    if err != nil {
        return nil, nil, err
    }

    return &Adapter{db: db}, db.Close, nil
}

// GetAllTableNames returns all table names in database
func (a *Adapter) GetAllTableNames() ([]string, error) {
    var rows []allTables
    // c.f. https://github.com/rsim/oracle-enhanced/blob/v6.0.0/lib/active_record/connection_adapters/oracle_enhanced/schema_statements.rb#L15
    err := a.db.Select(&rows, `
        SELECT DECODE(table_name, UPPER(table_name), LOWER(table_name), table_name) AS table_name
        FROM all_tables
        WHERE owner = SYS_CONTEXT('userenv', 'current_schema')
        AND secondary = 'N'
        minus
        SELECT DECODE(mview_name, UPPER(mview_name), LOWER(mview_name), mview_name) AS table_name
        FROM all_mviews
        WHERE owner = SYS_CONTEXT('userenv', 'current_schema')
    `)

    if err != nil {
        return []string{}, err
    }

    var tables []string
    for _, row := range rows {
        tables = append(tables, row.TableName)
    }

    return tables, nil
}

// GetTable returns table info
func (a *Adapter) GetTable(tableName string) (*db.Table, error) {
    table := db.Table{
        Name: tableName,
    }

    primaryKeyColumns, err := a.getPrimaryKeyColumns(tableName)
    if err != nil {
        return nil, err
    }

    sql := `
        SELECT COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE
        FROM ALL_TAB_COLUMNS
        WHERE TABLE_NAME = UPPER(?)
        AND owner = SYS_CONTEXT('userenv', 'current_schema')
    `
    stmt, err := a.db.Preparex(a.db.Rebind(sql))
    if err != nil {
        return nil, err
    }

    var rows []allTabColumns
    err = stmt.Select(&rows, tableName)
    defer stmt.Close()

    if err != nil {
        return nil, err
    }

    for _, row := range rows {
        column := &db.Column{
            Name:       row.ColumnName,
            Type:       row.FormatColumnType(),
            NotNull:    row.Nullable == "N",
            PrimaryKey: primaryKeyColumns.Contains(row.ColumnName),
        }
        table.Columns = append(table.Columns, column)
    }

    foreignKeys, err := a.getForeignKeys(tableName)
    if err != nil {
        return nil, err
    }
    table.ForeignKeys = foreignKeys

    indexes, err := a.getIndexes(tableName)
    if err != nil {
        return nil, err
    }
    table.Indexes = indexes

    return &table, nil
}

func (a *Adapter) getPrimaryKeyColumns(tableName string) (mapset.Set, error) {
    // c.f. https://github.com/rsim/oracle-enhanced/blob/v6.0.0/lib/active_record/connection_adapters/oracle_enhanced_adapter.rb#L612
    sql := `
        SELECT cc.column_name
        FROM all_constraints c, all_cons_columns cc
        WHERE c.owner = SYS_CONTEXT('userenv', 'current_schema')
        AND c.table_name = UPPER(?)
        AND c.constraint_type = 'P'
        AND cc.owner = c.owner
        AND cc.constraint_name = c.constraint_name
        order by cc.position
    `

    stmt, err := a.db.Preparex(a.db.Rebind(sql))
    if err != nil {
        return nil, err
    }

    var rows []primaryKeys
    err = stmt.Select(&rows, tableName)
    defer stmt.Close()

    if err != nil {
        return nil, err
    }

    columns := mapset.NewSet()
    for _, row := range rows {
        columns.Add(row.ColumnName)
    }

    return columns, nil
}

func (a *Adapter) getForeignKeys(tableName string) ([]*db.ForeignKey, error) {
    // c.f. https://github.com/rsim/oracle-enhanced/blob/v6.0.0/lib/active_record/connection_adapters/oracle_enhanced/schema_statements.rb#L544
    sql := `
            SELECT r.table_name to_table
                  ,rc.column_name references_column
                  ,cc.column_name
              FROM all_constraints c, all_cons_columns cc,
                   all_constraints r, all_cons_columns rc
             WHERE c.owner = SYS_CONTEXT('userenv', 'current_schema')
               AND c.table_name = UPPER(?)
               AND c.constraint_type = 'R'
               AND cc.owner = c.owner
               AND cc.constraint_name = c.constraint_name
               AND r.constraint_name = c.r_constraint_name
               AND r.owner = c.owner
               AND rc.owner = r.owner
               AND rc.constraint_name = r.constraint_name
               AND rc.position = cc.position
            ORDER BY to_table, column_name, references_column
    `

    stmt, err := a.db.Preparex(a.db.Rebind(sql))
    if err != nil {
        return nil, err
    }

    var rows []foreignKey
    err = stmt.Select(&rows, tableName)
    defer stmt.Close()

    if err != nil {
        return nil, err
    }

    var foreignKeys []*db.ForeignKey

    for _, row := range rows {
        foreignKey := &db.ForeignKey{
            FromColumn: row.ColumnName,
            ToColumn:   row.ReferencesColumn,
            ToTable:    row.ToTable,
        }
        foreignKeys = append(foreignKeys, foreignKey)
    }

    return foreignKeys, nil
}

func (a *Adapter) getIndexes(tableName string) ([]*db.Index, error) {
    // c.f. https://github.com/rsim/oracle-enhanced/blob/v6.0.0/lib/active_record/connection_adapters/oracle_enhanced/schema_statements.rb#L91
    sql := `
        SELECT index_name, uniqueness
        FROM all_indexes i
        WHERE owner = SYS_CONTEXT('userenv', 'current_schema') 
        AND table_owner = SYS_CONTEXT('userenv', 'current_schema') 
        AND table_name = UPPER(?)
        AND NOT EXISTS (
            SELECT uc.index_name
            FROM all_constraints uc
            WHERE uc.index_name = i.index_name AND uc.owner = i.owner AND uc.constraint_type = 'P'
        )
        ORDER BY table_name
    `

    stmt, err := a.db.Preparex(a.db.Rebind(sql))
    if err != nil {
        return nil, err
    }

    var rows []allIndexes
    err = stmt.Select(&rows, tableName)
    defer stmt.Close()

    if err != nil {
        return nil, err
    }
    var indexes []*db.Index
    for _, row := range rows {
        columns, err := a.getIndexColumns(row.IndexName)
        if err != nil {
            return nil, err
        }

        index := &db.Index{
            Name:    row.IndexName,
            Unique:  row.Uniqueness == "UNIQUE",
            Columns: columns,
        }
        indexes = append(indexes, index)
    }

    return indexes, nil
}

func (a *Adapter) getIndexColumns(indexName string) ([]string, error) {
    // c.f. https://github.com/rsim/oracle-enhanced/blob/v6.0.0/lib/active_record/connection_adapters/oracle_enhanced/schema_statements.rb#L91
    sql := "SELECT column_name FROM all_ind_columns WHERE index_name = ? ORDER BY column_position"

    stmt, err := a.db.Preparex(a.db.Rebind(sql))
    if err != nil {
        return nil, err
    }

    var rows []allIndColumns
    err = stmt.Select(&rows, indexName)
    defer stmt.Close()

    if err != nil {
        return nil, err
    }

    var columns []string

    for _, row := range rows {
        columns = append(columns, row.ColumnName)
    }

    return columns, nil
}