fabiocicerchia/go-proxy-cache

View on GitHub
server/balancer/balancer.go

Summary

Maintainability
A
2 hrs
Test Coverage
package balancer

//                                                                         __
// .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----|  |--.-----.
// |  _  |  _  |______|  _  |   _|  _  |_   _|  |  |______|  __|  _  |  __|     |  -__|
// |___  |_____|      |   __|__| |_____|__.__|___  |      |____|___._|____|__|__|_____|
// |_____|            |__|                   |_____|
//
// Copyright (c) 2023 Fabio Cicerchia. https://fabiocicerchia.it. MIT License
// Repo: https://github.com/fabiocicerchia/go-proxy-cache

import (
    "crypto/tls"
    "fmt"
    "net"
    "net/http"
    "net/url"
    "strconv"
    "time"

    "github.com/fabiocicerchia/go-proxy-cache/config"
    "github.com/fabiocicerchia/go-proxy-cache/logger"
    "github.com/fabiocicerchia/go-proxy-cache/telemetry"
    "github.com/fabiocicerchia/go-proxy-cache/telemetry/metrics"
    "github.com/fabiocicerchia/go-proxy-cache/utils/slice"
)

const lBIpHash = "ip-hash"
const lBLeastConnections = "least-connections"
const lBRandom = "random"
const lBRoundRobin = "round-robin"
const enableHealthchecks = true
const defaultClientTimeout = 5 * time.Second

func initLB() {
    if len(lb) == 0 {
        lb = make(LoadBalancing)
    }
}

func convertEndpoints(endpoints []string) []Item {
    items := []Item{}
    for _, v := range endpoints {
        item := Item{Healthy: true, Endpoint: v}
        items = append(items, item)
    }

    return items
}

// Init - Initialise the LB algorithm.
func Init(name string, config config.Upstream) {
    switch config.BalancingAlgorithm {
    case lBIpHash:
        InitIpHash(name, config, enableHealthchecks)
    case lBLeastConnections:
        InitLeastConnection(name, config, enableHealthchecks)
    case lBRandom:
        InitRandom(name, config, enableHealthchecks)
    case lBRoundRobin:
        InitRoundRobin(name, config, enableHealthchecks)
    default: // round-robin (default)
        InitRoundRobin(name, config, enableHealthchecks)
    }
}

// InitRoundRobin - Initialise the LB algorithm for round robin selection.
func InitRoundRobin(name string, config config.Upstream, enableHealthchecks bool) {
    initLB()
    items := convertEndpoints(config.Endpoints)

    lb[name] = NewRoundRobinBalancer(name, items)

    if enableHealthchecks {
        CheckHealth(&lb[name].(*RoundRobinBalancer).NodeBalancer, config.Host, config.HealthCheck)
    }
}

// InitRandom - Initialise the LB algorithm for random selection.
func InitRandom(name string, config config.Upstream, enableHealthchecks bool) {
    initLB()
    items := convertEndpoints(config.Endpoints)

    lb[name] = NewRandomBalancer(name, items)

    if enableHealthchecks {
        CheckHealth(&lb[name].(*RandomBalancer).NodeBalancer, config.Host, config.HealthCheck)
    }
}

// InitLeastConnection - Initialise the LB algorithm for least-connection selection.
func InitLeastConnection(name string, config config.Upstream, enableHealthchecks bool) {
    initLB()
    items := convertEndpoints(config.Endpoints)

    lb[name] = NewLeastConnectionsBalancer(name, items)

    if enableHealthchecks {
        CheckHealth(&lb[name].(*LeastConnectionsBalancer).NodeBalancer, config.Host, config.HealthCheck)
    }
}

// InitIpHash - Initialise the LB algorithm for ip-hash selection.
func InitIpHash(name string, config config.Upstream, enableHealthchecks bool) {
    initLB()
    items := convertEndpoints(config.Endpoints)

    lb[name] = NewIpHashBalancer(name, items)

    if enableHealthchecks {
        CheckHealth(&lb[name].(*IpHashBalancer).NodeBalancer, config.Host, config.HealthCheck)
    }
}

// GetUpstreamNode - Returns backend server using current algorithm.
func GetUpstreamNode(name string, requestURL url.URL, defaultHost string) string {
    var err error

    endpoint := ""

    if lbDomain, ok := lb[name]; ok {
        endpoint, err = lbDomain.Pick(requestURL.String())
    }

    if err != nil || endpoint == "" {
        return defaultHost
    }

    return endpoint
}

// CheckHealth - Periodic check on nodes status.
func CheckHealth(b *NodeBalancer, host string, config config.HealthCheck) {
    period := config.Interval
    if period == 0 {
        period = HealthCheckInterval
    }

    go func() {
        t := time.NewTicker(period)

        for {
            <-t.C

            healthyCounter := 0
            unhealthyCounter := 0
            for k, v := range b.Items {
                DoHealthCheck(&v, host, config)

                if v.Healthy {
                    healthyCounter++
                } else {
                    unhealthyCounter++
                }

                b.M.Lock()
                b.Items[k] = v
                b.M.Unlock()
            }

            telemetry.RegisterHostHealth(healthyCounter, unhealthyCounter)
        }
    }()
}

func getClient(timeout time.Duration, tlsFlag bool, allowInsecure bool) *http.Client {
    if timeout == 0 {
        timeout = defaultClientTimeout
    }

    c := &http.Client{
        // return the 301/302
        CheckRedirect: func(req *http.Request, via []*http.Request) error {
            return http.ErrUseLastResponse
        },
        Timeout: timeout,
    }

    if tlsFlag {
        c.Transport = &http.Transport{
            TLSClientConfig: &tls.Config{
                InsecureSkipVerify: allowInsecure,
            },
        } //#nosec G402
    }

    return c
}

func DoHealthCheck(v *Item, host string, config config.HealthCheck) {
    url, _ := url.Parse(v.Endpoint)
    scheme := url.Scheme
    if scheme == "" || (scheme != "http" && scheme != "https") {
        scheme = config.Scheme
    }

    hostWithPort := url.Host
    if hostWithPort == "" {
        hostWithPort = v.Endpoint
    }
    _, port, err := net.SplitHostPort(hostWithPort)

    overridePort := ""
    if err != nil || port == "" {
        overridePort = fmt.Sprintf(":%s", config.Port)
    }

    overrideScheme := ""
    if url.Scheme != scheme {
        overrideScheme = fmt.Sprintf("%s://", scheme)
    }

    endpointURL := fmt.Sprintf("%s%s%s", overrideScheme, v.Endpoint, overridePort)

    req, err := http.NewRequest("HEAD", endpointURL, nil)
    if err != nil {
        logger.GetGlobal().Errorf("Healthcheck request failed for %s / %s: %s", host, endpointURL, err) // TODO: Add to trace span?
        return
    }
    res, err := getClient(config.Timeout, scheme == "https", config.AllowInsecure).Do(req)

    v.Healthy = err == nil
    if err != nil {
        logger.GetGlobal().Errorf("Healthcheck failed for %s: %s", endpointURL, err) // TODO: Add to trace span?
        metrics.SetUpstreamServerHealthChecksFails(host, hostWithPort)
    } else {
        v.Healthy = slice.ContainsString(config.StatusCodes, strconv.Itoa(res.StatusCode))

        if !v.Healthy {
            logger.GetGlobal().Errorf("Endpoint %s is not healthy (%d).", endpointURL, res.StatusCode) // TODO: Add to trace span?
            metrics.SetUpstreamServerHealthChecksUnhealthy(host, hostWithPort)
        } else {
            metrics.SetUpstreamServerHealthChecksHealthy(host, hostWithPort)
        }
    }
}

// GetHealthyNodes - Retrieves healthy nodes.
func (b *NodeBalancer) GetHealthyNodes() []Item {
    healthyNodes := []Item{}

    for _, v := range b.Items {
        if v.Healthy {
            healthyNodes = append(healthyNodes, v)
        }
    }

    return healthyNodes
}