nikoksr/proji

View on GitHub
pkg/package/store/store.go

Summary

Maintainability
B
4 hrs
Test Coverage
package packagestore

import (
    "context"
    "database/sql"
    "sync"
    "time"

    "github.com/nikoksr/proji/pkg/domain"
    "github.com/pkg/errors"
    "golang.org/x/sync/errgroup"
    "gopkg.in/guregu/null.v4"
    "gorm.io/gorm"
    "gorm.io/gorm/clause"
)

type packageStore struct {
    db *gorm.DB
}

func New(db *gorm.DB) domain.PackageStore {
    return &packageStore{
        db: db,
    }
}

func (ps packageStore) StorePackage(pkg *domain.Package) error {
    // Check if package exists
    err := ps.db.Where("label = ?", pkg.Label).First(pkg).Error
    if err == nil {
        return ErrPackageExists
    }
    if err != gorm.ErrRecordNotFound {
        return err
    }

    tx := ps.db.Begin()
    defer func() {
        if r := recover(); r != nil {
            tx.Rollback()
        }
    }()
    if err := tx.Error; err != nil {
        return err
    }

    err = tx.Omit(clause.Associations).Create(pkg).Error
    if err != nil {
        tx.Rollback()
        return errors.Wrap(err, "insert package")
    }

    err = storeTemplates(tx, pkg.Templates, pkg.ID)
    if err != nil {
        tx.Rollback()
        return err
    }

    err = storePlugins(tx, pkg.Plugins, pkg.ID)
    if err != nil {
        tx.Rollback()
        return err
    }

    return tx.Commit().Error
}

func storeTemplates(tx *gorm.DB, templates []*domain.Template, packageID uint) error {
    insertTemplateStmt := "INSERT OR IGNORE INTO templates (created_at, updated_at, is_file, destination, path, description) VALUES (?, ?, ?, ?, ?, ?)"
    insertAssociationStmt := "INSERT OR IGNORE INTO package_templates (package_id, template_id) VALUES (?, ?)"
    queryIDStmt := "SELECT id from templates WHERE destination = ? AND path = ?"
    for _, template := range templates {
        now := time.Now()
        err := tx.Exec(insertTemplateStmt, now, now, template.IsFile, template.Destination, template.Path, template.Description).Error
        if err != nil {
            return err
        }

        var id null.Int
        err = tx.Raw(queryIDStmt, template.Destination, template.Path).Row().Scan(&id)
        if err != nil {
            return err
        }
        if !id.Valid {
            return err
        }
        template.ID = uint(id.Int64)

        err = tx.Exec(insertAssociationStmt, packageID, template.ID).Error
        if err != nil {
            return err
        }
    }
    return nil
}

func storePlugins(tx *gorm.DB, plugins []*domain.Plugin, packageID uint) error {
    var err error
    insertPluginStmt := "INSERT OR IGNORE INTO plugins (created_at, updated_at, path, exec_number, description) VALUES (?, ?, ?, ?, ?)"
    insertAssociationStmt := "INSERT OR IGNORE INTO package_plugins (package_id, plugin_id) VALUES (?, ?)"
    queryIDStmt := "SELECT id from plugins WHERE path = ?"
    for _, plugin := range plugins {
        now := time.Now()
        err = tx.Exec(insertPluginStmt, now, now, plugin.Path, plugin.ExecNumber, plugin.Description).Error
        if err != nil {
            return err
        }

        var id null.Int
        err = tx.Raw(queryIDStmt, plugin.Path).Row().Scan(&id)
        if err != nil {
            return err
        }
        if !id.Valid {
            return err
        }
        plugin.ID = uint(id.Int64)

        err = tx.Exec(insertAssociationStmt, packageID, plugin.ID).Error
        if err != nil {
            return err
        }
    }
    return nil
}

func (ps packageStore) LoadPackage(loadDependencies bool, label string) (*domain.Package, error) {
    conditions := "WHERE label = ?"
    if loadDependencies {
        conditions = "WHERE packages.label = ?"
    }
    return ps.loadPackage(loadDependencies, conditions, label)
}

func (ps packageStore) loadPackage(loadDependencies bool, conditions string, values ...string) (*domain.Package, error) {
    if loadDependencies {
        return ps.deepQueryPackage(conditions, values...)
    }
    return ps.queryPackage(conditions, values...)
}

func (ps packageStore) LoadPackageList(loadDependencies bool, labels ...string) ([]*domain.Package, error) {
    var err error
    labelCount := len(labels)
    if labelCount < 1 {
        labels, err = ps.queryAllLabels()
        labelCount = len(labels)
        if err != nil {
            return nil, err
        }
    }
    packageList := make([]*domain.Package, 0, labelCount)
    mut := sync.Mutex{}
    errs, _ := errgroup.WithContext(context.Background())
    for _, label := range labels {
        label := label
        errs.Go(func() error {
            pkg, err := ps.LoadPackage(loadDependencies, label)
            if err != nil {
                return err
            }
            mut.Lock()
            packageList = append(packageList, pkg)
            mut.Unlock()
            return nil
        })
    }
    return packageList, errs.Wait()
}

const (
    defaultPackageQueryBase     = `SELECT name, label, description FROM packages`
    defaultPackageDeepQueryBase = `SELECT
    packages.name,
    packages.label,
    packages.description as package_description,
    templates.is_file,
    templates.destination,
    templates."path" as template_path,
    templates.description as template_description,
    plugins."path" as plugin_path,
    plugins.exec_number,
    plugins.description as plugin_description
    FROM packages
LEFT OUTER JOIN package_templates
    ON packages.id = package_templates.package_id
LEFT OUTER JOIN templates
    ON package_templates.template_id = templates.id
LEFT OUTER JOIN package_plugins
    ON packages.id = package_plugins.package_id
LEFT OUTER JOIN plugins
    ON package_plugins.plugin_id = plugins.id`
)

// loadAllPackages loads and returns all packages found in the database.
func (ps packageStore) queryPackage(conditions string, values ...string) (*domain.Package, error) {
    var name, label string
    var description null.String
    err := ps.db.Raw(defaultPackageQueryBase+" "+conditions, values).Row().Scan(&name, &label, &description)
    if err == sql.ErrNoRows {
        return nil, ErrPackageNotFound
    }
    if err != nil {
        return nil, err
    }
    return &domain.Package{Name: name, Label: label, Description: description.String}, nil
}

func (ps packageStore) deepQueryPackage(conditions string, values ...string) (pkg *domain.Package, err error) {
    rows, err := ps.db.Raw(defaultPackageDeepQueryBase+" "+conditions, values).Rows()
    if err != nil {
        return nil, err
    }

    defer func() {
        e := rows.Close()
        if e != nil {
            if err != nil {
                err = errors.Wrap(err, e.Error())
            } else {
                err = e
            }
        }
    }()

    pkg = &domain.Package{}
    var gotPkgInfo bool
    for rows.Next() {
        var (
            packageName, packageLabel string
            packageDescription        null.String

            templateIsFile                                         null.Bool
            templateDestination, templatePath, templateDescription null.String

            pluginPath, pluginDescription null.String
            pluginExecNumber              null.Int
        )

        err = rows.Scan(
            &packageName,
            &packageLabel,
            &packageDescription,
            &templateIsFile,
            &templateDestination,
            &templatePath,
            &templateDescription,
            &pluginPath,
            &pluginExecNumber,
            &pluginDescription,
        )
        if err != nil {
            return nil, err
        }

        if !gotPkgInfo {
            pkg.Name = packageName
            pkg.Label = packageLabel
            pkg.Description = packageDescription.String
            gotPkgInfo = true
        }
        if templateIsFile.Valid && templateDestination.Valid && templatePath.Valid {
            pkg.Templates = append(pkg.Templates, &domain.Template{
                IsFile:      templateIsFile.Bool,
                Destination: templateDestination.String,
                Path:        templatePath.String,
                Description: templateDescription.String,
            })
        }
        if pluginPath.Valid && pluginExecNumber.Valid {
            pkg.Plugins = append(pkg.Plugins, &domain.Plugin{
                Path:        pluginPath.String,
                ExecNumber:  int(pluginExecNumber.Int64),
                Description: pluginDescription.String,
            })
        }
    }
    if rows.Err() != nil {
        return nil, rows.Err()
    }
    if !gotPkgInfo {
        return nil, ErrPackageNotFound
    }
    return pkg, nil
}

func (ps packageStore) queryAllLabels() ([]string, error) {
    rows, err := ps.db.Raw("SELECT label FROM packages").Rows()
    if err != nil {
        return nil, err
    }
    if rows.Err() != nil {
        return nil, err
    }

    defer func() {
        e := rows.Close()
        if e != nil {
            if err != nil {
                err = errors.Wrap(err, e.Error())
            } else {
                err = e
            }
        }
    }()

    var labels []string
    for rows.Next() {
        var label string
        err = rows.Scan(&label)
        if err != nil {
            return nil, err
        }
        labels = append(labels, label)
    }
    return labels, nil
}

func (ps packageStore) RemovePackage(label string) error {
    tx := ps.db.Begin()
    defer func() {
        if r := recover(); r != nil {
            tx.Rollback()
        }
    }()

    var pkg domain.Package
    tx = tx.Select("id").Where("label = ?", label).First(&pkg)
    if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
        tx.Rollback()
        return ErrPackageNotFound
    }
    if tx.Error != nil {
        tx.Rollback()
        return tx.Error
    }

    err := tx.Exec("DELETE FROM package_templates WHERE package_id = ?", pkg.ID).Error
    if err != nil {
        tx.Rollback()
        return err
    }

    // Delete the actual package
    err = tx.Delete(pkg).Error
    if errors.Is(tx.Error, gorm.ErrRecordNotFound) || tx.RowsAffected < 1 {
        tx.Rollback()
        return ErrPackageNotFound
    }
    if err != nil {
        tx.Rollback()
        return err
    }
    return tx.Commit().Error
}