simonmittag/p0d

View on GitHub
config.go

Summary

Maintainability
B
6 hrs
Test Coverage
C
73%
package p0d

import (
    "crypto/tls"
    "encoding/json"
    "fmt"
    "github.com/ghodss/yaml"
    . "github.com/logrusorgru/aurora"
    "golang.org/x/net/context"
    "golang.org/x/net/http2"
    "io/ioutil"
    "math"
    "net"
    "net/http"
    "net/url"
    "os"
    "strconv"
    "strings"
    "time"
)

type Config struct {
    Req  Req
    Res  Res
    Exec Exec
    File string
}

type Req struct {
    Method        string
    Url           string
    Headers       []map[string]string
    ContentType   string
    Body          string
    FormData      []map[string]string
    FormDataFiles map[string][]byte
    Ips           []net.IP
}

type Res struct {
    Code int
}

type Exec struct {
    Mode               string
    DurationSeconds    int
    RampSeconds        int
    Concurrency        int
    DialTimeoutSeconds int64
    LogSampling        float64
    SpacingMillis      int64
    HttpVersion        float32
    SkipInetTest       bool
}

const UNLIMITED int = -1

const http11 = 1.1
const http20 = 2

var httpVers = map[float32]float32{http11: http11, http20: http20}

func loadConfigFromFile(fileName string) *Config {
    cfgPanic := func(err error) {
        if err != nil {
            msg := Red(fmt.Sprintf("unable to load config from '%s', exiting...", fileName))
            logv(msg)
            os.Exit(-1)
        }
    }

    f, err := os.Open(fileName)
    defer f.Close()
    cfgPanic(err)

    yml, _ := ioutil.ReadAll(f)
    jsn, _ := yaml.YAMLToJSON(yml)

    c := &Config{}
    err = json.Unmarshal(jsn, c)
    cfgPanic(err)
    return c
}

func (cfg *Config) validate() *Config {
    //we always want this.
    cfg.Req.Method = strings.ToUpper(cfg.Req.Method)
    if cfg.Req.Method == "" {
        cfg.Req.Method = "GET"
    }

    if cfg.Exec.Concurrency == 0 {
        cfg.Exec.Concurrency = 1
    }
    if cfg.Exec.DurationSeconds == 0 {
        cfg.Exec.DurationSeconds = 10
    } else {
        if cfg.Exec.DurationSeconds < 3 {
            cfg.panic("duration cannot be less than 3 seconds")
        }
    }

    if cfg.Exec.RampSeconds == 0 {
        cfg.Exec.RampSeconds = int(math.Ceil(float64(cfg.Exec.DurationSeconds) / 10))
    } else {
        if float64(cfg.Exec.RampSeconds) > (float64(cfg.Exec.DurationSeconds) / 2) {
            cfg.panic("ramp time cannot be longer than half the duration")
        }
    }
    if cfg.Exec.DialTimeoutSeconds == 0 {
        cfg.Exec.DialTimeoutSeconds = 3
    }
    if cfg.Exec.Mode == "" {
        cfg.Exec.Mode = "binary"
    }
    if cfg.Exec.HttpVersion == 0 {
        cfg.Exec.HttpVersion = http11
    } else {
        if _, ok := httpVers[cfg.Exec.HttpVersion]; !ok {
            hv := fmt.Sprintf("%.1f", cfg.Exec.HttpVersion)
            cfg.panic(fmt.Sprintf("bad http version %s, must be one of [1.1, 2.0], exiting...", hv))
        }
    }
    if cfg.Exec.LogSampling < 0 || cfg.Exec.LogSampling > 1 {
        //default to none
        cfg.Exec.LogSampling = 0
    }
    if cfg.Exec.SpacingMillis < 0 {
        //default to all
        cfg.Exec.SpacingMillis = 0
    }

    cfg.validateReqBody()

    if cfg.Req.Url == "" {
        cfg.panic("request url not specified")
    } else {
        u, e := url.Parse(cfg.Req.Url)
        if e != nil {
            cfg.panic(e.Error())
        }
        h, p, _ := net.SplitHostPort(u.Host)
        if h == "" {
            h = u.Host
        }
        if len(p) > 0 {
            p1, e2 := strconv.Atoi(p)
            if e2 != nil {
                cfg.panic(e.Error())
            }
            if p1 < 0 || p1 > 65535 {
                cfg.panic(fmt.Sprintf("valid port range is [0-65535], yours: %d", p1))
            }
        }

        cfg.Req.Ips = make([]net.IP, 1)
        cfg.Req.Ips[0] = net.ParseIP(h)

        var e3 error
        if cfg.Req.Ips[0] == nil {
            cfg.Req.Ips, e3 = net.LookupIP(h)
        }
        if e3 == nil && len(cfg.Req.Ips) > 0 {
            if isPrivateIP(cfg.Req.Ips[0]) {
                //we disable inet test for all local targets.
                cfg.Exec.SkipInetTest = true
            }
        }
    }

    if cfg.Res.Code == 0 {
        cfg.Res.Code = 200
    }
    return cfg
}

func (cfg *Config) getRemotePort() uint16 {
    u, _ := url.Parse(cfg.Req.Url)
    _, p, _ := net.SplitHostPort(u.Host)
    if p == N {
        if u.Scheme == "http" {
            p = "80"
        }
        if u.Scheme == "https" {
            p = "443"
        }
    }
    p1, _ := strconv.Atoi(p)
    return uint16(p1)
}

func (cfg *Config) validateReqBody() {
    if len(cfg.Req.Body) > 0 {
        if len(cfg.Req.FormData) > 0 {
            cfg.panic("when specifying request body, cannot have form data")
        }
    }

    if len(cfg.Req.FormData) > 0 {
        if cfg.Req.Method != "POST" {
            cfg.panic("when specifying form data, method must be POST")
        }
        if len(cfg.Req.Body) > 0 {
            cfg.panic("when specifying form data, cannot specify body, use formData param")
        }
        cfg.setDefaultFormDataContentType()

        cfg.Req.FormDataFiles = make(map[string][]byte, 0)
        for i, fd := range cfg.Req.FormData {
            for k, v := range fd {
                if strings.HasPrefix(k, "@") {
                    dat, err := os.ReadFile(v)
                    if err != nil {
                        cfg.panic(fmt.Sprintf("unable to read file: %s", v))
                    }
                    cfg.Req.FormDataFiles[k] = dat
                    f, _ := os.Open(v)
                    fs, _ := f.Stat()
                    cfg.Req.FormData[i] = map[string]string{k: fs.Name()}
                    dat = nil
                }
            }
        }
    } else if contains(bodyTypes, cfg.Req.Method) {
        cfg.setDefaultPostContentType()
    }
}

func (cfg *Config) setDefaultPostContentType() {
    cfg.setContentType("application/json", false)
}

func (cfg *Config) setDefaultFormDataContentType() {
    cfg.setContentType("application/x-www-form-urlencoded", false)
}

func (cfg *Config) setContentType(contentType string, overwrite bool) {
    const ctkey = "Content-Type"
    ctobj := map[string]string{ctkey: contentType}

    if len(cfg.Req.Headers) > 0 {
        matched := false
        for i, h := range cfg.Req.Headers {
            for k, v := range h {
                if k == ctkey {
                    matched = true
                    if overwrite {
                        cfg.Req.Headers[i] = ctobj
                        cfg.Req.ContentType = contentType
                    } else {
                        cfg.Req.Headers[i][ct] = v
                        cfg.Req.ContentType = v
                    }
                }
            }
        }
        if !matched {
            cfg.Req.Headers = append(cfg.Req.Headers, ctobj)
            cfg.Req.ContentType = contentType
        }
    } else {
        cfg.Req.Headers = append(cfg.Req.Headers, ctobj)
        cfg.Req.ContentType = contentType
    }
}

func (cfg *Config) hasContentType(contentType string) bool {
    for _, h := range cfg.Req.Headers {
        for k, v := range h {
            if k == "Content-Type" {
                if v != contentType {
                    return false
                }
            }
        }
    }
    if cfg.Req.ContentType != contentType {
        return false
    }
    return true
}

func (cfg Config) scaffoldHttpClients() map[int]*http.Client {
    cs := make(map[int]*http.Client, cfg.Exec.Concurrency)

    //we want to bypass connection re-use inside the pool for multiple parallel requests using streams in http/2
    //but also to get a clean 1x conn == 1x goroutine mapping for http/1.1 during execution, so we create multiple
    //clients limited to one connection. each goroutine can recover if a connection dies.
    for i := 0; i < cfg.Exec.Concurrency; i++ {
        c := cfg.scaffoldHttpClient(1)
        cs[i] = c
    }
    return cs
}

const httpIdleTimeout = time.Duration(1) * time.Second

func (cfg Config) scaffoldHttpClient(max int) *http.Client {
    return cfg.scaffoldHttpClientWith(max, false, nil)
}

func (cfg Config) scaffoldHttpClientWith(max int, connSpy bool, pod *P0d) *http.Client {
    tlsc := &tls.Config{
        MinVersion:         tls.VersionTLS11,
        MaxVersion:         tls.VersionTLS12,
        InsecureSkipVerify: true,
    }

    t := &http.Transport{
        DisableCompression: true,
        DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
            to := time.Duration(cfg.Exec.DialTimeoutSeconds) * time.Second
            c1, e := net.DialTimeout(network, addr, to)
            if connSpy {
                pod.sampleConn = c1
            }
            return c1, e
        },
        DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
            nd := net.Dialer{
                Timeout: time.Duration(cfg.Exec.DialTimeoutSeconds) * time.Second,
            }
            c1, e := tls.DialWithDialer(&nd, network, addr, tlsc)
            if connSpy {
                pod.sampleConn = c1
            }
            return c1, e
        },
        //TLS handshake timeout is the same as connection timeout
        TLSHandshakeTimeout: time.Duration(cfg.Exec.DialTimeoutSeconds) * time.Second,
        TLSClientConfig:     tlsc,
        MaxConnsPerHost:     max,
        MaxIdleConns:        max,
        MaxIdleConnsPerHost: max,
        IdleConnTimeout:     httpIdleTimeout,
        TLSNextProto:        make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
    }

    //see https://stackoverflow.com/questions/57683132/turning-off-connection-pool-for-go-http-client
    if cfg.Exec.Concurrency == UNLIMITED {
        t.DisableKeepAlives = true
        logv(Yellow("transport connection pool disabled for http/1.1"))
    }

    if cfg.Exec.HttpVersion == http20 {
        http2.ConfigureTransport(t)
    }

    return &http.Client{
        Transport: t,
    }
}

func (cfg Config) byteCount(b int64) string {
    f := "%2s"
    switch strings.TrimSpace(cfg.Exec.Mode) {
    case "decimal":
        return fmt.Sprintf(f, ByteCountSI(b))
    case "binary":
        return fmt.Sprintf(f, ByteCountIEC(b))
    default:
        return fmt.Sprintf(f, ByteCountIEC(b))
    }
}

func (cfg Config) panic(msg string) {
    logv(Red(msg))
    os.Exit(-1)
}