synapsecns/sanguine

View on GitHub
contrib/screener-api/db/suite_test.go

Summary

Maintainability
A
0 mins
Test Coverage
package db_test

import (
    dbSQL "database/sql"
    "fmt"
    "os"
    "sync"
    "testing"

    "github.com/synapsecns/sanguine/contrib/screener-api/db"
    "github.com/synapsecns/sanguine/contrib/screener-api/db/sql"
    "github.com/synapsecns/sanguine/contrib/screener-api/db/sql/mysql"
    "github.com/synapsecns/sanguine/contrib/screener-api/metadata"

    "github.com/Flaque/filet"
    . "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/suite"
    "github.com/synapsecns/sanguine/core"
    "github.com/synapsecns/sanguine/core/dbcommon"
    "github.com/synapsecns/sanguine/core/metrics"
    "github.com/synapsecns/sanguine/core/metrics/localmetrics"
    "github.com/synapsecns/sanguine/core/testsuite"
    "gorm.io/gorm/schema"
)

type DBSuite struct {
    *testsuite.TestSuite
    dbs     []db.DB
    metrics metrics.Handler
}

// NewDBSuite creates a new DBSuite.
func NewDBSuite(tb testing.TB) *DBSuite {
    tb.Helper()
    return &DBSuite{
        TestSuite: testsuite.NewTestSuite(tb),
        dbs:       []db.DB{},
    }
}
func (d *DBSuite) SetupSuite() {
    d.TestSuite.SetupSuite()

    // don't use metrics on ci for integration tests
    isCI := core.GetEnvBool("CI", false)
    useMetrics := !isCI
    metricsHandler := metrics.Null

    if useMetrics {
        localmetrics.SetupTestJaeger(d.GetSuiteContext(), d.T())
        metricsHandler = metrics.Jaeger
    }

    var err error
    d.metrics, err = metrics.NewByType(d.GetSuiteContext(), metadata.BuildInfo(), metricsHandler)
    Nil(d.T(), err)
}

func (d *DBSuite) SetupTest() {
    d.TestSuite.SetupTest()

    sqliteStore, err := sql.Connect(d.GetTestContext(), dbcommon.Sqlite, filet.TmpDir(d.T(), ""), d.metrics)
    Nil(d.T(), err)

    d.dbs = []db.DB{sqliteStore}
    d.setupMysqlDB()
}

func (d *DBSuite) setupMysqlDB() {
    if os.Getenv(dbcommon.EnableMysqlTestVar) != "true" {
        return
    }

    mysql.NamingStrategy = schema.NamingStrategy{
        TablePrefix: fmt.Sprintf("api_%d", d.GetTestID()),
    }

    // sets up the conn string to the default database
    connString := dbcommon.GetTestConnString()
    // sets up the myqsl db
    testDB, err := dbSQL.Open("mysql", connString)
    d.Require().NoError(err)
    // close the db once the connection is don
    defer func() {
        d.Require().NoError(testDB.Close())
    }()

    mysqlStore, err := mysql.NewMysqlStore(d.GetTestContext(), connString, d.metrics)
    d.Require().NoError(err)

    d.dbs = append(d.dbs, mysqlStore)
}

func (d *DBSuite) RunOnAllDBs(testFunc func(testDB db.DB)) {
    d.T().Helper()

    wg := sync.WaitGroup{}
    for _, testDB := range d.dbs {
        wg.Add(1)
        // capture the value
        go func(testDB db.DB) {
            defer wg.Done()
            testFunc(testDB)
        }(testDB)
    }
    wg.Wait()
}

func TestDBSuite(t *testing.T) {
    suite.Run(t, NewDBSuite(t))
}