synapsecns/sanguine

View on GitHub
core/dbcommon/enum_test.go

Summary

Maintainability
A
35 mins
Test Coverage
package dbcommon_test

import (
    "context"
    "database/sql/driver"
    "fmt"
    "github.com/Flaque/filet"
    "github.com/ipfs/go-log"
    . "github.com/stretchr/testify/assert"
    "github.com/synapsecns/sanguine/core/dbcommon"
    "gorm.io/driver/sqlite"
    "gorm.io/gorm"
    "os"
    "testing"
)

var testDBLogger = log.Logger("dbcommon")

// TestEnum tests the default providers for an enum type.
func (s DbSuite) TestEnum() {
    res, err := RunEnumExample(s.GetTestContext(), fmt.Sprintf("%s/sql.db", filet.TmpDir(s.T(), "")))
    Nil(s.T(), err)

    for i, fruit := range AllFruits {
        Equal(s.T(), fruit.Int(), res[i].Fruit.Int())
    }

    Equal(s.T(), len(res), len(AllFruits))
}

// ExampleEnum demonstrates example use of the enum interface.
// this implementation can be confusing, so there's an example below.
func ExampleEnum() {
    res, err := RunEnumExample(context.Background(), fmt.Sprintf("%s/sql.db", os.TempDir()))
    if err != nil {
        panic(err)
    }

    for _, res := range res {
        fmt.Printf("got result %s \n", res.Fruit.String())
    }
}

// RunEnumExample is used to separate out tests from the example.
func RunEnumExample(ctx context.Context, dbDir string) (res []InventoryModel, err error) {
    gdb, err := gorm.Open(sqlite.Open(dbDir), &gorm.Config{
        Logger: dbcommon.GetGormLogger(testDBLogger),
    })
    if err != nil {
        return res, fmt.Errorf("could not open db: %w", err)
    }

    // migrate the inventory model
    err = gdb.WithContext(ctx).AutoMigrate(&InventoryModel{})
    if err != nil {
        return res, fmt.Errorf("could not migrate db: %w", err)
    }

    for _, fruit := range AllFruits {
        tx := gdb.WithContext(ctx).Create(&InventoryModel{
            Fruit: fruit,
        })

        if tx.Error != nil {
            return res, fmt.Errorf("could not insert fruit: %w", err)
        }
    }

    tx := gdb.WithContext(ctx).Find(&res)
    if tx.Error != nil {
        return res, fmt.Errorf("could not query db: %w", err)
    }

    return res, nil
}

// InventoryModel is an example model for of an inventory table for fruit.
type InventoryModel struct {
    gorm.Model
    // fruit is the fruit we're storing
    Fruit Fruit
}

// you should use ints rather than iota's when interacting with the database.
const (
    // Apple is an example implementing enum.
    Apple Fruit = 0
    // Pear is a n example implementing enum.
    Pear Fruit = 1
)

var AllFruits = []Fruit{Apple, Pear}

type Fruit uint8

// String gets a string of the enum
// in a production setting, generater should be used.
// see: https://pkg.go.dev/golang.org/x/tools/cmd/stringer for details
func (f Fruit) String() string {
    switch f {
    case Apple:
        return "Apple"
    case Pear:
        return "Pear"
    }
    return ""
}

// Int get the integer value of the fruit.
func (f Fruit) Int() uint8 {
    return uint8(f)
}

// GormDataType is the gorm data type.
func (f Fruit) GormDataType() string {
    return dbcommon.EnumDataType
}

// Scan will scan the fruit into the db.
func (f *Fruit) Scan(src interface{}) error {
    res, err := dbcommon.EnumScan(src)
    if err != nil {
        return fmt.Errorf("could not scan: %w", err)
    }
    newFruit := Fruit(res)
    *f = newFruit
    return nil
}

// nolint: wrapcheck
func (f *Fruit) Value() (driver.Value, error) {
    return dbcommon.EnumValue(f)
}

var _ dbcommon.EnumInter = (*Fruit)(nil)

type testEnum uint8

func (t testEnum) Int() uint8 {
    return uint8(t)
}

const (
    testEnumValue1 testEnum = 1
    testEnumValue2 testEnum = 2
)

func TestEnumValue(t *testing.T) {
    tests := []struct {
        name    string
        enum    dbcommon.EnumInter
        want    int64
        wantErr error
    }{
        {
            name: "Valid enum value",
            enum: testEnumValue1,
            want: 1,
        },
        {
            name: "Valid enum value",
            enum: testEnumValue2,
            want: 2,
        },
    }

    for i := range tests {
        tt := tests[i]
        t.Run(tt.name, func(t *testing.T) {
            got, err := dbcommon.EnumValue(tt.enum)
            if tt.wantErr != nil {
                ErrorIs(t, err, tt.wantErr)
            } else {
                Nil(t, err)
                Equal(t, tt.want, got)
            }
        })
    }
}

func TestEnumScan(t *testing.T) {
    tests := []struct {
        name    string
        src     interface{}
        want    uint8
        wantErr string
    }{
        {
            name: "Valid int64 value",
            src:  int64(1),
            want: 1,
        },
        {
            name: "Valid int32 value",
            src:  int32(2),
            want: 2,
        },
        {
            name:    "Invalid type",
            src:     "invalid",
            want:    0,
            wantErr: "could not scan enum: converting driver.Value type string (\"invalid\") to a int32: invalid syntax",
        },
    }

    for i := range tests {
        tt := tests[i]
        t.Run(tt.name, func(t *testing.T) {
            got, err := dbcommon.EnumScan(tt.src)
            if tt.wantErr != "" {
                Error(t, err)
                EqualError(t, err, tt.wantErr)
            } else {
                NoError(t, err)
                Equal(t, tt.want, got)
            }
        })
    }
}