ClusterLabs/hawk-apiserver

View on GitHub
server/reverseproxy.go

Summary

Maintainability
A
0 mins
Test Coverage
// This file is adapted from code in the net/http/httputil
// package of the Go standard library, which is by the
// Go Authors, and bears this copyright and license info:
//
//   Copyright 2011 The Go Authors. All rights reserved.
//   Use of this source code is governed by a BSD-style
//   license that can be found in the LICENSE file.
//
// This file has been modified from the standard lib to
// meet the needs of the application.

package server

import (
    "crypto/tls"
    "io"
    "net"
    "net/http"
    "net/url"
    "strings"
    "sync"
    "time"
)

var (
    defaultDialer = &net.Dialer{
        Timeout:   30 * time.Second,
        KeepAlive: 30 * time.Second,
    }

    bufferPool = sync.Pool{New: createBuffer}
)

func createBuffer() interface{} {
    return make([]byte, 0, 32*1024)
}

func pooledIoCopy(dst io.Writer, src io.Reader) {
    buf := bufferPool.Get().([]byte)
    defer bufferPool.Put(buf)

    // CopyBuffer only uses buf up to its length and panics if it's 0.
    // Due to that we extend buf's length to its capacity here and
    // ensure it's always non-zero.
    bufCap := cap(buf)
    io.CopyBuffer(dst, src, buf[0:bufCap:bufCap])
}

// onExitFlushLoop is a callback set by tests to detect the state of the
// flushLoop() goroutine.
var onExitFlushLoop func()

// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
type ReverseProxy struct {
    // Director must be a function which modifies
    // the request into a new request to be sent
    // using Transport. Its response is then copied
    // back to the original client unmodified.
    Director func(*http.Request)

    // The transport used to perform proxy requests.
    // If nil, http.DefaultTransport is used.
    Transport http.RoundTripper

    // FlushInterval specifies the flush interval
    // to flush to the client while copying the
    // response body.
    // If zero, no periodic flushing is done.
    FlushInterval time.Duration
}

// Though the relevant directive prefix is just "unix:", url.Parse
// will - assuming the regular URL scheme - add additional slashes
// as if "unix" was a request protocol.
// What we need is just the path, so if "unix:/var/run/www.socket"
// was the proxy directive, the parsed hostName would be
// "unix:///var/run/www.socket", hence the ambiguous trimming.
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
    return func(network, addr string) (conn net.Conn, err error) {
        return net.Dial("unix", hostName[len("unix://"):])
    }
}

func singleJoiningSlash(a, b string) string {
    aslash := strings.HasSuffix(a, "/")
    bslash := strings.HasPrefix(b, "/")
    switch {
    case aslash && bslash:
        return a + b[1:]
    case !aslash && !bslash && b != "":
        return a + "/" + b
    }
    return a + b
}

// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
// Without logic: target's path is "/", incoming is "/api/messages",
// without is "/api", then the target request will be for /messages.
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy {
    targetQuery := target.RawQuery
    director := func(req *http.Request) {
        if target.Scheme == "unix" {
            // to make Dial work with unix URL,
            // scheme and host have to be faked
            req.URL.Scheme = "http"
            req.URL.Host = "socket"
        } else {
            req.URL.Scheme = target.Scheme
            req.URL.Host = target.Host
        }

        // remove the `without` prefix
        if without != "" {
            req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
            if req.URL.Opaque != "" {
                req.URL.Opaque = strings.TrimPrefix(req.URL.Opaque, without)
            }
            if req.URL.RawPath != "" {
                req.URL.RawPath = strings.TrimPrefix(req.URL.RawPath, without)
            }
        }

        // prefer returns val if it isn't empty, otherwise def
        prefer := func(val, def string) string {
            if val != "" {
                return val
            }
            return def
        }

        // Make up the final URL by concatenating the request and target URL.
        //
        // If there is encoded part in request or target URL,
        // the final URL should also be in encoded format.
        // Here, we concatenate their encoded parts which are stored
        // in URL.Opaque and URL.RawPath, if it is empty use
        // URL.Path instead.
        if req.URL.Opaque != "" || target.Opaque != "" {
            req.URL.Opaque = singleJoiningSlash(
                prefer(target.Opaque, target.Path),
                prefer(req.URL.Opaque, req.URL.Path))
        }
        if req.URL.RawPath != "" || target.RawPath != "" {
            req.URL.RawPath = singleJoiningSlash(
                prefer(target.RawPath, target.Path),
                prefer(req.URL.RawPath, req.URL.Path))
        }
        req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)

        // Trims the path of the socket from the URL path.
        // This is done because req.URL passed to your proxied service
        // will have the full path of the socket file prefixed to it.
        // Calling /test on a server that proxies requests to
        // unix:/var/run/www.socket will thus set the requested path
        // to /var/run/www.socket/test, rendering paths useless.
        if target.Scheme == "unix" {
            // See comment on socketDial for the trim
            socketPrefix := target.String()[len("unix://"):]
            req.URL.Path = strings.TrimPrefix(req.URL.Path, socketPrefix)
            if req.URL.Opaque != "" {
                req.URL.Opaque = strings.TrimPrefix(req.URL.Opaque, socketPrefix)
            }
            if req.URL.RawPath != "" {
                req.URL.RawPath = strings.TrimPrefix(req.URL.RawPath, socketPrefix)
            }
        }

        if targetQuery == "" || req.URL.RawQuery == "" {
            req.URL.RawQuery = targetQuery + req.URL.RawQuery
        } else {
            req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
        }
    }
    rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events
    if target.Scheme == "unix" {
        rp.Transport = &http.Transport{
            Dial: socketDial(target.String()),
        }
    } else if keepalive != http.DefaultMaxIdleConnsPerHost {
        // if keepalive is equal to the default,
        // just use default transport, to avoid creating
        // a brand new transport
        transport := &http.Transport{
            Proxy:                 http.ProxyFromEnvironment,
            Dial:                  defaultDialer.Dial,
            TLSHandshakeTimeout:   10 * time.Second,
            ExpectContinueTimeout: 1 * time.Second,
        }
        if keepalive == 0 {
            transport.DisableKeepAlives = true
        } else {
            transport.MaxIdleConnsPerHost = keepalive
        }
        rp.Transport = transport
    }
    return rp
}

// UseInsecureTransport is used to facilitate HTTPS proxying
// when it is OK for upstream to be using a bad certificate,
// since this transport skips verification.
func (rp *ReverseProxy) UseInsecureTransport() {
    if rp.Transport == nil {
        transport := &http.Transport{
            Proxy:               http.ProxyFromEnvironment,
            Dial:                defaultDialer.Dial,
            TLSHandshakeTimeout: 10 * time.Second,
            TLSClientConfig:     &tls.Config{InsecureSkipVerify: true},
        }
        rp.Transport = transport
    } else if transport, ok := rp.Transport.(*http.Transport); ok {
        if transport.TLSClientConfig == nil {
            transport.TLSClientConfig = &tls.Config{}
        }
        transport.TLSClientConfig.InsecureSkipVerify = true
    }
}

// ServeHTTP serves the proxied request to the upstream by performing a roundtrip.
// It is designed to handle websocket connection upgrades as well.
func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
    transport := rp.Transport
    if requestIsWebsocket(outreq) {
        transport = newConnHijackerTransport(transport)
    } else if transport == nil {
        transport = http.DefaultTransport
    }

    rp.Director(outreq)

    res, err := transport.RoundTrip(outreq)
    if err != nil {
        return err
    }

    isWebsocket := res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket"

    // Remove hop-by-hop headers listed in the
    // "Connection" header of the response.
    if c := res.Header.Get("Connection"); c != "" {
        for _, f := range strings.Split(c, ",") {
            if f = strings.TrimSpace(f); f != "" {
                res.Header.Del(f)
            }
        }
    }

    for _, h := range hopHeaders {
        res.Header.Del(h)
    }

    if respUpdateFn != nil {
        respUpdateFn(res)
    }

    if isWebsocket {
        defer res.Body.Close()
        hj, ok := rw.(http.Hijacker)
        if !ok {
            panic(rw)
        }

        conn, brw, err := hj.Hijack()
        if err != nil {
            return err
        }
        defer conn.Close()

        var backendConn net.Conn
        if hj, ok := transport.(*connHijackerTransport); ok {
            backendConn = hj.Conn
            if _, err := conn.Write(hj.Replay); err != nil {
                return err
            }
            bufferPool.Put(hj.Replay)
        } else {
            backendConn, err = net.Dial("tcp", outreq.URL.Host)
            if err != nil {
                return err
            }
            outreq.Write(backendConn)
        }
        defer backendConn.Close()

        // Proxy backend -> frontend.
        go pooledIoCopy(conn, backendConn)

        // Proxy frontend -> backend.
        //
        // NOTE: Hijack() sometimes returns buffered up bytes in brw which
        // would be lost if we didn't read them out manually below.
        if brw != nil {
            if n := brw.Reader.Buffered(); n > 0 {
                rbuf, err := brw.Reader.Peek(n)
                if err != nil {
                    return err
                }
                backendConn.Write(rbuf)
            }
        }
        pooledIoCopy(backendConn, conn)
    } else {
        // NOTE:
        //   Closing the Body involves acquiring a mutex, which is a
        //   unnecessarily heavy operation, considering that this defer will
        //   pretty much never be executed with the Body still unclosed.
        bodyOpen := true
        closeBody := func() {
            if bodyOpen {
                res.Body.Close()
                bodyOpen = false
            }
        }
        defer closeBody()

        // Copy all headers over.
        // res.Header does not include the "Trailer" header,
        // which means we will have to do that manually below.
        copyHeader(rw.Header(), res.Header)

        // The "Trailer" header isn't included in res' Header map, which
        // is why we have to build one ourselves from res.Trailer.
        //
        // But res.Trailer does not necessarily contain all trailer keys at this
        // point yet. The HTTP spec allows one to send "unannounced trailers"
        // after a request and certain systems like gRPC make use of that.
        announcedTrailerKeyCount := len(res.Trailer)
        if announcedTrailerKeyCount > 0 {
            vv := make([]string, 0, announcedTrailerKeyCount)
            for k := range res.Trailer {
                vv = append(vv, k)
            }
            rw.Header()["Trailer"] = vv
        }

        // Now copy over the status code as well as the response body.
        rw.WriteHeader(res.StatusCode)
        if announcedTrailerKeyCount > 0 {
            // Force chunking if we saw a response trailer.
            // This prevents net/http from calculating the length
            // for short bodies and adding a Content-Length.
            if fl, ok := rw.(http.Flusher); ok {
                fl.Flush()
            }
        }
        rp.copyResponse(rw, res.Body)

        // Now close the body to fully populate res.Trailer.
        closeBody()

        // Since Go does not remove keys from res.Trailer we
        // can safely do a length comparison to check wether
        // we received further, unannounced trailers.
        //
        // Most of the time forceSetTrailers should be false.
        forceSetTrailers := len(res.Trailer) != announcedTrailerKeyCount
        shallowCopyTrailers(rw.Header(), res.Trailer, forceSetTrailers)
    }

    return nil
}

func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
    if rp.FlushInterval != 0 {
        if wf, ok := dst.(writeFlusher); ok {
            mlw := &maxLatencyWriter{
                dst:     wf,
                latency: rp.FlushInterval,
                done:    make(chan bool),
            }
            go mlw.flushLoop()
            defer mlw.stop()
            dst = mlw
        }
    }
    pooledIoCopy(dst, src)
}

// skip these headers if they already exist.
// see https://github.com/mholt/caddy/pull/1112#discussion_r80092582
var skipHeaders = map[string]struct{}{
    "Content-Type":        {},
    "Content-Disposition": {},
    "Accept-Ranges":       {},
    "Set-Cookie":          {},
    "Cache-Control":       {},
    "Expires":             {},
}

func copyHeader(dst, src http.Header) {
    for k, vv := range src {
        if _, ok := dst[k]; ok {
            // skip some predefined headers
            // see https://github.com/mholt/caddy/issues/1086
            if _, shouldSkip := skipHeaders[k]; shouldSkip {
                continue
            }
            // otherwise, overwrite to avoid duplicated fields that can be
            // problematic (see issue #1086) -- however, allow duplicate
            // Server fields so we can see the reality of the proxying.
            if k != "Server" {
                dst.Del(k)
            }
        }
        for _, v := range vv {
            dst.Add(k, v)
        }
    }
}

// shallowCopyTrailers copies all headers from srcTrailer to dstHeader.
//
// If forceSetTrailers is set to true, the http.TrailerPrefix will be added to
// all srcTrailer key names. Otherwise the Go stdlib will ignore all keys
// which weren't listed in the Trailer map before submitting the Response.
//
// WARNING: Only a shallow copy will be created!
func shallowCopyTrailers(dstHeader, srcTrailer http.Header, forceSetTrailers bool) {
    // Copy of http.TrailerPrefix moved here for compatibility
    // with go 1.6
    const TrailerPrefix = "Trailer:"
    for k, vv := range srcTrailer {
        if forceSetTrailers {
            k = TrailerPrefix + k
        }
        dstHeader[k] = vv
    }
}

// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
    "Alt-Svc",
    "Alternate-Protocol",
    "Connection",
    "Keep-Alive",
    "Proxy-Authenticate",
    "Proxy-Authorization",
    "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
    "Te",               // canonicalized version of "TE"
    "Trailer",          // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
    "Transfer-Encoding",
    "Upgrade",
}

type respUpdateFn func(resp *http.Response)

type hijackedConn struct {
    net.Conn
    hj *connHijackerTransport
}

func (c *hijackedConn) Read(b []byte) (n int, err error) {
    n, err = c.Conn.Read(b)
    c.hj.Replay = append(c.hj.Replay, b[:n]...)
    return
}

func (c *hijackedConn) Close() error {
    return nil
}

type connHijackerTransport struct {
    *http.Transport
    Conn   net.Conn
    Replay []byte
}

func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
    t := &http.Transport{
        MaxIdleConnsPerHost: -1,
    }
    if b, _ := base.(*http.Transport); b != nil {
        tlsClientConfig := b.TLSClientConfig
        if tlsClientConfig != nil && tlsClientConfig.NextProtos != nil {
            tlsClientConfig = b.Clone().TLSClientConfig.Clone()
            tlsClientConfig.NextProtos = nil
        }

        t.Proxy = b.Proxy
        t.TLSClientConfig = tlsClientConfig
        t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
        t.Dial = b.Dial
        t.DialTLS = b.DialTLS
    } else {
        t.Proxy = http.ProxyFromEnvironment
        t.TLSHandshakeTimeout = 10 * time.Second
    }
    hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}

    dial := getTransportDial(t)
    dialTLS := getTransportDialTLS(t)
    t.Dial = func(network, addr string) (net.Conn, error) {
        c, err := dial(network, addr)
        hj.Conn = c
        return &hijackedConn{c, hj}, err
    }
    t.DialTLS = func(network, addr string) (net.Conn, error) {
        c, err := dialTLS(network, addr)
        hj.Conn = c
        return &hijackedConn{c, hj}, err
    }

    return hj
}

// getTransportDial always returns a plain Dialer
// and defaults to the existing t.Dial.
func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
    if t.Dial != nil {
        return t.Dial
    }
    return defaultDialer.Dial
}

// getTransportDial always returns a TLS Dialer
// and defaults to the existing t.DialTLS.
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
    if t.DialTLS != nil {
        return t.DialTLS
    }

    // newConnHijackerTransport will modify t.Dial after calling this method
    // => Create a backup reference.
    plainDial := getTransportDial(t)

    // The following DialTLS implementation stems from the Go stdlib and
    // is identical to what happens if DialTLS is not provided.
    // Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
    return func(network, addr string) (net.Conn, error) {
        plainConn, err := plainDial(network, addr)
        if err != nil {
            return nil, err
        }

        tlsClientConfig := t.TLSClientConfig
        if tlsClientConfig == nil {
            tlsClientConfig = &tls.Config{}
        }
        if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
            tlsClientConfig.ServerName = stripPort(addr)
        }

        tlsConn := tls.Client(plainConn, tlsClientConfig)
        errc := make(chan error, 2)
        var timer *time.Timer
        if d := t.TLSHandshakeTimeout; d != 0 {
            timer = time.AfterFunc(d, func() {
                errc <- tlsHandshakeTimeoutError{}
            })
        }
        go func() {
            err := tlsConn.Handshake()
            if timer != nil {
                timer.Stop()
            }
            errc <- err
        }()
        if err := <-errc; err != nil {
            plainConn.Close()
            return nil, err
        }
        if !tlsClientConfig.InsecureSkipVerify {
            hostname := tlsClientConfig.ServerName
            if hostname == "" {
                hostname = stripPort(addr)
            }
            if err := tlsConn.VerifyHostname(hostname); err != nil {
                plainConn.Close()
                return nil, err
            }
        }

        return tlsConn, nil
    }
}

// stripPort returns address without its port if it has one and
// works with IP addresses as well as hostnames formatted as host:port.
//
// IPv6 addresses (excluding the port) must be enclosed in
// square brackets similar to the requirements of Go's stdlib.
func stripPort(address string) string {
    // Keep in mind that the address might be a IPv6 address
    // and thus contain a colon, but not have a port.
    portIdx := strings.LastIndex(address, ":")
    ipv6Idx := strings.LastIndex(address, "]")
    if portIdx > ipv6Idx {
        address = address[:portIdx]
    }
    return address
}

type tlsHandshakeTimeoutError struct{}

func (tlsHandshakeTimeoutError) Timeout() bool   { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string   { return "net/http: TLS handshake timeout" }

func requestIsWebsocket(req *http.Request) bool {
    return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
}

type writeFlusher interface {
    io.Writer
    http.Flusher
}

type maxLatencyWriter struct {
    dst     writeFlusher
    latency time.Duration

    lk   sync.Mutex // protects Write + Flush
    done chan bool
}

func (m *maxLatencyWriter) Write(p []byte) (int, error) {
    m.lk.Lock()
    defer m.lk.Unlock()
    return m.dst.Write(p)
}

func (m *maxLatencyWriter) flushLoop() {
    t := time.NewTicker(m.latency)
    defer t.Stop()
    for {
        select {
        case <-m.done:
            if onExitFlushLoop != nil {
                onExitFlushLoop()
            }
            return
        case <-t.C:
            m.lk.Lock()
            m.dst.Flush()
            m.lk.Unlock()
        }
    }
}

func (m *maxLatencyWriter) stop() { m.done <- true }