orkungursel/go-eco

View on GitHub
eco.go

Summary

Maintainability
A
0 mins
Test Coverage
A
100%
package eco

import (
    "errors"
    "fmt"
    "os"
    "reflect"
    "strconv"
    "strings"
)

type eco struct {
    sliceSeparator        string
    envNameSeparator      string
    envNamePrefix         string
    envNamePrefixAutoTrim bool
    envNameTransformer    envNameTransformerFunc
    envValueGetter        envValueGetterFunc
    tagNameEnv            string
    tagNameDefault        string
    tagSkipIdentifier     string
}

// New returns a new Eco instance.
func New() *eco {
    return &eco{
        sliceSeparator:        ",",
        envNameSeparator:      "_",
        envNamePrefixAutoTrim: true,
        envNameTransformer:    defaultEnvNameTransformerFunc,
        envValueGetter:        os.Getenv,
        tagNameEnv:            "env",
        tagNameDefault:        "default",
        tagSkipIdentifier:     "-",
    }
}

// SetPrefix sets the prefix for the environment variable names.
// It will be trimmed from both sides if it is not empty.
func (e *eco) SetPrefix(prefix string) *eco {
    e.envNamePrefix = strings.TrimSpace(prefix)
    return e
}

// getPrefix returns the prefix for the environment variable names.
// If auto trim is enabled, it will be trimmed from both sides.
func (e *eco) getPrefix() string {
    prefix := e.envNamePrefix
    if e.envNamePrefixAutoTrim {
        prefix = strings.TrimRight(prefix, e.envNameSeparator)
    }
    return prefix
}

// SetPrefixAutoTrim enables or disables auto trimming of the prefix.
// Default is true.
func (e *eco) SetPrefixAutoTrim(autoTrim bool) *eco {
    e.envNamePrefixAutoTrim = autoTrim
    return e
}

// SetArraySeparator sets the separator for array values.
// Default is ",".
func (e *eco) SetArraySeparator(arrSep string) *eco {
    if arrSep != "" {
        e.sliceSeparator = arrSep
    }
    return e
}

// SetEnvNameTransformer sets the function for transforming the environment variable names.
func (e *eco) SetEnvNameTransformer(transformerFunc envNameTransformerFunc) *eco {
    e.envNameTransformer = transformerFunc
    return e
}

// SetEnvNameSeparator sets the separator for the environment variable names.
func (e *eco) SetEnvNameSeparator(envNameSeparator string) *eco {
    if envNameSeparator != "" {
        e.envNameSeparator = envNameSeparator
    }
    return e
}

// SetValueGetter sets the function for getting the environment variable values.
func (e *eco) SetValueGetter(valueGetter envValueGetterFunc) *eco {
    if valueGetter != nil {
        e.envValueGetter = valueGetter
    }
    return e
}

// Unmarshal takes a pointer to a struct and unmarshals the environment variables to the struct.
func (e *eco) Unmarshal(v interface{}) error {
    if v == nil {
        return ErrRequiresNonNilPtr
    }

    rt := reflect.TypeOf(v)
    if rt.Kind() != reflect.Ptr {
        return ErrRequiresNonNilPtr
    }

    rv := reflect.ValueOf(v)
    if rv.IsNil() {
        return ErrRequiresNonNilPtr
    }

    var p []string
    if prefix := e.getPrefix(); prefix != "" {
        p = append(p, prefix)
    }

    return e.bindStructValues(v, p...)
}

// bindStructValues binds the environment variables to the given struct.
func (e *eco) bindStructValues(s interface{}, envNameParts ...string) error {
    sr := e.getStructReflection(s)

    for i := 0; i < sr.Type().NumField(); i++ {
        field := sr.Field(i)
        typeField := sr.Type().Field(i)

        // Skip unexported fields
        if !typeField.IsExported() {
            continue
        }

        tags := typeField.Tag
        isPtr := field.Type().Kind() == reflect.Ptr
        isStruct := typeField.Type.Kind() == reflect.Struct

        envTagValue, ok := tags.Lookup(e.tagNameEnv)
        if !ok || envTagValue == "" {
            // If field "env" tag is not provided, get tag from struct field name
            envTagValue = typeField.Name
        }

        envTagValue = toSnakeCase(envTagValue)

        p := envNameParts

        // if tag value is "-", skip this field
        // when looking for the environment variable name
        skip := envTagValue == e.tagSkipIdentifier
        if !skip {
            p = append(p, envTagValue)
        }

        // sanitize env variable name using the envNameFunc
        envKey := e.envNameTransformer(p, e.envNameSeparator)

        var envVal string
        var err error

        // get value from env
        envVal = e.envValueGetter(envKey)

        // if value is empty, get default value from tag
        if envVal == "" {
            defaultValue, ok := tags.Lookup(e.tagNameDefault)
            if ok {
                envVal = defaultValue
            }
        }

        // if field is a pointer, create
        if isPtr {
            if typeField.Type.Elem().Kind() != reflect.Struct && envVal == "" {
                continue
            }

            // if field is nil, create a new one with the type of the field
            if field.IsNil() {
                field.Set(reflect.New(field.Type().Elem()))
            }

            // check whether the field element kind
            // is a struct, then bind its values
            if field.Elem().Kind() == reflect.Struct {
                if err := e.bindStructValues(field.Interface(), p...); err != nil {
                    return err
                }

                continue
            }
        } else if isStruct { // if field is a struct, bind it
            if err := e.bindStructValues(field.Addr().Interface(), p...); err != nil {
                return err
            }

            continue
        }

        // if value is empty, skip binding
        if envVal == "" {
            continue
        }

        // convert string value which comes from env to the type of the field
        val, err := e.convertStrToFieldVal(sr, i, envVal)
        if err != nil {
            return errors.New(err.Error() + ": " + envKey)
        }

        // set field value
        if isPtr {
            field.Elem().Set(val)
        } else {
            field.Set(val)
        }
    }

    return nil
}

// getStructReflection returns the reflection of the given struct.
func (e *eco) getStructReflection(s interface{}) reflect.Value {
    sv := reflect.ValueOf(s)

    if sv.Type().Kind() == reflect.Ptr {
        sv = sv.Elem()
    }

    return sv
}

// convertStrToFieldVal converts the given string value to the type of the field.
func (e *eco) convertStrToFieldVal(ref reflect.Value, index int, val interface{}) (out reflect.Value, err error) {
    field := ref.Field(index)
    typeField := ref.Type().Field(index)
    isPtr := typeField.Type.Kind() == reflect.Ptr

    kind := field.Kind()
    if isPtr {
        kind = field.Elem().Kind()
    }

    switch kind {
    case reflect.String:
        val = val.(string)
    case reflect.Int:
        val, err = strconv.Atoi(val.(string))
    case reflect.Uint:
        val, err = strconv.ParseUint(val.(string), 10, 0)
        val = uint(val.(uint64))
    case reflect.Int64:
        val, err = strconv.ParseInt(val.(string), 10, 64)
    case reflect.Uint64:
        val, err = strconv.ParseUint(val.(string), 10, 64)
    case reflect.Int32:
        val, err = strconv.ParseInt(val.(string), 10, 32)
        val = int32(val.(int64))
    case reflect.Uint32:
        val, err = strconv.ParseUint(val.(string), 10, 32)
        val = uint32(val.(uint64))
    case reflect.Int16:
        val, err = strconv.ParseInt(val.(string), 10, 16)
        val = int16(val.(int64))
    case reflect.Uint16:
        val, err = strconv.ParseUint(val.(string), 10, 16)
        val = uint16(val.(uint64))
    case reflect.Int8:
        val, err = strconv.ParseInt(val.(string), 10, 8)
        val = int8(val.(int64))
    case reflect.Uint8:
        val, err = strconv.ParseUint(val.(string), 10, 8)
        val = uint8(val.(uint64))
    case reflect.Float32:
        val, err = strconv.ParseFloat(val.(string), 32)
        val = float32(val.(float64))
    case reflect.Float64:
        val, err = strconv.ParseFloat(val.(string), 64)
    case reflect.Bool:
        val, err = strconv.ParseBool(val.(string))
    case reflect.Slice:
        switch typeField.Type.Elem().Kind() {
        case reflect.String:
            val = strings.Split(val.(string), e.sliceSeparator)
            vals := make([]string, len(val.([]string)))
            for i, v := range val.([]string) {
                vals[i] = strings.TrimSpace(v)
            }
            val = vals
        case reflect.Int:
            val = strings.Split(val.(string), e.sliceSeparator)
            var vals []int
            for _, v := range val.([]string) {
                i, err := strconv.Atoi(strings.TrimSpace(v))
                if err != nil {
                    return reflect.ValueOf(nil), err
                }
                vals = append(vals, i)
            }
            val = vals
        case reflect.Int64:
            val = strings.Split(val.(string), e.sliceSeparator)
            var vals []int64
            for _, v := range val.([]string) {
                i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64)
                if err != nil {
                    return reflect.ValueOf(nil), err
                }
                vals = append(vals, i)
            }
            val = vals
        case reflect.Float32:
            val = strings.Split(val.(string), e.sliceSeparator)
            var vals []float32
            for _, v := range val.([]string) {
                f, err := strconv.ParseFloat(strings.TrimSpace(v), 32)
                if err != nil {
                    return reflect.ValueOf(nil), err

                }
                vals = append(vals, float32(f))
            }
            val = vals
        case reflect.Float64:
            val = strings.Split(val.(string), e.sliceSeparator)
            var vals []float64
            for _, v := range val.([]string) {
                f, err := strconv.ParseFloat(strings.TrimSpace(v), 64)
                if err != nil {
                    return reflect.ValueOf(nil), err
                }
                vals = append(vals, float64(f))
            }
            val = vals
        default:
            return reflect.ValueOf(nil), fmt.Errorf("unsupported slice type: %s", typeField.Type.Elem().Kind())
        }
    }

    return reflect.ValueOf(val), err
}