gitlabhq/gitlab-shell

View on GitHub
internal/sshd/sshd.go

Summary

Maintainability
A
0 mins
Test Coverage
// Package sshd provides functionality for handling SSH connections
package sshd

import (
    "context"
    "fmt"
    "net"
    "net/http"
    "strings"
    "sync"
    "time"

    proxyproto "github.com/pires/go-proxyproto"
    "golang.org/x/crypto/ssh"

    "gitlab.com/gitlab-org/gitlab-shell/v14/client"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/gitlabnet"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/metrics"

    "gitlab.com/gitlab-org/labkit/correlation"
    "gitlab.com/gitlab-org/labkit/log"
)

type status int

const (
    // StatusStarting represents the starting status of the SSH server
    StatusStarting status = iota

    // StatusReady represents the ready status of the SSH server
    StatusReady

    // StatusOnShutdown represents the status when the SSH server is shutting down
    StatusOnShutdown

    // StatusClosed represents the closed status of the SSH server
    StatusClosed
)

// Server represents an SSH server instance
type Server struct {
    Config *config.Config

    status       status
    statusMu     sync.RWMutex
    wg           sync.WaitGroup
    listener     net.Listener
    serverConfig *serverConfig
}

type logInfo struct{}

// NewServer creates a new instance of Server
func NewServer(cfg *config.Config) (*Server, error) {
    serverConfig, err := newServerConfig(cfg)
    if err != nil {
        return nil, err
    }

    return &Server{Config: cfg, serverConfig: serverConfig}, nil
}

// ListenAndServe starts listening for SSH connections and serves them
func (s *Server) ListenAndServe(ctx context.Context) error {
    if err := s.listen(ctx); err != nil {
        return err
    }
    defer func() { _ = s.listener.Close() }()

    s.serve(ctx)

    return nil
}

// Shutdown gracefully shuts down the SSH server
func (s *Server) Shutdown() error {
    if s.listener == nil {
        return nil
    }

    s.changeStatus(StatusOnShutdown)

    return s.listener.Close()
}

// MonitoringServeMux returns the ServeMux for monitoring endpoints
func (s *Server) MonitoringServeMux() *http.ServeMux {
    mux := http.NewServeMux()

    mux.HandleFunc(s.Config.Server.ReadinessProbe, func(w http.ResponseWriter, _ *http.Request) {
        if s.getStatus() == StatusReady {
            w.WriteHeader(http.StatusOK)
        } else {
            w.WriteHeader(http.StatusServiceUnavailable)
        }
    })

    mux.HandleFunc(s.Config.Server.LivenessProbe, func(w http.ResponseWriter, _ *http.Request) {
        w.WriteHeader(http.StatusOK)
    })

    return mux
}

func (s *Server) listen(ctx context.Context) error {
    sshListener, err := net.Listen("tcp", s.Config.Server.Listen)
    if err != nil {
        return fmt.Errorf("failed to listen for connection: %w", err)
    }

    if s.Config.Server.ProxyProtocol {
        policy, err := s.proxyPolicy()
        if err != nil {
            return fmt.Errorf("invalid policy configuration: %w", err)
        }

        sshListener = &proxyproto.Listener{
            Listener:          sshListener,
            Policy:            policy,
            ReadHeaderTimeout: time.Duration(s.Config.Server.ProxyHeaderTimeout),
        }

        log.ContextLogger(ctx).Info("Proxy protocol is enabled")
    }

    fields := log.Fields{
        "tcp_address": sshListener.Addr().String(),
    }

    if len(s.serverConfig.cfg.Server.PublicKeyAlgorithms) > 0 {
        fields["supported_public_key_algorithms"] = s.serverConfig.cfg.Server.PublicKeyAlgorithms
    }

    log.WithContextFields(ctx, fields).Info("Listening for SSH connections")

    s.listener = sshListener

    return nil
}

func (s *Server) serve(ctx context.Context) {
    s.changeStatus(StatusReady)

    for {
        nconn, err := s.listener.Accept()
        if err != nil {
            if s.getStatus() == StatusOnShutdown {
                break
            }

            log.ContextLogger(ctx).WithError(err).Warn("Failed to accept connection")
            continue
        }

        s.wg.Add(1)
        go s.handleConn(ctx, nconn)
    }

    s.wg.Wait()

    s.changeStatus(StatusClosed)
}

func (s *Server) changeStatus(st status) {
    s.statusMu.Lock()
    s.status = st
    s.statusMu.Unlock()
}

func (s *Server) getStatus() status {
    s.statusMu.RLock()
    defer s.statusMu.RUnlock()

    return s.status
}

func contextWithValues(parent context.Context, nconn net.Conn) context.Context {
    ctx := correlation.ContextWithCorrelation(parent, correlation.SafeRandomID())

    // If we're dealing with a PROXY connection, register the original requester's IP
    mconn, ok := nconn.(*proxyproto.Conn)
    if ok {
        ip := gitlabnet.ParseIP(mconn.Raw().RemoteAddr().String())
        ctx = context.WithValue(ctx, client.OriginalRemoteIPContextKey{}, ip)
    }

    return ctx
}

func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
    defer s.wg.Done()

    metrics.SshdConnectionsInFlight.Inc()
    defer metrics.SshdConnectionsInFlight.Dec()

    ctx, cancel := context.WithCancel(contextWithValues(ctx, nconn))
    defer cancel()
    go func() {
        <-ctx.Done()
        _ = nconn.Close() // Close the connection when context is canceled
    }()

    remoteAddr := nconn.RemoteAddr().String()
    ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr})

    // Prevent a panic in a single connection from taking out the whole server
    defer func() {
        if err := recover(); err != nil {
            ctxlog.WithField("recovered_error", err).Error("panic handling session")

            metrics.SliSshdSessionsErrorsTotal.Inc()
        }
    }()

    started := time.Now()
    conn := newConnection(s.Config, nconn)

    var ctxWithLogData context.Context

    conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) error {
        session := &session{
            cfg:                 s.Config,
            channel:             channel,
            gitlabKeyID:         sconn.Permissions.Extensions["key-id"],
            gitlabKrb5Principal: sconn.Permissions.Extensions["krb5principal"],
            gitlabUsername:      sconn.Permissions.Extensions["username"],
            namespace:           sconn.Permissions.Extensions["namespace"],
            remoteAddr:          remoteAddr,
            started:             time.Now(),
        }

        var err error
        ctxWithLogData, err = session.handle(ctx, requests)

        return err
    })

    logData := extractLogDataFromContext(ctxWithLogData)

    ctxlog.WithFields(log.Fields{
        "duration_s":    time.Since(started).Seconds(),
        "written_bytes": logData.WrittenBytes,
        "meta":          logData.Meta,
    }).Info("access: finish")
}

func (s *Server) proxyPolicy() (proxyproto.PolicyFunc, error) {
    if len(s.Config.Server.ProxyAllowed) > 0 {
        return proxyproto.StrictWhiteListPolicy(s.Config.Server.ProxyAllowed)
    }

    // Set the Policy value based on config
    // Values are taken from https://github.com/pires/go-proxyproto/blob/195fedcfbfc1be163f3a0d507fac1709e9d81fed/policy.go#L20
    switch strings.ToLower(s.Config.Server.ProxyPolicy) {
    case "require":
        return staticProxyPolicy(proxyproto.REQUIRE), nil
    case "ignore":
        return staticProxyPolicy(proxyproto.IGNORE), nil
    case "reject":
        return staticProxyPolicy(proxyproto.REJECT), nil
    default:
        return staticProxyPolicy(proxyproto.USE), nil
    }
}

func extractDataFromContext(ctx context.Context) command.LogData {
    logData := command.LogData{}

    if ctx == nil {
        return logData
    }

    if ctx.Value("logData") != nil {
        logData = ctx.Value("logData").(command.LogData)
    }

    return logData
}

func extractLogDataFromContext(ctx context.Context) command.LogData {
    logData := command.LogData{}

    if ctx == nil {
        return logData
    }

    if ctx.Value(logInfo{}) != nil {
        logData = ctx.Value(logInfo{}).(command.LogData)
    }

    return logData
}

func staticProxyPolicy(policy proxyproto.Policy) proxyproto.PolicyFunc {
    return func(_ net.Addr) (proxyproto.Policy, error) {
        return policy, nil
    }
}