simonmittag/jabba

View on GitHub
config.go

Summary

Maintainability
D
1 day
Test Coverage
B
88%
package j8a

import (
    "bytes"
    "encoding/json"
    "errors"
    "fmt"
    "golang.org/x/net/idna"
    "io/ioutil"
    "net"
    "os"
    "sort"
    "strconv"
    "strings"
    "text/template"
    "time"

    "github.com/asaskevich/govalidator"
    "github.com/ghodss/yaml"
    isd "github.com/jbenet/go-is-domain"
    "github.com/rs/zerolog"
    "github.com/rs/zerolog/log"
)

// Config is the system wide configuration for j8a
type Config struct {
    Policies            map[string]Policy
    Routes              Routes
    Jwt                 map[string]*Jwt
    Resources           map[string][]ResourceMapping
    Connection          Connection
    DisableXRequestInfo bool
    TimeZone            string
    LogLevel            string
}

const HTTP = "HTTP"
const TLS = "TLS"
const J8ACFG_YML = "J8ACFG_YML"

var ConfigFile = ""

const DefaultConfigFile = "j8acfg.yml"

// load loads the configuration from a file or environment variables.
// If `ConfigFile` is set, it attempts to load the configuration from the specified file.
// If `ConfigFile` is empty and `J8ACFG_YML` environment variable is set,
// it attempts to load the configuration from the environment variable.
// If both `ConfigFile` and `J8ACFG_YML` are empty,
// it attempts to load the configuration from the default file `DefaultConfigFile`.
// It returns a pointer to the loaded configuration.
// If an error occurs while loading the configuration, it panics with an error message.
func (config Config) load() *Config {
    if len(ConfigFile) > 0 {
        log.Info().Msgf("attempting config from file '%s'", ConfigFile)
        if cfg := *config.readYmlFile(ConfigFile); cfg.loaded() {
            config = cfg
        } else {
            config.panic(fmt.Sprintf("error loading config from file '%s'", ConfigFile))
        }
    } else if len(os.Getenv(J8ACFG_YML)) > 0 {
        log.Info().Msgf("attempting config from env %s", J8ACFG_YML)
        if cfg := *config.readYmlEnv(); cfg.loaded() {
            config = cfg
        } else {
            config.panic(fmt.Sprintf("error loading config from env '%s'", J8ACFG_YML))
        }
    } else {
        log.Info().Msgf("attempting config from default file '%s'", DefaultConfigFile)
        if cfg := config.readYmlFile(DefaultConfigFile); cfg.loaded() {
            config = *cfg
        } else {
            config.panic(fmt.Sprintf("error loading config from default file '%s'", DefaultConfigFile))
        }
    }
    log.Info().Msgf("parsed config with %d live routes", config.Routes.Len())
    return &config
}

// Wrapper method for config to panic during loading. We never
// want invalid config.
func (config Config) panic(msg string) {
    msg = "unable to validate config. " + msg
    panic(msg)
}

func (config Config) readYmlEnv() *Config {
    conf := []byte(os.Getenv(J8ACFG_YML))
    return config.parse(conf)
}

// Returns true if config was loaded, but does not test if it's valid.
func (config Config) loaded() bool {
    return config.Routes.Len() > 0
}

func (config Config) readYmlFile(file string) *Config {
    f, err := os.Open(file)
    defer f.Close()
    if err != nil {
        msg := fmt.Sprintf("unable to load config from %s", file)
        config.panic(msg)
    }
    byteValue, _ := ioutil.ReadAll(f)
    return config.parse(byteValue)
}

func (config Config) parse(yml []byte) *Config {
    envMap := envToMap()
    configTemplate := template.New("config").Option("missingkey=error")
    configTemplate, _ = configTemplate.Parse(string(yml[:]))
    var configTpl bytes.Buffer
    renderingErr := configTemplate.Execute(&configTpl, envMap)
    if renderingErr != nil {
        config.panic("unable to parse config, cause: " + renderingErr.Error())
    }
    yml, _ = ioutil.ReadAll(&configTpl)
    jsn, _ := yaml.YAMLToJSON(yml)

    json.Unmarshal(jsn, &config)
    return &config
}

func (config Config) validateLogLevel() *Config {
    logLevel := strings.ToUpper(config.LogLevel)
    old := strings.ToUpper(zerolog.GlobalLevel().String())

    if len(logLevel) > 0 && logLevel != old {
        switch logLevel {
        case "TRACE", "DEBUG", "INFO", "WARN":
        default:
            config.panic(fmt.Sprintf("invalid log level %v must be one of TRACE | DEBUG | INFO | WARN ", logLevel))
        }
    }

    return &config
}

func (config Config) validateTimeZone() *Config {
    var tz *time.Location
    var e error
    if len(config.TimeZone) > 0 {
        tz, e = time.LoadLocation(config.TimeZone)
        if e != nil {
            config.panic(fmt.Sprintf("Not a valid TimeZone identifier %s, see: https://en.wikipedia.org/wiki/List_of_tz_database_time_zones", config.TimeZone))
        }
    } else {
        //we default to UTC if time wasn't specified
        tz = time.UTC
    }

    zerolog.TimestampFunc = func() time.Time {
        return time.Now().In(tz)
    }
    log.Info().Msgf("timeZone for this log and all system events set to %s", tz.String())

    return &config
}

func (config Config) validateResources() *Config {
    for name := range config.Resources {
        resourceMappings := config.Resources[name]
        if len(resourceMappings) == 0 {
            config.panic(fmt.Sprintf("resource '%v' needs to have at least one url, see https://j8a.io/docs", name))
        }
        for _, r := range resourceMappings {
            iPort, e := strconv.Atoi(r.URL.Port)
            if e != nil {
                config.panic(fmt.Sprintf("resource '%v' needs to have port between 1 and 65535, was: %v", name, r.URL.Port))
            }
            if iPort <= 1 || iPort > 65535 {
                config.panic(fmt.Sprintf("resource '%v' needs to have port between 1 and 65535, was: %v", name, r.URL.Port))
            }
            if len(r.URL.Host) == 0 {
                config.panic("resource needs to have host")
            } else {
                ie := validIpAddress(r)
                he := validHostName(r)
                if ie != nil && he != nil {
                    config.panic(fmt.Sprintf("resource '%v' host needs to be valid DNS name or IP address, was: %v", name, r.URL.Host))
                }
            }

            sm := "resource '%v' needs to have valid scheme, was: %v"
            if len(r.URL.Scheme) == 0 {
                config.panic(fmt.Sprintf(sm, name, r.URL.Scheme))
            } else if !validScheme(r.URL.Scheme) {
                config.panic(fmt.Sprintf(sm, name, r.URL.Scheme))
            }
        }
    }
    return &config
}

func validScheme(s string) bool {
    s = strings.TrimSpace(s)
    s = strings.TrimSuffix(s, "://")
    s = strings.ToLower(s)

    schemes := [4]string{"http", "https", "ws", "wss"}
    for _, v := range schemes {
        if v == s {
            return true
        }
    }
    return false
}

// validIpAddress checks if the given IP address is a valid IPv4 or IPv6 address.
// It returns an error if the address is invalid.
func validIpAddress(r ResourceMapping) error {
    defaultErr := errors.New(fmt.Sprintf("invalid ipv4 or ipv6 address: %v", r.URL.Host))

    h := r.URL.Host
    h = strings.TrimPrefix(h, "[")
    h = strings.TrimSuffix(h, "]")
    hasBrackets := h != r.URL.Host

    p := net.ParseIP(h)
    if p == nil || len(p) == 0 {
        return defaultErr
    } else if p.To4() != nil && hasBrackets {
        return defaultErr
    }
    return nil
}

// validHostName checks if the given host name in the ResourceMapping URL is a valid DNS name.
// It returns an error if the host name is invalid.
// It uses the golang.org/x/net/idna package to validate and convert the host name into Unicode and ASCII representations.
// The host name should not contain a wildcard, a port, or any invalid characters.
// The function also validates the host name according to the govalidator.IsDNSName function.
// If any validation fails, an error is returned.
func validHostName(r ResourceMapping) error {
    invalidErr := errors.New(fmt.Sprintf("resource host needs a valid DNS name: %v", r.URL.Host))
    p := idna.New(
        idna.ValidateLabels(true),
        //this has to be off it disallows * for registration
        //idna.ValidateForRegistration(),
        idna.StrictDomainName(true))
    _, err := p.ToUnicode(r.URL.Host)
    if err != nil {
        return invalidErr
    }
    a, err := p.ToASCII(r.URL.Host)
    if err != nil {
        return invalidErr
    }
    if a != r.URL.Host {
        return invalidErr
    }
    if strings.Contains(a, "*") {
        return errors.New(fmt.Sprintf("resource host cannot be a wildcard DNS name: %v", r.URL.Host))
    }
    if strings.Contains(a, ":") {
        return errors.New(fmt.Sprintf("resource host cannot contain port: %v", r.URL.Host))
    }
    if !govalidator.IsDNSName(a) {
        return invalidErr
    }
    return nil
}

func (config Config) reApplyResourceNames() *Config {
    for name := range config.Resources {
        resourceMappings := config.Resources[name]
        for i, resourceMapping := range resourceMappings {
            resourceMapping.Name = name
            resourceMappings[i] = resourceMapping
        }
    }
    return &config
}

var routePathTypes = NewRoutePathTypes()

const prefixS = "prefix"

func (config Config) validateRoutes() *Config {
    for i, _ := range config.Routes {
        v, e := config.Routes[i].validPath()
        if !v {
            config.panic(e.Error())
        }
        if config.Routes[i].hasJwt() {
            if _, ok := config.Jwt[config.Routes[i].Jwt]; !ok {
                config.panic(fmt.Sprintf("route [%s] jwt [%s] not found, check your configuration", config.Routes[i].Path, config.Routes[i].Jwt))
            }
        }
        if len(config.Routes[i].PathType) == 0 {
            config.Routes[i].PathType = prefixS
        } else {
            if !routePathTypes.isValid(config.Routes[i].PathType) {
                config.panic(fmt.Sprintf("path type %s invalid, not one of ['prefix', 'exact']", config.Routes[i].PathType))
            }
        }
        if len(config.Routes[i].Host) > 0 {
            if v2, e2 := config.Routes[i].validHostPattern(); !v2 {
                config.panic(fmt.Sprintf("host pattern %s invalid, cause %v", config.Routes[i].Host, e2))
            }
        }
        if len(config.Routes[i].Resource) == 0 {
            config.panic(fmt.Sprintf("route %s must have a resource", config.Routes[i].Path))
        } else {
            res := config.Routes[i].Resource
            if res != about {
                _, ok := config.Resources[res]
                if !ok {
                    config.panic(fmt.Sprintf("route %s must have a resource, but %s is not declared", config.Routes[i].Path, res))
                }
            }
        }
    }
    sort.Sort(Routes(config.Routes))
    return &config
}

func (config Config) compileRoutePaths() *Config {
    var err error
    for i, route := range config.Routes {
        err = route.compilePath()
        if err != nil {
            config.panic(fmt.Sprintf("config error, illegal route path %s", route.Path))
        } else {
            config.Routes[i] = route
        }
    }
    return &config
}

func (config Config) compileRouteHosts() *Config {
    var err error
    for i, route := range config.Routes {
        if len(route.Host) > 0 {
            err = route.compileHostPattern()
            if err != nil {
                config.panic(fmt.Sprintf("config error, illegal route host pattern %s", route.Host))
            } else {
                config.Routes[i] = route
            }
        }
    }
    return &config
}

func (config Config) compileRouteTransforms() *Config {
    for i, route := range config.Routes {
        if len(config.Routes[i].Transform) > 0 && config.Routes[i].Transform[:1] != "/" {
            config.panic(fmt.Sprintf("config error, illegal route transform %s", route.Transform))
        }
    }
    return &config
}

func (config Config) reformatResourceUrlSchemes() *Config {
    for name := range config.Resources {
        resourceMappings := config.Resources[name]
        for i, resourceMapping := range resourceMappings {
            scheme := resourceMapping.URL.Scheme
            scheme = strings.TrimSpace(scheme)
            scheme = strings.ToLower(scheme)
            scheme = strings.TrimSuffix(scheme, "://")
            scheme = strings.TrimSuffix(scheme, ":/")
            scheme = strings.TrimSuffix(scheme, ":")
            resourceMappings[i].URL.Scheme = scheme
        }
    }
    return &config
}

func (config Config) reApplyResourceURLDefaults() *Config {
    const http = "http"
    const https = "https"
    const p80S = "80"
    const p443S = "443"
    for name := range config.Resources {
        resourceMappings := config.Resources[name]
        for i, resourceMapping := range resourceMappings {
            scheme := resourceMapping.URL.Scheme
            if len(scheme) == 0 {
                scheme = http
            } else {
                if strings.Contains(scheme, https) {
                    scheme = https
                } else if strings.Contains(scheme, http) {
                    scheme = http
                }
            }
            resourceMappings[i].URL.Scheme = scheme

            //insert port here if not specified
            if len(resourceMapping.URL.Port) == 0 {
                if scheme == http {
                    resourceMappings[i].URL.Port = p80S
                }
                if scheme == https {
                    resourceMappings[i].URL.Port = p443S
                }
            }
        }
    }
    return &config
}

func (config Config) validateHTTPConfig() *Config {
    if !(config.isHTTPOn() || config.isTLSOn()) {
        config.panic("must have either http or tls enabled")
    }

    if config.isHTTPOn() && config.isTLSOn() {
        if config.Connection.Downstream.Http.Port == config.Connection.Downstream.Tls.Port {
            config.panic("connection downstream http and tls port must be different")
        }
    }

    if config.isHTTPOn() {
        if config.Connection.Downstream.Http.Port < 1 ||
            config.Connection.Downstream.Http.Port > 65535 {
            config.panic(fmt.Sprintf("connection downstream http port must be between 1 and 65535, was: %v",
                config.Connection.Downstream.Http.Port))
        }
    }

    if config.isTLSOn() {
        if config.Connection.Downstream.Tls.Port < 1 ||
            config.Connection.Downstream.Tls.Port > 65535 {
            config.panic(fmt.Sprintf("connection downstream tls port must be between 1 and 65535, was: %v",
                config.Connection.Downstream.Tls.Port))
        }
    }

    if !config.isTLSOn() &&
        config.Connection.Downstream.Http.Redirecttls == true {
        config.panic("cannot redirect to TLS if not properly configured.")
    }

    return &config
}

const wildcardDomainPrefix = "*."
const dot = "."

func (config Config) validateAcmeConfig() *Config {
    acmeProvider := len(config.Connection.Downstream.Tls.Acme.Provider) > 0
    acmeDomain := len(config.Connection.Downstream.Tls.Acme.Domains) > 0 && len(config.Connection.Downstream.Tls.Acme.Domains[0]) > 0
    acmeEmail := len(config.Connection.Downstream.Tls.Acme.Email) > 0

    if acmeProvider || acmeDomain || acmeEmail {
        if config.Connection.Downstream.Tls.Acme.GracePeriodDays == 0 {
            config.Connection.Downstream.Tls.Acme.GracePeriodDays = 30
        } else if config.Connection.Downstream.Tls.Acme.GracePeriodDays > 30 {
            config.panic("ACME grace period must be between 1 and 30 days")
        }

        if len(config.Connection.Downstream.Tls.Cert) > 0 {
            config.panic("cannot specify TLS cert with ACME configuration, it would be overridden.")
        }

        if len(config.Connection.Downstream.Tls.Key) > 0 {
            config.panic("cannot specify TLS private key with ACME configuration, it would be overridden.")
        }

        if config.Connection.Downstream.Http.Port != 80 {
            config.panic("HTTP listener must be configured and set to port 80 for ACME challenge")
        }

        if !acmeProvider {
            config.panic("ACME provider must be specified in ACME config")
        }

        if !acmeDomain {
            config.panic("ACME domain must be specified in ACME config")
        }

        if !acmeEmail {
            config.panic("ACME email must be specified in ACME config")
        }

        // ACME domain checks
        for _, domain := range config.Connection.Downstream.Tls.Acme.Domains {
            if !govalidator.IsDNSName(domain) {
                config.panic(fmt.Sprintf("ACME domain must be a valid DNS name, but was %s", domain))
            }

            if !isd.IsDomain(domain) {
                config.panic(fmt.Sprintf("ACME domain must be a valid FQDN name, but was %s", domain))
            }

            if domain[0:1] == wildcardDomainPrefix {
                config.panic(fmt.Sprintf("ACME domain validation does not support wildcard domain names, was %s", domain))
            }

            if string(domain[0]) == dot {
                config.panic(fmt.Sprintf("ACME domain validation does not support domains starting with '.', was %s", domain))
            }

            if strings.HasSuffix(domain, dot) {
                config.panic(fmt.Sprintf("ACME domain validation does not support domains ending with '.', was %s", domain))
            }
        }

        // ACME provider checks
        if _, supported := acmeProviders[config.Connection.Downstream.Tls.Acme.Provider]; !supported {
            config.panic(fmt.Sprintf("ACME provider not supported: %s", config.Connection.Downstream.Tls.Acme.Provider))
        }

        // ACME email checks
        if !govalidator.IsEmail(config.Connection.Downstream.Tls.Acme.Email) {
            config.panic(fmt.Sprintf("ACME email must be a valid email address, but was %s", config.Connection.Downstream.Tls.Acme.Email))
        }

        log.Info().Msgf("By using the ACME provider %s you agree to the provider terms of service (%s)", acmeProviders[config.Connection.Downstream.Tls.Acme.Provider].friendlyName, acmeProviders[config.Connection.Downstream.Tls.Acme.Provider].tosURL)

    }

    return &config
}

func (config Config) addDefaultPolicy() *Config {
    defaultPolicy := new(Policy)
    lw := LabelWeight{
        Label:  "default",
        Weight: 1.0,
    }
    var labelWeights []LabelWeight
    labelWeights = append(labelWeights, lw)
    *defaultPolicy = labelWeights
    if config.Policies == nil {
        config.Policies = make(map[string]Policy)
    }
    config.Policies["default"] = *defaultPolicy
    return &config
}

func (config Config) setDefaultDownstreamParams() *Config {

    if config.Connection.Downstream.ReadTimeoutSeconds == 0 {
        config.Connection.Downstream.ReadTimeoutSeconds = 5
    }

    if config.Connection.Downstream.RoundTripTimeoutSeconds == 0 {
        config.Connection.Downstream.RoundTripTimeoutSeconds = 10
    }

    if config.Connection.Downstream.IdleTimeoutSeconds == 0 {
        config.Connection.Downstream.IdleTimeoutSeconds = 5
    }

    if config.Connection.Downstream.MaxBodyBytes == 0 {
        //set to 2MB default value
        config.Connection.Downstream.MaxBodyBytes = 2 << 20
    }

    if !config.isHTTPOn() {
        config.Connection.Downstream.Http.Redirecttls = false
    }

    return &config
}

func (config Config) isTLSOn() bool {
    return config.Connection.Downstream.Tls.Port > 0
}

func (config Config) isHTTPOn() bool {
    return config.Connection.Downstream.Http.Port > 0
}

func (config Config) setDefaultUpstreamParams() *Config {

    if config.Connection.Upstream.SocketTimeoutSeconds == 0 {
        config.Connection.Upstream.SocketTimeoutSeconds = 3
    }
    if config.Connection.Upstream.ReadTimeoutSeconds == 0 {
        config.Connection.Upstream.ReadTimeoutSeconds = 10
    }
    if config.Connection.Upstream.IdleTimeoutSeconds == 0 {
        config.Connection.Upstream.IdleTimeoutSeconds = 120
    }
    if config.Connection.Upstream.PoolSize == 0 {
        config.Connection.Upstream.PoolSize = 32768
    }
    if config.Connection.Upstream.MaxAttempts == 0 {
        config.Connection.Upstream.MaxAttempts = 1
    }
    return &config
}

func (config Config) validateJwt() *Config {
    if len(config.Jwt) > 0 {
        for name, jwt := range config.Jwt {
            //update name on resource
            jwt.Name = name
            jwt.Init()
            err := jwt.Validate()
            if err != nil {
                config.panic(err.Error())
            }
            config.Jwt[name] = jwt
        }
        log.Info().Msgf("parsed %d jwt configurations", len(config.Jwt))
    }
    return &config
}

func (config Config) getDownstreamRoundTripTimeoutDuration() time.Duration {
    return time.Duration(time.Second * time.Duration(config.Connection.Downstream.RoundTripTimeoutSeconds))
}

func envToMap() map[string]string {
    envMap := make(map[string]string)

    for _, v := range os.Environ() {
        split_v := strings.SplitN(v, "=", 2)
        envMap[split_v[0]] = split_v[1]
    }

    return envMap
}