oauth2-proxy/oauth2-proxy

View on GitHub
pkg/http/server.go

Summary

Maintainability
A
1 hr
Test Coverage
A
91%
package http

import (
    "context"
    "crypto/tls"
    "errors"
    "fmt"
    "net"
    "net/http"
    "strings"
    "time"

    "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
    "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util"
    "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
    "golang.org/x/sync/errgroup"
)

// Server represents an HTTP or HTTPS server.
type Server interface {
    // Start blocks and runs the server.
    Start(ctx context.Context) error
}

// Opts contains the information required to set up the server.
type Opts struct {
    // Handler is the http.Handler to be used to serve http pages by the server.
    Handler http.Handler

    // BindAddress is the address the HTTP server should listen on.
    BindAddress string

    // SecureBindAddress is the address the HTTPS server should listen on.
    SecureBindAddress string

    // TLS is the TLS configuration for the server.
    TLS *options.TLS
}

// NewServer creates a new Server from the options given.
func NewServer(opts Opts) (Server, error) {
    s := &server{
        handler: opts.Handler,
    }
    if err := s.setupListener(opts); err != nil {
        return nil, fmt.Errorf("error setting up listener: %v", err)
    }
    if err := s.setupTLSListener(opts); err != nil {
        return nil, fmt.Errorf("error setting up TLS listener: %v", err)
    }

    return s, nil
}

// server is an implementation of the Server interface.
type server struct {
    handler http.Handler

    listener    net.Listener
    tlsListener net.Listener
}

// setupListener sets the server listener if the HTTP server is enabled.
// The HTTP server can be disabled by setting the BindAddress to "-" or by
// leaving it empty.
func (s *server) setupListener(opts Opts) error {
    if opts.BindAddress == "" || opts.BindAddress == "-" {
        // No HTTP listener required
        return nil
    }

    networkType := getNetworkScheme(opts.BindAddress)
    listenAddr := getListenAddress(opts.BindAddress)

    listener, err := net.Listen(networkType, listenAddr)
    if err != nil {
        return fmt.Errorf("listen (%s, %s) failed: %v", networkType, listenAddr, err)
    }
    s.listener = listener

    return nil
}

func parseCipherSuites(names []string) ([]uint16, error) {
    cipherNameMap := make(map[string]uint16)

    for _, cipherSuite := range tls.CipherSuites() {
        cipherNameMap[cipherSuite.Name] = cipherSuite.ID
    }
    for _, cipherSuite := range tls.InsecureCipherSuites() {
        cipherNameMap[cipherSuite.Name] = cipherSuite.ID
    }

    result := make([]uint16, len(names))
    for i, name := range names {
        id, present := cipherNameMap[name]
        if !present {
            return nil, fmt.Errorf("unknown TLS cipher suite name specified %q", name)
        }
        result[i] = id
    }
    return result, nil
}

// setupTLSListener sets the server TLS listener if the HTTPS server is enabled.
// The HTTPS server can be disabled by setting the SecureBindAddress to "-" or by
// leaving it empty.
func (s *server) setupTLSListener(opts Opts) error {
    if opts.SecureBindAddress == "" || opts.SecureBindAddress == "-" {
        // No HTTPS listener required
        return nil
    }

    config := &tls.Config{
        MinVersion: tls.VersionTLS12, // default, override below
        MaxVersion: tls.VersionTLS13,
        NextProtos: []string{"http/1.1"},
    }
    if opts.TLS == nil {
        return errors.New("no TLS config provided")
    }
    cert, err := getCertificate(opts.TLS)
    if err != nil {
        return fmt.Errorf("could not load certificate: %v", err)
    }
    config.Certificates = []tls.Certificate{cert}

    if len(opts.TLS.CipherSuites) > 0 {
        cipherSuites, err := parseCipherSuites(opts.TLS.CipherSuites)
        if err != nil {
            return fmt.Errorf("could not parse cipher suites: %v", err)
        }
        config.CipherSuites = cipherSuites
    }

    if len(opts.TLS.MinVersion) > 0 {
        switch opts.TLS.MinVersion {
        case "TLS1.2":
            config.MinVersion = tls.VersionTLS12
        case "TLS1.3":
            config.MinVersion = tls.VersionTLS13
        default:
            return errors.New("unknown TLS MinVersion config provided")
        }
    }

    listenAddr := getListenAddress(opts.SecureBindAddress)

    listener, err := net.Listen("tcp", listenAddr)
    if err != nil {
        return fmt.Errorf("listen (%s) failed: %v", listenAddr, err)
    }

    s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config)
    return nil
}

// Start starts the HTTP and HTTPS server if applicable.
// It will block until the context is cancelled.
// If any errors occur, only the first error will be returned.
func (s *server) Start(ctx context.Context) error {
    g, groupCtx := errgroup.WithContext(ctx)

    if s.listener != nil {
        g.Go(func() error {
            if err := s.startServer(groupCtx, s.listener); err != nil {
                return fmt.Errorf("error starting insecure server: %v", err)
            }
            return nil
        })
    }

    if s.tlsListener != nil {
        g.Go(func() error {
            if err := s.startServer(groupCtx, s.tlsListener); err != nil {
                return fmt.Errorf("error starting secure server: %v", err)
            }
            return nil
        })
    }

    return g.Wait()
}

// startServer creates and starts a new server with the given listener.
// When the given context is cancelled the server will be shutdown.
// If any errors occur, only the first error will be returned.
func (s *server) startServer(ctx context.Context, listener net.Listener) error {
    srv := &http.Server{Handler: s.handler, ReadHeaderTimeout: time.Minute}
    g, groupCtx := errgroup.WithContext(ctx)

    g.Go(func() error {
        <-groupCtx.Done()

        if err := srv.Shutdown(context.Background()); err != nil {
            return fmt.Errorf("error shutting down server: %v", err)
        }
        return nil
    })

    g.Go(func() error {
        if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
            return fmt.Errorf("could not start server: %v", err)
        }
        return nil
    })

    return g.Wait()
}

// getNetworkScheme gets the scheme for the HTTP server.
func getNetworkScheme(addr string) string {
    var scheme string
    i := strings.Index(addr, "://")
    if i > -1 {
        scheme = addr[0:i]
    }

    switch scheme {
    case "", "http":
        return "tcp"
    default:
        return scheme
    }
}

// getListenAddress gets the address for the HTTP server.
func getListenAddress(addr string) string {
    slice := strings.SplitN(addr, "//", 2)
    return slice[len(slice)-1]
}

// getCertificate loads the certificate data from the TLS config.
func getCertificate(opts *options.TLS) (tls.Certificate, error) {
    keyData, err := getSecretValue(opts.Key)
    if err != nil {
        return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err)
    }

    certData, err := getSecretValue(opts.Cert)
    if err != nil {
        return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err)
    }

    cert, err := tls.X509KeyPair(certData, keyData)
    if err != nil {
        return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err)
    }

    return cert, nil
}

// getSecretValue wraps util.GetSecretValue so that we can return an error if no
// source is provided.
func getSecretValue(src *options.SecretSource) ([]byte, error) {
    if src == nil {
        return nil, errors.New("no configuration provided")
    }
    return util.GetSecretValue(src)
}

// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by so that dead TCP connections (e.g. closing laptop
// mid-download) eventually go away.
type tcpKeepAliveListener struct {
    *net.TCPListener
}

// Accept implements the TCPListener interface.
// It sets the keep alive period to 3 minutes for each connection.
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
    tc, err := ln.AcceptTCP()
    if err != nil {
        return nil, err
    }
    err = tc.SetKeepAlive(true)
    if err != nil {
        logger.Errorf("Error setting Keep-Alive: %v", err)
    }
    err = tc.SetKeepAlivePeriod(3 * time.Minute)
    if err != nil {
        logger.Printf("Error setting Keep-Alive period: %v", err)
    }
    return tc, nil
}