
View on GitHub


35 mins
Test Coverage
package dbcommon_test

import (
    . ""

var testDBLogger = log.Logger("dbcommon")

// TestEnum tests the default providers for an enum type.
func (s DbSuite) TestEnum() {
    res, err := RunEnumExample(s.GetTestContext(), fmt.Sprintf("%s/sql.db", filet.TmpDir(s.T(), "")))
    Nil(s.T(), err)

    for i, fruit := range AllFruits {
        Equal(s.T(), fruit.Int(), res[i].Fruit.Int())

    Equal(s.T(), len(res), len(AllFruits))

// ExampleEnum demonstrates example use of the enum interface.
// this implementation can be confusing, so there's an example below.
func ExampleEnum() {
    res, err := RunEnumExample(context.Background(), fmt.Sprintf("%s/sql.db", os.TempDir()))
    if err != nil {

    for _, res := range res {
        fmt.Printf("got result %s \n", res.Fruit.String())

// RunEnumExample is used to separate out tests from the example.
func RunEnumExample(ctx context.Context, dbDir string) (res []InventoryModel, err error) {
    gdb, err := gorm.Open(sqlite.Open(dbDir), &gorm.Config{
        Logger: dbcommon.GetGormLogger(testDBLogger),
    if err != nil {
        return res, fmt.Errorf("could not open db: %w", err)

    // migrate the inventory model
    err = gdb.WithContext(ctx).AutoMigrate(&InventoryModel{})
    if err != nil {
        return res, fmt.Errorf("could not migrate db: %w", err)

    for _, fruit := range AllFruits {
        tx := gdb.WithContext(ctx).Create(&InventoryModel{
            Fruit: fruit,

        if tx.Error != nil {
            return res, fmt.Errorf("could not insert fruit: %w", err)

    tx := gdb.WithContext(ctx).Find(&res)
    if tx.Error != nil {
        return res, fmt.Errorf("could not query db: %w", err)

    return res, nil

// InventoryModel is an example model for of an inventory table for fruit.
type InventoryModel struct {
    // fruit is the fruit we're storing
    Fruit Fruit

// you should use ints rather than iota's when interacting with the database.
const (
    // Apple is an example implementing enum.
    Apple Fruit = 0
    // Pear is a n example implementing enum.
    Pear Fruit = 1

var AllFruits = []Fruit{Apple, Pear}

type Fruit uint8

// String gets a string of the enum
// in a production setting, generater should be used.
// see: for details
func (f Fruit) String() string {
    switch f {
    case Apple:
        return "Apple"
    case Pear:
        return "Pear"
    return ""

// Int get the integer value of the fruit.
func (f Fruit) Int() uint8 {
    return uint8(f)

// GormDataType is the gorm data type.
func (f Fruit) GormDataType() string {
    return dbcommon.EnumDataType

// Scan will scan the fruit into the db.
func (f *Fruit) Scan(src interface{}) error {
    res, err := dbcommon.EnumScan(src)
    if err != nil {
        return fmt.Errorf("could not scan: %w", err)
    newFruit := Fruit(res)
    *f = newFruit
    return nil

// nolint: wrapcheck
func (f *Fruit) Value() (driver.Value, error) {
    return dbcommon.EnumValue(f)

var _ dbcommon.EnumInter = (*Fruit)(nil)

type testEnum uint8

func (t testEnum) Int() uint8 {
    return uint8(t)

const (
    testEnumValue1 testEnum = 1
    testEnumValue2 testEnum = 2

func TestEnumValue(t *testing.T) {
    tests := []struct {
        name    string
        enum    dbcommon.EnumInter
        want    int64
        wantErr error
            name: "Valid enum value",
            enum: testEnumValue1,
            want: 1,
            name: "Valid enum value",
            enum: testEnumValue2,
            want: 2,

    for i := range tests {
        tt := tests[i]
        t.Run(, func(t *testing.T) {
            got, err := dbcommon.EnumValue(tt.enum)
            if tt.wantErr != nil {
                ErrorIs(t, err, tt.wantErr)
            } else {
                Nil(t, err)
                Equal(t, tt.want, got)

func TestEnumScan(t *testing.T) {
    tests := []struct {
        name    string
        src     interface{}
        want    uint8
        wantErr string
            name: "Valid int64 value",
            src:  int64(1),
            want: 1,
            name: "Valid int32 value",
            src:  int32(2),
            want: 2,
            name:    "Invalid type",
            src:     "invalid",
            want:    0,
            wantErr: "could not scan enum: converting driver.Value type string (\"invalid\") to a int32: invalid syntax",

    for i := range tests {
        tt := tests[i]
        t.Run(, func(t *testing.T) {
            got, err := dbcommon.EnumScan(tt.src)
            if tt.wantErr != "" {
                Error(t, err)
                EqualError(t, err, tt.wantErr)
            } else {
                NoError(t, err)
                Equal(t, tt.want, got)