kubenetworks/kubevpn

View on GitHub
pkg/util/portforward.go

Summary

Maintainability
A
3 hrs
Test Coverage
package util

import (
    "errors"
    "fmt"
    "io"
    "net"
    "net/http"
    "sort"
    "strconv"
    "strings"
    "sync"
    "sync/atomic"
    "time"

    "k8s.io/api/core/v1"
    k8serrors "k8s.io/apimachinery/pkg/api/errors"
    "k8s.io/apimachinery/pkg/util/httpstream"
    "k8s.io/apimachinery/pkg/util/runtime"
    "k8s.io/client-go/tools/portforward"
)

// PortForwarder knows how to listen for local connections and forward them to
// a remote pod via an upgraded HTTP request.
type PortForwarder struct {
    addresses []listenAddress
    ports     []ForwardedPort
    stopChan  <-chan struct{}
    // if failed to find socat, send error
    // if pod is not found, send error
    errChan chan error

    dialer        httpstream.Dialer
    streamConn    httpstream.Connection
    listeners     []io.Closer
    Ready         chan struct{}
    requestIDLock sync.Mutex
    requestID     int
    out           io.Writer
    errOut        io.Writer
}

// ForwardedPort contains a Local:Remote port pairing.
type ForwardedPort struct {
    Local  uint16
    Remote uint16
}

/*
valid port specifications:

5000
- forwards from localhost:5000 to pod:5000

8888:5000
- forwards from localhost:8888 to pod:5000

0:5000
:5000
  - selects a random available local port,
    forwards from localhost:<random port> to pod:5000
*/
func parsePorts(ports []string) ([]ForwardedPort, error) {
    var forwards []ForwardedPort
    for _, portString := range ports {
        parts := strings.Split(portString, ":")
        var localString, remoteString string
        if len(parts) == 1 {
            localString = parts[0]
            remoteString = parts[0]
        } else if len(parts) == 2 {
            localString = parts[0]
            if localString == "" {
                // support :5000
                localString = "0"
            }
            remoteString = parts[1]
        } else {
            return nil, fmt.Errorf("invalid port format '%s'", portString)
        }

        localPort, err := strconv.ParseUint(localString, 10, 16)
        if err != nil {
            return nil, fmt.Errorf("error parsing local port '%s': %s", localString, err)
        }

        remotePort, err := strconv.ParseUint(remoteString, 10, 16)
        if err != nil {
            return nil, fmt.Errorf("error parsing remote port '%s': %s", remoteString, err)
        }
        if remotePort == 0 {
            return nil, fmt.Errorf("remote port must be > 0")
        }

        forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
    }

    return forwards, nil
}

type listenAddress struct {
    address     string
    protocol    string
    failureMode string
}

func parseAddresses(addressesToParse []string) ([]listenAddress, error) {
    var addresses []listenAddress
    parsed := make(map[string]listenAddress)
    for _, address := range addressesToParse {
        if address == "localhost" {
            if _, exists := parsed["127.0.0.1"]; !exists {
                ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"}
                parsed[ip.address] = ip
            }
            if _, exists := parsed["::1"]; !exists {
                ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"}
                parsed[ip.address] = ip
            }
        } else if net.ParseIP(address).To4() != nil {
            parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"}
        } else if net.ParseIP(address) != nil {
            parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"}
        } else {
            return nil, fmt.Errorf("%s is not a valid IP", address)
        }
    }
    addresses = make([]listenAddress, len(parsed))
    id := 0
    for _, v := range parsed {
        addresses[id] = v
        id++
    }
    // Sort addresses before returning to get a stable order
    sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address })

    return addresses, nil
}

// NewOnAddresses creates a new PortForwarder with custom listen addresses.
func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, out, errOut io.Writer) (*PortForwarder, error) {
    if len(addresses) == 0 {
        return nil, errors.New("you must specify at least 1 address")
    }
    parsedAddresses, err := parseAddresses(addresses)
    if err != nil {
        return nil, err
    }
    if len(ports) == 0 {
        return nil, errors.New("you must specify at least 1 port")
    }
    parsedPorts, err := parsePorts(ports)
    if err != nil {
        return nil, err
    }
    return &PortForwarder{
        dialer:    dialer,
        addresses: parsedAddresses,
        ports:     parsedPorts,
        stopChan:  stopChan,
        errChan:   make(chan error, 1),
        Ready:     readyChan,
        out:       out,
        errOut:    errOut,
    }, nil
}

// ForwardPorts formats and executes a port forwarding request. The connection will remain
// open until stopChan is closed.
func (pf *PortForwarder) ForwardPorts() error {
    defer pf.Close()

    var err error
    pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name)
    if err != nil {
        return fmt.Errorf("error upgrading connection: %s", err)
    }
    defer pf.streamConn.Close()

    return pf.forward()
}

// forward dials the remote host specific in req, upgrades the request, starts
// listeners for each port specified in ports, and forwards local connections
// to the remote host via streams.
func (pf *PortForwarder) forward() error {
    var err error

    listenSuccess := false
    for i := range pf.ports {
        port := &pf.ports[i]
        err = pf.listenOnPort(port)
        switch {
        case err == nil:
            listenSuccess = true
        default:
            if pf.errOut != nil {
                fmt.Fprintf(pf.errOut, "Unable to listen on port %d: %v\n", port.Local, err)
            }
        }
    }

    if !listenSuccess {
        return fmt.Errorf("unable to listen on any of the requested ports: %v", pf.ports)
    }

    if pf.Ready != nil {
        close(pf.Ready)
    }

    // wait for interrupt or conn closure
    select {
    case <-pf.stopChan:
        runtime.HandleError(errors.New("lost connection to pod"))
    }
    select {
    case errs, ok := <-pf.errChan:
        if ok {
            return errs
        }
        return nil
    default:
        return nil
    }
}

// listenOnPort delegates listener creation and waits for connections on requested bind addresses.
// An error is raised based on address groups (default and localhost) and their failure modes
func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
    var errors []error
    failCounters := make(map[string]int, 2)
    successCounters := make(map[string]int, 2)
    for _, addr := range pf.addresses {
        err := pf.listenOnPortAndAddress(port, addr.protocol, addr.address)
        if err != nil {
            errors = append(errors, err)
            failCounters[addr.failureMode]++
        } else {
            successCounters[addr.failureMode]++
        }
    }
    if successCounters["all"] == 0 && failCounters["all"] > 0 {
        return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
    }
    if failCounters["any"] > 0 {
        return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
    }
    return nil
}

// listenOnPortAndAddress delegates listener creation and waits for new connections
// in the background f
func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol string, address string) error {
    listener, err := pf.getListener(protocol, address, port)
    if err != nil {
        return err
    }
    pf.listeners = append(pf.listeners, listener)
    go pf.waitForConnection(listener, *port)
    return nil
}

// getListener creates a listener on the interface targeted by the given hostname on the given port with
// the given protocol. protocol is in net.Listen style which basically admits values like tcp, tcp4, tcp6
func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
    listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local))))
    if err != nil {
        return nil, fmt.Errorf("unable to create listener: Error %s", err)
    }
    listenerAddress := listener.Addr().String()
    host, localPort, _ := net.SplitHostPort(listenerAddress)
    localPortUInt, err := strconv.ParseUint(localPort, 10, 16)

    if err != nil {
        fmt.Fprintf(pf.out, "Failed to forward from %s:%d -> %d\n", hostname, localPortUInt, port.Remote)
        return nil, fmt.Errorf("error parsing local port: %s from %s (%s)", err, listenerAddress, host)
    }
    port.Local = uint16(localPortUInt)
    if pf.out != nil {
        fmt.Fprintf(pf.out, "Forwarding from %s -> %d\n", net.JoinHostPort(hostname, strconv.Itoa(int(localPortUInt))), port.Remote)
    }

    return listener, nil
}

// waitForConnection waits for new connections to listener and handles them in
// the background.
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
    for {
        conn, err := listener.Accept()
        if err != nil {
            // TODO consider using something like https://github.com/hydrogen18/stoppableListener?
            if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
                runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err))
            }
            return
        }
        go pf.handleConnection(conn, port)
    }
}

func (pf *PortForwarder) nextRequestID() int {
    pf.requestIDLock.Lock()
    defer pf.requestIDLock.Unlock()
    id := pf.requestID
    pf.requestID++
    return id
}

// handleConnection copies data between the local connection and the stream to
// the remote server.
func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
    defer conn.Close()

    if pf.out != nil {
        fmt.Fprintf(pf.out, "Handling connection for %d\n", port.Local)
    }

    requestID := pf.nextRequestID()
    // create error stream
    headers := http.Header{}
    headers.Set(v1.StreamType, v1.StreamTypeError)
    headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote))
    headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(requestID))
    var err error
    errorStream, err := pf.streamConn.CreateStream(headers)
    if err != nil {
        runtime.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err))
        return
    }
    // we're not writing to this stream
    errorStream.Close()

    errorChan := make(chan error)
    go func() {
        message, err := io.ReadAll(errorStream)
        switch {
        case err != nil:
            errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
        case len(message) > 0:
            errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
        }
        close(errorChan)
    }()

    // create data stream
    headers.Set(v1.StreamType, v1.StreamTypeData)
    dataStream, err := pf.streamConn.CreateStream(headers)
    if err != nil {
        runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
        return
    }

    localError := make(chan struct{})
    remoteDone := make(chan struct{})

    go func() {
        // Copy from the remote side to the local port.
        if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
            runtime.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err))
        }

        // inform the select below that the remote copy is done
        close(remoteDone)
    }()

    go func() {
        // inform server we're not sending any more data after copy unblocks
        defer dataStream.Close()

        // Copy from the local port to the remote side.
        if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
            runtime.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err))
            // break out of the select below without waiting for the other copy to finish
            close(localError)
        }
    }()

    // wait for either a local->remote error or for copying from remote->local to finish
    select {
    case <-remoteDone:
    case <-localError:
    // wait for interrupt or conn closure
    case <-pf.stopChan:
        runtime.HandleError(errors.New("lost connection to pod"))
    }

    // always expect something on errorChan (it may be nil)
    select {
    case err = <-errorChan:
    default:
    }
    if err != nil {
        if strings.Contains(err.Error(), "failed to find socat") {
            select {
            case pf.errChan <- err:
            default:
            }
        }
        runtime.HandleError(err)
    }
}

// Close stops all listeners of PortForwarder.
func (pf *PortForwarder) Close() {
    // stop all listeners
    for _, l := range pf.listeners {
        if err := l.Close(); err != nil {
            runtime.HandleError(fmt.Errorf("error closing listener: %v", err))
        }
    }
}

// GetPorts will return the ports that were forwarded; this can be used to
// retrieve the locally-bound port in cases where the input was port 0. This
// function will signal an error if the Ready channel is nil or if the
// listeners are not ready yet; this function will succeed after the Ready
// channel has been closed.
func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) {
    if pf.Ready == nil {
        return nil, fmt.Errorf("no ready channel provided")
    }
    select {
    case <-pf.Ready:
        return pf.ports, nil
    default:
        return nil, fmt.Errorf("listeners not ready")
    }
}

func (pf *PortForwarder) tryToCreateStream(header *http.Header) (httpstream.Stream, error) {
    errorChan := make(chan error, 2)
    var resultChan atomic.Value
    time.AfterFunc(time.Second*1, func() {
        errorChan <- errors.New("timeout")
    })
    go func() {
        if pf.streamConn != nil {
            if stream, err := pf.streamConn.CreateStream(*header); err == nil && stream != nil {
                errorChan <- nil
                resultChan.Store(stream)
                return
            }
        }
        errorChan <- errors.New("")
    }()
    if err := <-errorChan; err == nil && resultChan.Load() != nil {
        return resultChan.Load().(httpstream.Stream), nil
    }
    // close old connection in case of resource leak
    if pf.streamConn != nil {
        _ = pf.streamConn.Close()
    }
    var err error
    pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name)
    if err != nil {
        if k8serrors.IsNotFound(err) {
            runtime.HandleError(fmt.Errorf("pod not found: %s", err))
            select {
            case pf.errChan <- err:
            default:
            }
        } else {
            runtime.HandleError(fmt.Errorf("error upgrading connection: %s", err))
        }
        return nil, err
    }
    header.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(pf.nextRequestID()))
    return pf.streamConn.CreateStream(*header)
}