gitlabhq/gitlab-shell

View on GitHub
internal/sshd/sshd_test.go

Summary

Maintainability
B
5 hrs
Test Coverage
package sshd

import (
    "context"
    "fmt"
    "net"
    "net/http"
    "net/http/httptest"
    "os"
    "path"
    "testing"
    "time"

    "github.com/pires/go-proxyproto"
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
    "golang.org/x/crypto/ssh"

    "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/testhelper"
)

const (
    serverURL = "127.0.0.1:50000"
    user      = "git"
)

var (
    correlationID = ""
    xForwardedFor = ""
)

func TestListenAndServe(t *testing.T) {
    s, testRoot := setupServer(t)

    client, err := ssh.Dial("tcp", serverURL, clientConfig(t, testRoot))
    require.NoError(t, err)
    defer client.Close()

    require.NoError(t, s.Shutdown())
    verifyStatus(t, s, StatusOnShutdown)

    holdSession(t, client)

    _, err = ssh.Dial("tcp", serverURL, clientConfig(t, testRoot))
    require.Equal(t, "dial tcp 127.0.0.1:50000: connect: connection refused", err.Error())

    client.Close()

    verifyStatus(t, s, StatusClosed)
}

func TestListenAndServe_proxyProtocolEnabled(t *testing.T) {
    testRoot := testhelper.PrepareTestRootDir(t)

    target, err := net.ResolveTCPAddr("tcp", serverURL)
    require.NoError(t, err)

    header := &proxyproto.Header{
        Version:           2,
        Command:           proxyproto.PROXY,
        TransportProtocol: proxyproto.TCPv4,
        SourceAddr: &net.TCPAddr{
            IP:   net.ParseIP("10.1.1.1"),
            Port: 1000,
        },
        DestinationAddr: target,
    }
    xForwardedFor = "127.0.0.1"
    defer func() {
        xForwardedFor = "" // Cleanup for other test cases
    }()

    testCases := []struct {
        desc         string
        proxyPolicy  string
        proxyAllowed []string
        header       *proxyproto.Header
        isRejected   bool
    }{
        {
            desc:        "USE (default) without a header",
            proxyPolicy: "",
            header:      nil,
            isRejected:  false,
        },
        {
            desc:        "USE (default) with a header",
            proxyPolicy: "",
            header:      header,
            isRejected:  false,
        },
        {
            desc:        "REQUIRE without a header",
            proxyPolicy: "require",
            header:      nil,
            isRejected:  true,
        },
        {
            desc:        "REQUIRE with a header",
            proxyPolicy: "require",
            header:      header,
            isRejected:  false,
        },
        {
            desc:        "REJECT without a header",
            proxyPolicy: "reject",
            header:      nil,
            isRejected:  false,
        },
        {
            desc:        "REJECT with a header",
            proxyPolicy: "reject",
            header:      header,
            isRejected:  true,
        },
        {
            desc:        "IGNORE without a header",
            proxyPolicy: "ignore",
            header:      nil,
            isRejected:  false,
        },
        {
            desc:        "IGNORE with a header",
            proxyPolicy: "ignore",
            header:      header,
            isRejected:  false,
        },
        {
            desc:         "Allow-listed IP with a header",
            proxyAllowed: []string{"127.0.0.1"},
            header:       header,
            isRejected:   false,
        },
        {
            desc:         "Allow-listed IP without a header",
            proxyAllowed: []string{"127.0.0.1"},
            header:       nil,
            isRejected:   false,
        },
        {
            desc:         "Allow-listed range with a header",
            proxyAllowed: []string{"127.0.0.0/24"},
            header:       header,
            isRejected:   false,
        },
        {
            desc:         "Allow-listed range without a header",
            proxyAllowed: []string{"127.0.0.0/24"},
            header:       nil,
            isRejected:   false,
        },
        {
            desc:         "Not allow-listed IP with a header",
            proxyAllowed: []string{"192.168.1.1"},
            header:       header,
            isRejected:   true,
        },
        {
            desc:         "Not allow-listed IP without a header",
            proxyAllowed: []string{"192.168.1.1"},
            header:       nil,
            isRejected:   false,
        },
        {
            desc:         "Not allow-listed range with a header",
            proxyAllowed: []string{"192.168.1.0/24"},
            header:       header,
            isRejected:   true,
        },
        {
            desc:         "Not allow-listed range without a header",
            proxyAllowed: []string{"192.168.1.0/24"},
            header:       nil,
            isRejected:   false,
        },
    }

    for _, tc := range testCases {
        t.Run(tc.desc, func(t *testing.T) {
            setupServerWithConfig(t, &config.Config{
                Server: config.ServerConfig{
                    ProxyProtocol: true,
                    ProxyPolicy:   tc.proxyPolicy,
                    ProxyAllowed:  tc.proxyAllowed,
                },
            })

            conn, err := net.DialTCP("tcp", nil, target)
            require.NoError(t, err)

            if tc.header != nil {
                _, writeToErr := header.WriteTo(conn)
                require.NoError(t, writeToErr)
            }

            sshConn, sshChans, sshRequs, err := ssh.NewClientConn(conn, serverURL, clientConfig(t, testRoot))
            if sshConn != nil {
                defer sshConn.Close()
            }

            if tc.isRejected {
                require.Error(t, err, "Expected plain SSH request to be failed")
                require.Regexp(t, "ssh: handshake failed", err.Error())
            } else {
                require.NoError(t, err)
                client := ssh.NewClient(sshConn, sshChans, sshRequs)
                defer client.Close()

                holdSession(t, client)
            }
        })
    }
}

func TestCorrelationId(t *testing.T) {
    _, testRoot := setupServer(t)

    client, err := ssh.Dial("tcp", serverURL, clientConfig(t, testRoot))
    require.NoError(t, err)
    defer client.Close()

    holdSession(t, client)

    previousCorrelationID := correlationID

    client, err = ssh.Dial("tcp", serverURL, clientConfig(t, testRoot))
    require.NoError(t, err)
    defer client.Close()

    holdSession(t, client)

    require.NotEqual(t, previousCorrelationID, correlationID)
}

func TestReadinessProbe(t *testing.T) {
    s := &Server{Config: &config.Config{Server: config.DefaultServerConfig}}

    require.Equal(t, StatusStarting, s.getStatus())

    mux := s.MonitoringServeMux()

    req := httptest.NewRequest("GET", "/start", nil)

    r := httptest.NewRecorder()
    mux.ServeHTTP(r, req)
    res := r.Result()
    require.Equal(t, 503, res.StatusCode)
    res.Body.Close()

    s.changeStatus(StatusReady)

    r = httptest.NewRecorder()
    mux.ServeHTTP(r, req)
    res = r.Result()
    require.Equal(t, 200, res.StatusCode)
    res.Body.Close()

    s.changeStatus(StatusOnShutdown)

    r = httptest.NewRecorder()
    mux.ServeHTTP(r, req)
    res = r.Result()
    require.Equal(t, 503, res.StatusCode)
    res.Body.Close()
}

func TestLivenessProbe(t *testing.T) {
    s := &Server{Config: &config.Config{Server: config.DefaultServerConfig}}
    mux := s.MonitoringServeMux()

    req := httptest.NewRequest("GET", "/health", nil)

    r := httptest.NewRecorder()
    mux.ServeHTTP(r, req)
    res := r.Result()
    require.Equal(t, 200, res.StatusCode)
    res.Body.Close()
}

func TestInvalidClientConfig(t *testing.T) {
    _, testRoot := setupServer(t)

    cfg := clientConfig(t, testRoot)
    cfg.User = "unknown"
    _, err := ssh.Dial("tcp", serverURL, cfg)
    require.Error(t, err)
}

func TestInvalidServerConfig(t *testing.T) {
    s := &Server{Config: &config.Config{Server: config.ServerConfig{Listen: "invalid"}}}
    err := s.ListenAndServe(context.Background())

    require.Error(t, err)
    require.Equal(t, "failed to listen for connection: listen tcp: address invalid: missing port in address", err.Error())
    require.NoError(t, s.Shutdown())
}

func TestClosingHangedConnections(t *testing.T) {
    ctx, cancel := context.WithCancel(context.Background())
    defer cancel()

    s, testRoot := setupServerWithContext(ctx, t, nil)

    unauthenticatedRequestStatus := make(chan string)
    completed := make(chan bool)

    clientCfg := clientConfig(t, testRoot)
    clientCfg.HostKeyCallback = func(_ string, _ net.Addr, _ ssh.PublicKey) error {
        unauthenticatedRequestStatus <- "authentication-started"
        <-completed // Wait infinitely

        return nil
    }

    go func() {
        // Start an SSH connection that never ends
        ssh.Dial("tcp", serverURL, clientCfg)
    }()

    require.Equal(t, "authentication-started", <-unauthenticatedRequestStatus)

    require.NoError(t, s.Shutdown())
    cancel()
    verifyStatus(t, s, StatusClosed)
}

func TestLoginGraceTime(t *testing.T) {
    cfg := &config.Config{
        Server: config.ServerConfig{
            LoginGraceTime: config.YamlDuration(50 * time.Millisecond),
        },
    }
    s, testRoot := setupServerWithConfig(t, cfg)

    unauthenticatedRequestStatus := make(chan string)
    completed := make(chan bool)

    clientCfg := clientConfig(t, testRoot)
    clientCfg.HostKeyCallback = func(_ string, _ net.Addr, _ ssh.PublicKey) error {
        unauthenticatedRequestStatus <- "authentication-started"
        <-completed // Wait infinitely

        return nil
    }

    go func() {
        // Start an SSH connection that never ends
        ssh.Dial("tcp", serverURL, clientCfg)
    }()

    require.Equal(t, "authentication-started", <-unauthenticatedRequestStatus)

    // Shutdown the server and verify that it's closed
    // If LoginGraceTime doesn't work, then the connection that runs infinitely, will stop it from closing.
    // The close won't happen until the context is canceled like in TestClosingHangedConnections
    require.NoError(t, s.Shutdown())
    verifyStatus(t, s, StatusClosed)
}

func TestExtractMetaDataFromContext(t *testing.T) {
    username := "alex-doe"
    rootNameSpace := "flightjs"
    project := fmt.Sprintf("%s/Flight", rootNameSpace)
    projectID := 1
    rootNamespaceID := 2
    ctx := context.WithValue(context.Background(), logInfo{}, command.NewLogData(project, username, projectID, rootNamespaceID))

    data := extractLogDataFromContext(ctx)

    require.Equal(t, command.LogData{Username: username, Meta: command.LogMetadata{Project: project, RootNamespace: rootNameSpace, ProjectID: projectID, RootNamespaceID: rootNamespaceID}}, data)
}

func TestExtractMetaDataFromContextWithoutMetaData(t *testing.T) {
    data := extractLogDataFromContext(context.Background())

    require.Equal(t, command.LogData{}, data)
}

func TestExtractMetaDataFromNilContext(t *testing.T) {
    var ctx context.Context

    data := extractLogDataFromContext(ctx)

    require.Equal(t, command.LogData{}, data)
}

func setupServer(t *testing.T) (*Server, string) {
    t.Helper()

    return setupServerWithConfig(t, nil)
}

func setupServerWithConfig(t *testing.T, cfg *config.Config) (*Server, string) {
    t.Helper()

    return setupServerWithContext(context.Background(), t, cfg)
}

func setupServerWithContext(ctx context.Context, t *testing.T, cfg *config.Config) (*Server, string) {
    t.Helper()

    testRoot := testhelper.PrepareTestRootDir(t)

    requests := []testserver.TestRequestHandler{
        {
            Path: "/api/v4/internal/authorized_keys",
            Handler: func(w http.ResponseWriter, r *http.Request) {
                correlationID = r.Header.Get("X-Request-Id")

                assert.NotEmpty(t, correlationID)
                assert.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For"))

                fmt.Fprint(w, `{"id": 1000, "key": "key"}`)
            },
        }, {
            Path: "/api/v4/internal/discover",
            Handler: func(w http.ResponseWriter, r *http.Request) {
                assert.Equal(t, correlationID, r.Header.Get("X-Request-Id"))
                assert.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For"))

                fmt.Fprint(w, `{"id": 1000, "name": "Test User", "username": "test-user"}`)
            },
        },
    }

    url := testserver.StartSocketHTTPServer(t, requests)

    if cfg == nil {
        cfg = &config.Config{}
    }

    // All things that don't need to be configurable in tests yet
    cfg.GitlabUrl = url
    cfg.RootDir = "/tmp"
    cfg.User = user
    cfg.Server.Listen = serverURL
    cfg.Server.ConcurrentSessionsLimit = 1
    cfg.Server.HostKeyFiles = []string{path.Join(testRoot, "certs/valid/server.key")}

    s, err := NewServer(cfg)
    require.NoError(t, err)

    go func() { s.ListenAndServe(ctx) }()
    //nolint:godox // NOTE: Changing the below to { require.NoError(t, s.Shutdown()) } results in failed tests
    t.Cleanup(func() { s.Shutdown() })

    verifyStatus(t, s, StatusReady)

    return s, testRoot
}

func clientConfig(t *testing.T, testRoot string) *ssh.ClientConfig {
    keyRaw, _ := os.ReadFile(path.Join(testRoot, "certs/valid/server_authorized_key"))
    pKey, _, _, _, err := ssh.ParseAuthorizedKey(keyRaw) //nolint:dogsled
    require.NoError(t, err)

    key, err := os.ReadFile(path.Join(testRoot, "certs/client/key.pem"))
    require.NoError(t, err)
    signer, err := ssh.ParsePrivateKey(key)
    require.NoError(t, err)

    return &ssh.ClientConfig{
        User: user,
        Auth: []ssh.AuthMethod{
            ssh.PublicKeys(signer),
        },
        HostKeyCallback: ssh.FixedHostKey(pKey),
    }
}

func holdSession(t *testing.T, c *ssh.Client) {
    session, err := c.NewSession()
    require.NoError(t, err)
    defer session.Close()

    output, err := session.Output("discover")
    require.NoError(t, err)
    require.Equal(t, "Welcome to GitLab, @test-user!\n", string(output))
}

func verifyStatus(t *testing.T, s *Server, st status) {
    require.Eventually(t, func() bool { return s.getStatus() == st }, 2*time.Second, time.Millisecond)
}