gitlabhq/gitlab-shell

View on GitHub
internal/sshd/server_config.go

Summary

Maintainability
A
1 hr
Test Coverage
// Package sshd implements functionality related to SSH server configuration and handling
package sshd

import (
    "context"
    "encoding/base64"
    "fmt"
    "os"
    "path/filepath"
    "strconv"
    "strings"
    "time"

    "golang.org/x/crypto/ssh"

    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/gitlabnet/authorizedcerts"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/gitlabnet/authorizedkeys"

    "gitlab.com/gitlab-org/labkit/log"
)

var (
    supportedMACs = []string{
        "hmac-sha2-256-etm@openssh.com",
        "hmac-sha2-512-etm@openssh.com",
        "hmac-sha2-256",
        "hmac-sha2-512",
        "hmac-sha1",
    }

    supportedKeyExchanges = []string{
        "curve25519-sha256",
        "curve25519-sha256@libssh.org",
        "ecdh-sha2-nistp256",
        "ecdh-sha2-nistp384",
        "ecdh-sha2-nistp521",
        "diffie-hellman-group14-sha256",
        "diffie-hellman-group14-sha1",
    }
)

type serverConfig struct {
    cfg                   *config.Config
    hostKeys              []ssh.Signer
    hostKeyToCertMap      map[string]*ssh.Certificate
    authorizedKeysClient  *authorizedkeys.Client
    authorizedCertsClient *authorizedcerts.Client
}

func parseHostKeys(keyFiles []string) []ssh.Signer {
    var hostKeys []ssh.Signer

    for _, filename := range keyFiles {
        keyRaw, err := os.ReadFile(filepath.Clean(filename))
        if err != nil {
            log.WithError(err).WithFields(log.Fields{"filename": filename}).Error("Failed to read host key")
            continue
        }
        key, err := ssh.ParsePrivateKey(keyRaw)
        if err != nil {
            log.WithError(err).WithFields(log.Fields{"filename": filename}).Error("Failed to parse host key")
            continue
        }

        hostKeys = append(hostKeys, key)
    }

    return hostKeys
}

func parseHostCerts(hostKeys []ssh.Signer, certFiles []string) map[string]*ssh.Certificate {
    keyToCertMap := map[string]*ssh.Certificate{}
    hostKeyIndex := make(map[string]int)

    for index, hostKey := range hostKeys {
        hostKeyIndex[string(hostKey.PublicKey().Marshal())] = index
    }

    for _, filename := range certFiles {
        keyRaw, err := os.ReadFile(filepath.Clean(filename))
        if err != nil {
            log.WithError(err).WithFields(log.Fields{"filename": filename}).Error("failed to read host certificate")
            continue
        }
        publicKey, _, _, _, err := ssh.ParseAuthorizedKey(keyRaw)
        if err != nil {
            log.WithError(err).WithFields(log.Fields{"filename": filename}).Error("failed to parse host certificate")
            continue
        }

        cert, ok := publicKey.(*ssh.Certificate)
        if !ok {
            log.WithFields(log.Fields{"filename": filename}).Error("failed to decode host certificate")
            continue
        }

        hostRawKey := string(cert.Key.Marshal())
        index, found := hostKeyIndex[hostRawKey]
        if found {
            keyToCertMap[hostRawKey] = cert

            certSigner, err := ssh.NewCertSigner(cert, hostKeys[index])
            if err != nil {
                log.WithError(err).WithFields(log.Fields{"filename": filename}).Error("the host certificate doesn't match the host private key")
                continue
            }

            hostKeys[index] = certSigner
        } else {
            log.WithFields(log.Fields{"filename": filename}).Errorf("no matching private key for certificate %s", filename)
        }
    }

    return keyToCertMap
}

func newServerConfig(cfg *config.Config) (*serverConfig, error) {
    authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
    if err != nil {
        return nil, fmt.Errorf("failed to initialize authorized keys client: %w", err)
    }

    authorizedCertsClient, err := authorizedcerts.NewClient(cfg)
    if err != nil {
        return nil, fmt.Errorf("failed to initialize authorized certs client: %w", err)
    }

    hostKeys := parseHostKeys(cfg.Server.HostKeyFiles)
    if len(hostKeys) == 0 {
        return nil, fmt.Errorf("no host keys could be loaded, aborting")
    }

    hostKeyToCertMap := parseHostCerts(hostKeys, cfg.Server.HostCertFiles)

    return &serverConfig{
        cfg:                   cfg,
        authorizedKeysClient:  authorizedKeysClient,
        authorizedCertsClient: authorizedCertsClient,
        hostKeys:              hostKeys,
        hostKeyToCertMap:      hostKeyToCertMap,
    }, nil
}

func (s *serverConfig) handleUserKey(ctx context.Context, user string, key ssh.PublicKey) (*ssh.Permissions, error) {
    if user != s.cfg.User {
        return nil, fmt.Errorf("unknown user")
    }
    if key.Type() == ssh.KeyAlgoDSA {
        return nil, fmt.Errorf("DSA is prohibited")
    }

    res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
    if err != nil {
        return nil, err
    }

    return &ssh.Permissions{
        // Record the public key used for authentication.
        Extensions: map[string]string{
            "key-id": strconv.FormatInt(res.ID, 10),
        },
    }, nil
}

func (s *serverConfig) handleUserCertificate(ctx context.Context, user string, cert *ssh.Certificate) (*ssh.Permissions, error) {
    if os.Getenv("FF_GITLAB_SHELL_SSH_CERTIFICATES") != "1" {
        return nil, fmt.Errorf("handleUserCertificate: feature is disabled")
    }

    fingerprint := ssh.FingerprintSHA256(cert.SignatureKey)

    if cert.CertType != ssh.UserCert {
        return nil, fmt.Errorf("handleUserCertificate: cert has type %d", cert.CertType)
    }

    certChecker := &ssh.CertChecker{}
    if err := certChecker.CheckCert(user, cert); err != nil {
        return nil, err
    }

    logger := log.WithContextFields(ctx,
        log.Fields{
            "ssh_user":               user,
            "public_key_fingerprint": ssh.FingerprintSHA256(cert),
            "signing_ca_fingerprint": fingerprint,
            "certificate_identity":   cert.KeyId,
        },
    )

    res, err := s.authorizedCertsClient.GetByKey(ctx, cert.KeyId, strings.TrimPrefix(fingerprint, "SHA256:"))
    if err != nil {
        logger.WithError(err).Warn("user certificate is not signed by a trusted key")

        return nil, err
    }

    logger.WithFields(
        log.Fields{
            "certificate_username":  res.Username,
            "certificate_namespace": res.Namespace,
        },
    ).Info("user certificate is signed by a trusted key")

    return &ssh.Permissions{
        Extensions: map[string]string{
            "username":  res.Username,
            "namespace": res.Namespace,
        },
    }, nil
}

func (s *serverConfig) get(parentCtx context.Context) *ssh.ServerConfig {
    var gssapiWithMICConfig *ssh.GSSAPIWithMICConfig
    if s.cfg.Server.GSSAPI.Enabled {
        gssAPIServer, _ := NewGSSAPIServer(&s.cfg.Server.GSSAPI)

        if gssAPIServer != nil {
            gssapiWithMICConfig = &ssh.GSSAPIWithMICConfig{
                AllowLogin: func(conn ssh.ConnMetadata, srcName string) (*ssh.Permissions, error) {
                    if conn.User() != s.cfg.User {
                        return nil, fmt.Errorf("unknown user")
                    }

                    return &ssh.Permissions{
                        // Record the Kerberos principal used for authentication.
                        Extensions: map[string]string{
                            "krb5principal": srcName,
                        },
                    }, nil
                },
                Server: gssAPIServer,
            }
        }
    }

    sshCfg := &ssh.ServerConfig{
        PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
            ctx, cancel := context.WithTimeout(parentCtx, 10*time.Second)
            defer cancel()

            log.WithContextFields(ctx, log.Fields{"ssh_key_type": key.Type()}).Info("public key authentication")

            cert, ok := key.(*ssh.Certificate)
            if ok {
                return s.handleUserCertificate(ctx, conn.User(), cert)
            }

            return s.handleUserKey(ctx, conn.User(), key)
        },
        GSSAPIWithMICConfig: gssapiWithMICConfig,
        ServerVersion:       "SSH-2.0-GitLab-SSHD",
    }

    s.configureMACs(sshCfg)
    s.configureKeyExchanges(sshCfg)
    s.configureCiphers(sshCfg)
    s.configurePublicKeyAlgorithms(sshCfg)

    for _, key := range s.hostKeys {
        sshCfg.AddHostKey(key)
    }

    return sshCfg
}

func (s *serverConfig) configurePublicKeyAlgorithms(sshCfg *ssh.ServerConfig) {
    if len(s.cfg.Server.PublicKeyAlgorithms) > 0 {
        sshCfg.PublicKeyAuthAlgorithms = s.cfg.Server.PublicKeyAlgorithms
    }
}

func (s *serverConfig) configureCiphers(sshCfg *ssh.ServerConfig) {
    if len(s.cfg.Server.Ciphers) > 0 {
        sshCfg.Ciphers = s.cfg.Server.Ciphers
    }
}

func (s *serverConfig) configureKeyExchanges(sshCfg *ssh.ServerConfig) {
    if len(s.cfg.Server.KexAlgorithms) > 0 {
        sshCfg.KeyExchanges = s.cfg.Server.KexAlgorithms
    } else {
        sshCfg.KeyExchanges = supportedKeyExchanges
    }
}

func (s *serverConfig) configureMACs(sshCfg *ssh.ServerConfig) {
    if len(s.cfg.Server.MACs) > 0 {
        sshCfg.MACs = s.cfg.Server.MACs
    } else {
        sshCfg.MACs = supportedMACs
    }
}