oauth2-proxy/oauth2-proxy

View on GitHub
pkg/apis/options/load.go

Summary

Maintainability
A
1 hr
Test Coverage
A
90%
package options

import (
    "errors"
    "fmt"
    "os"
    "reflect"
    "strings"

    "github.com/a8m/envsubst"
    "github.com/ghodss/yaml"
    "github.com/mitchellh/mapstructure"
    "github.com/spf13/pflag"
    "github.com/spf13/viper"
)

// Load reads in the config file at the path given, then merges in environment
// variables (prefixed with `OAUTH2_PROXY`) and finally merges in flags from the flagSet.
// If a config value is unset and the flag has a non-zero value default, this default will be used.
// Eg. A field defined:
//
//    FooBar `cfg:"foo_bar" flag:"foo-bar"`
//
// Can be set in the config file as `foo_bar="baz"`, in the environment as `OAUTH2_PROXY_FOO_BAR=baz`,
// or via the command line flag `--foo-bar=baz`.
func Load(configFileName string, flagSet *pflag.FlagSet, into interface{}) error {
    v := viper.New()
    v.SetConfigFile(configFileName)
    v.SetConfigType("toml") // Config is in toml format
    v.SetEnvPrefix("OAUTH2_PROXY")
    v.AutomaticEnv()
    v.SetTypeByDefaultValue(true)

    if configFileName != "" {
        err := v.ReadInConfig()
        if err != nil {
            return fmt.Errorf("unable to load config file: %w", err)
        }
    }

    err := registerFlags(v, "", flagSet, into)
    if err != nil {
        // This should only happen if there is a programming error
        return fmt.Errorf("unable to register flags: %w", err)
    }

    // UnmarshalExact will return an error if the config includes options that are
    // not mapped to fields of the into struct
    err = v.UnmarshalExact(into, decodeFromCfgTag)
    if err != nil {
        return fmt.Errorf("error unmarshalling config: %w", err)
    }

    return nil
}

// registerFlags uses `cfg` and `flag` tags to associate flags in the flagSet
// to the fields in the options interface provided.
// Each exported field in the options must have a `cfg` tag otherwise an error will occur.
// - For fields, set `cfg` and `flag` so that `flag` is the name of the flag associated to this config option
// - For exported fields that are not user facing, set the `cfg` to `,internal`
// - For structs containing user facing fields, set the `cfg` to `,squash`
func registerFlags(v *viper.Viper, prefix string, flagSet *pflag.FlagSet, options interface{}) error {
    val := reflect.ValueOf(options)
    var typ reflect.Type
    if val.Kind() == reflect.Ptr {
        typ = val.Elem().Type()
    } else {
        typ = val.Type()
    }

    for i := 0; i < typ.NumField(); i++ {
        // pull out the struct tags:
        //    flag - the name of the command line flag
        //    cfg - the name of the config file option
        field := typ.Field(i)
        fieldV := reflect.Indirect(val).Field(i)
        fieldName := strings.Join([]string{prefix, field.Name}, ".")

        cfgName := field.Tag.Get("cfg")
        if cfgName == ",internal" {
            // Public but internal types that should not be exposed to users, skip them
            continue
        }

        if isUnexported(field.Name) {
            // Unexported fields cannot be set by a user, so won't have tags or flags, skip them
            continue
        }

        if field.Type.Kind() == reflect.Struct {
            if cfgName != ",squash" {
                return fmt.Errorf("field %q does not have required cfg tag: `,squash`", fieldName)
            }
            err := registerFlags(v, fieldName, flagSet, fieldV.Interface())
            if err != nil {
                return err
            }
            continue
        }

        flagName := field.Tag.Get("flag")
        if flagName == "" || cfgName == "" {
            return fmt.Errorf("field %q does not have required tags (cfg, flag)", fieldName)
        }

        if flagSet == nil {
            return fmt.Errorf("flagset cannot be nil")
        }

        f := flagSet.Lookup(flagName)
        if f == nil {
            return fmt.Errorf("field %q does not have a registered flag", flagName)
        }
        err := v.BindPFlag(cfgName, f)
        if err != nil {
            return fmt.Errorf("error binding flag for field %q: %w", fieldName, err)
        }
    }

    return nil
}

// decodeFromCfgTag sets the Viper decoder to read the names from the `cfg` tag
// on each struct entry.
func decodeFromCfgTag(c *mapstructure.DecoderConfig) {
    c.TagName = "cfg"
}

// isUnexported checks if a field name starts with a lowercase letter and therefore
// if it is unexported.
func isUnexported(name string) bool {
    if len(name) == 0 {
        // This should never happen
        panic("field name has len 0")
    }

    first := string(name[0])
    return first == strings.ToLower(first)
}

// LoadYAML will load a YAML based configuration file into the options interface provided.
func LoadYAML(configFileName string, into interface{}) error {
    buffer, err := loadAndParseYaml(configFileName)
    if err != nil {
        return err
    }

    // UnmarshalStrict will return an error if the config includes options that are
    // not mapped to fields of the into struct
    if err := yaml.UnmarshalStrict(buffer, into, yaml.DisallowUnknownFields); err != nil {
        return fmt.Errorf("error unmarshalling config: %w", err)
    }

    return nil
}

// Performs the heavy lifting of the LoadYaml function
func loadAndParseYaml(configFileName string) ([]byte, error) {
    if configFileName == "" {
        return nil, errors.New("no configuration file provided")
    }

    unparsedBuffer, err := os.ReadFile(configFileName)
    if err != nil {
        return nil, fmt.Errorf("unable to load config file: %w", err)
    }

    // We now parse over the yaml with env substring, and fill in the ENV's
    buffer, err := envsubst.Bytes(unparsedBuffer)
    if err != nil {
        return nil, fmt.Errorf("error in substituting env variables : %w", err)
    }

    return buffer, nil

}