gitlabhq/gitlab-shell

View on GitHub
cmd/gitlab-sshd/acceptance_test.go

Summary

Maintainability
A
0 mins
Test Coverage
package main_test

import (
    "bufio"
    "context"
    "crypto/ed25519"
    "encoding/json"
    "encoding/pem"
    "fmt"
    "io"
    "net"
    "net/http"
    "net/http/httptest"
    "os"
    "os/exec"
    "path/filepath"
    "regexp"
    "runtime"
    "strings"
    "testing"

    "github.com/mikesmitty/edkey"
    "github.com/pires/go-proxyproto"
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
    gitalyClient "gitlab.com/gitlab-org/gitaly/v16/client"
    pb "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb"
    "gitlab.com/gitlab-org/gitaly/v16/streamio"
    "golang.org/x/crypto/ssh"
    "google.golang.org/grpc"

    "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/testhelper"
)

var (
    sshdPath       = ""
    gitalyConnInfo *gitalyConnectionInfo
)

const (
    testRepo          = "test-gitlab-shell/gitlab-test.git"
    testRepoNamespace = "test-gitlab-shell"
    testRepoImportURL = "https://gitlab.com/gitlab-org/gitlab-test.git"
)

type gitalyConnectionInfo struct {
    Address string `json:"address"`
    Storage string `json:"storage"`
}

func init() {
    rootDir := rootDir()
    sshdPath = filepath.Join(rootDir, "bin", "gitlab-sshd")

    if _, err := os.Stat(sshdPath); os.IsNotExist(err) {
        panic(fmt.Errorf("cannot find executable %s. Please run 'make compile'", sshdPath))
    }

    gci, exists := os.LookupEnv("GITALY_CONNECTION_INFO")
    if exists {
        json.Unmarshal([]byte(gci), &gitalyConnInfo)
    }
}

func rootDir() string {
    _, currentFile, _, ok := runtime.Caller(0)
    if !ok {
        panic(fmt.Errorf("rootDir: calling runtime.Caller failed"))
    }

    return filepath.Join(filepath.Dir(currentFile), "..", "..")
}

func ensureGitalyRepository(t *testing.T) (*grpc.ClientConn, *pb.Repository) {
    if os.Getenv("GITALY_CONNECTION_INFO") == "" {
        t.Skip("GITALY_CONNECTION_INFO is not set")
    }

    conn, err := gitalyClient.Dial(gitalyConnInfo.Address, gitalyClient.DefaultDialOpts)
    require.NoError(t, err)

    repository := pb.NewRepositoryServiceClient(conn)

    glRepository := &pb.Repository{StorageName: gitalyConnInfo.Storage, RelativePath: testRepo}

    // Remove the test repository before running the tests
    removeReq := &pb.RemoveRepositoryRequest{Repository: glRepository}
    // Ignore the error because the repository may not exist
    repository.RemoveRepository(context.Background(), removeReq)

    createReq := &pb.CreateRepositoryFromURLRequest{Repository: glRepository, Url: testRepoImportURL}
    _, err = repository.CreateRepositoryFromURL(context.Background(), createReq)
    require.NoError(t, err)

    return conn, glRepository
}

func startGitOverHTTPServer(t *testing.T) string {
    ctx := context.Background()
    conn, glRepository := ensureGitalyRepository(t)
    client := pb.NewSmartHTTPServiceClient(conn)

    requests := []testserver.TestRequestHandler{
        {
            Path: "/info/refs",
            Handler: func(w http.ResponseWriter, r *http.Request) {
                rpcRequest := &pb.InfoRefsRequest{
                    Repository: glRepository,
                }

                var reader io.Reader
                switch r.URL.Query().Get("service") {
                case "git-receive-pack":
                    stream, err := client.InfoRefsReceivePack(ctx, rpcRequest)
                    assert.NoError(t, err)
                    reader = streamio.NewReader(func() ([]byte, error) {
                        resp, err := stream.Recv()
                        return resp.GetData(), err
                    })
                default:
                    t.FailNow()
                }

                _, err := io.Copy(w, reader)
                assert.NoError(t, err)
            },
        },
        {
            Path: "/git-receive-pack",
            Handler: func(_ http.ResponseWriter, r *http.Request) {
                body, err := io.ReadAll(r.Body)
                assert.NoError(t, err)
                defer r.Body.Close()

                assert.Equal(t, "0000", string(body))
            },
        },
    }

    return testserver.StartHTTPServer(t, requests)
}

func buildAllowedResponse(t *testing.T, filename string) string {
    testRoot := testhelper.PrepareTestRootDir(t)

    body, err := os.ReadFile(filepath.Join(testRoot, filename))
    require.NoError(t, err)

    response := strings.Replace(string(body), "GITALY_REPOSITORY", testRepo, 1)

    if gitalyConnInfo != nil {
        response = strings.Replace(response, "GITALY_ADDRESS", gitalyConnInfo.Address, 1)
        response = strings.Replace(response, "GITALY_STORAGE", gitalyConnInfo.Storage, 1)
    }

    return response
}

type customHandler struct {
    url    string
    caller http.HandlerFunc
}

func successAPI(t *testing.T, handlers ...customHandler) http.Handler {
    t.Helper()

    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        t.Logf("gitlab-api-mock: received request: %s %s", r.Method, r.RequestURI)
        w.Header().Set("Content-Type", "application/json")

        url := r.URL.EscapedPath()

        for _, handler := range handlers {
            if url == handler.url {
                handler.caller(w, r)

                return
            }
        }

        switch url {
        case "/api/v4/internal/authorized_keys":
            fmt.Fprintf(w, `{"id":1, "key":"%s"}`, r.FormValue("key"))
        case "/api/v4/internal/allowed":
            response := buildAllowedResponse(t, "responses/allowed_without_console_messages.json")

            _, err := fmt.Fprint(w, response)
            assert.NoError(t, err)
        case "/api/v4/internal/shellhorse/git_audit_event":
            w.WriteHeader(http.StatusOK)
            return
        default:
            t.Logf("Unexpected request to successAPI: %s", r.URL.EscapedPath())
            t.FailNow()
        }
    })
}

func genServerConfig(gitlabURL, hostKeyPath string) []byte {
    return []byte(`---
user: "git"
log_file: ""
log_format: json
secret: "0123456789abcdef"
gitlab_url: "` + gitlabURL + `"
sshd:
  listen: "127.0.0.1:0"
  proxy_protocol: true
  web_listen: ""
  host_key_files:
    - "` + hostKeyPath + `"`)
}

func buildClient(t *testing.T, addr string, hostKey ed25519.PublicKey) *ssh.Client {
    t.Helper()

    pubKey, err := ssh.NewPublicKey(hostKey)
    require.NoError(t, err)

    _, clientPrivKey, err := ed25519.GenerateKey(nil)
    require.NoError(t, err)

    clientSigner, err := ssh.NewSignerFromKey(clientPrivKey)
    require.NoError(t, err)

    // Use the proxy protocol to spoof our client address
    target, err := net.ResolveTCPAddr("tcp", addr)
    require.NoError(t, err)
    conn, err := net.DialTCP("tcp", nil, target)
    require.NoError(t, err)
    t.Cleanup(func() { conn.Close() })

    // Create a proxyprotocol header or use HeaderProxyFromAddrs() if you
    // have two conn's
    header := &proxyproto.Header{
        Version:           2,
        Command:           proxyproto.PROXY,
        TransportProtocol: proxyproto.TCPv4,
        SourceAddr: &net.TCPAddr{
            IP:   net.ParseIP("10.1.1.1"),
            Port: 1000,
        },
        DestinationAddr: target,
    }
    // After the connection was created write the proxy headers first
    _, err = header.WriteTo(conn)
    require.NoError(t, err)

    sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, &ssh.ClientConfig{
        User:            "git",
        Auth:            []ssh.AuthMethod{ssh.PublicKeys(clientSigner)},
        HostKeyCallback: ssh.FixedHostKey(pubKey),
    })
    require.NoError(t, err)

    client := ssh.NewClient(sshConn, chans, reqs)
    t.Cleanup(func() { client.Close() })

    return client
}

func configureSSHD(t *testing.T, apiServer string) (string, ed25519.PublicKey) {
    t.Helper()

    tempDir := t.TempDir()

    configFile := filepath.Join(tempDir, "config.yml")
    hostKeyFile := filepath.Join(tempDir, "hostkey")

    pub, priv, err := ed25519.GenerateKey(nil)
    require.NoError(t, err)

    configFileData := genServerConfig(apiServer, hostKeyFile)
    require.NoError(t, os.WriteFile(configFile, configFileData, 0644))

    block := &pem.Block{Type: "OPENSSH PRIVATE KEY", Bytes: edkey.MarshalED25519PrivateKey(priv)}
    hostKeyData := pem.EncodeToMemory(block)
    require.NoError(t, os.WriteFile(hostKeyFile, hostKeyData, 0400))

    return tempDir, pub
}

func startSSHD(t *testing.T, dir string) string {
    t.Helper()

    // We need to scan the first few lines of stderr to get the listen address.
    // Once we've learned it, we'll start a goroutine to copy everything to
    // the real stderr
    pr, pw := io.Pipe()
    t.Cleanup(func() { pr.Close() })
    t.Cleanup(func() { pw.Close() })

    scanner := bufio.NewScanner(pr)
    extractor := regexp.MustCompile(`"tcp_address":"([0-9a-f\[\]\.:]+)"`)

    ctx, cancel := context.WithCancel(context.Background())
    cmd := exec.CommandContext(ctx, sshdPath, "-config-dir", dir)
    cmd.Stdout = os.Stdout
    cmd.Stderr = pw
    require.NoError(t, cmd.Start())
    t.Logf("gitlab-sshd: Start(): success")
    t.Cleanup(func() { t.Logf("gitlab-sshd: Wait(): %v", cmd.Wait()) })
    t.Cleanup(cancel)

    var listenAddr string
    for scanner.Scan() {
        if matches := extractor.FindSubmatch(scanner.Bytes()); len(matches) == 2 {
            listenAddr = string(matches[1])
            break
        }
    }
    require.NotEmpty(t, listenAddr, "Couldn't extract listen address from gitlab-sshd")

    go io.Copy(os.Stderr, pr)

    return listenAddr
}

// Starts an instance of gitlab-sshd with the given arguments, returning an SSH
// client already connected to it
func runSSHD(t *testing.T, apiHandler http.Handler) *ssh.Client {
    t.Helper()

    // Set up a stub gitlab server
    apiServer := httptest.NewServer(apiHandler)
    t.Logf("gitlab-api-mock: started: url=%q", apiServer.URL)
    t.Cleanup(func() {
        apiServer.Close()
        t.Logf("gitlab-api-mock: closed")
    })

    dir, hostKey := configureSSHD(t, apiServer.URL)
    listenAddr := startSSHD(t, dir)

    return buildClient(t, listenAddr, hostKey)
}

func TestDiscoverSuccess(t *testing.T) {
    handler := customHandler{
        url: "/api/v4/internal/discover",
        caller: func(w http.ResponseWriter, _ *http.Request) {
            fmt.Fprint(w, `{"id": 1000, "name": "Test User", "username": "test-user"}`)
        },
    }
    client := runSSHD(t, successAPI(t, handler))

    session, err := client.NewSession()
    require.NoError(t, err)
    defer session.Close()

    output, err := session.Output("discover")
    require.NoError(t, err)
    require.Equal(t, "Welcome to GitLab, @test-user!\n", string(output))
}

func TestPersonalAccessTokenSuccess(t *testing.T) {
    handler := customHandler{
        url: "/api/v4/internal/personal_access_token",
        caller: func(w http.ResponseWriter, _ *http.Request) {
            fmt.Fprint(w, `{"success": true, "token": "testtoken", "scopes": ["api"], "expires_at": "9001-01-01"}`)
        },
    }
    client := runSSHD(t, successAPI(t, handler))

    session, err := client.NewSession()
    require.NoError(t, err)
    defer session.Close()

    output, err := session.Output("personal_access_token test api")
    require.NoError(t, err)
    require.Equal(t, "Token:   testtoken\nScopes:  api\nExpires: 9001-01-01\n", string(output))
}

func TestTwoFactorAuthRecoveryCodesSuccess(t *testing.T) {
    handler := customHandler{
        url: "/api/v4/internal/two_factor_recovery_codes",
        caller: func(w http.ResponseWriter, _ *http.Request) {
            fmt.Fprint(w, `{"success": true, "recovery_codes": ["code1", "code2"]}`)
        },
    }
    client := runSSHD(t, successAPI(t, handler))
    session, stdin, stdout := newSession(t, client)

    reader := bufio.NewReader(stdout)

    err := session.Start("2fa_recovery_codes")
    require.NoError(t, err)

    line, err := reader.ReadString('\n')
    require.NoError(t, err)
    require.Equal(t, "Are you sure you want to generate new two-factor recovery codes?\n", line)

    line, err = reader.ReadString('\n')
    require.NoError(t, err)
    require.Equal(t, "Any existing recovery codes you saved will be invalidated. (yes/no)\n", line)

    _, err = fmt.Fprintln(stdin, "yes")
    require.NoError(t, err)

    output, err := io.ReadAll(stdout)
    require.NoError(t, err)
    require.Equal(t, `
Your two-factor authentication recovery codes are:

code1
code2

During sign in, use one of the codes above when prompted for
your two-factor code. Then, visit your Profile Settings and add
a new device so you do not lose access to your account again.
`, string(output))
}

func TwoFactorAuthVerifySuccess(t *testing.T) {
    handler := customHandler{
        url: "/api/v4/internal/two_factor_otp_check",
        caller: func(w http.ResponseWriter, _ *http.Request) {
            fmt.Fprint(w, `{"success": true}`)
        },
    }
    client := runSSHD(t, successAPI(t, handler))
    session, stdin, stdout := newSession(t, client)

    reader := bufio.NewReader(stdout)

    err := session.Start("2fa_verify")
    require.NoError(t, err)

    line, err := reader.ReadString('\n')
    require.NoError(t, err)
    require.Equal(t, "OTP: ", line)

    _, err = fmt.Fprintln(stdin, "otp123")
    require.NoError(t, err)

    output, err := io.ReadAll(stdout)
    require.NoError(t, err)
    require.Equal(t, "OTP validation successful. Git operations are now allowed.\n", string(output))
}

func TestGitLfsAuthenticateSuccess(t *testing.T) {
    handler := customHandler{
        url: "/api/v4/internal/lfs_authenticate",
        caller: func(w http.ResponseWriter, _ *http.Request) {
            fmt.Fprint(w, `{"username": "test-user", "lfs_token": "testlfstoken", "repo_path": "foo", "expires_in": 7200}`)
        },
    }
    client := runSSHD(t, successAPI(t, handler))

    session, err := client.NewSession()
    require.NoError(t, err)
    defer session.Close()

    output, err := session.Output("git-lfs-authenticate test-user/repo.git download")

    require.NoError(t, err)
    require.Equal(t, `{"header":{"Authorization":"Basic dGVzdC11c2VyOnRlc3RsZnN0b2tlbg=="},"href":"/info/lfs","expires_in":7200}
`, string(output))
}

func TestGitReceivePackSuccess(t *testing.T) {
    ensureGitalyRepository(t)

    client := runSSHD(t, successAPI(t))
    session, stdin, stdout := newSession(t, client)

    err := session.Start(fmt.Sprintf("git-receive-pack %s", testRepo))
    require.NoError(t, err)

    // Gracefully close connection
    _, err = fmt.Fprintln(stdin, "0000")
    require.NoError(t, err)
    stdin.Close()

    output, err := io.ReadAll(stdout)
    require.NoError(t, err)

    outputLines := strings.Split(string(output), "\n")

    for i := 0; i < (len(outputLines) - 1); i++ {
        require.Regexp(t, "^[0-9a-f]{44} refs/(heads|tags)/[^ ]+", outputLines[i])
    }

    require.Equal(t, "0000", outputLines[len(outputLines)-1])
}

func TestGeoGitReceivePackSuccess(t *testing.T) {
    url := startGitOverHTTPServer(t)

    handler := customHandler{
        url: "/api/v4/internal/allowed",
        caller: func(w http.ResponseWriter, _ *http.Request) {
            response := buildAllowedResponse(t, "responses/allowed_with_geo_push_payload.json")
            response = strings.Replace(response, "PRIMARY_REPO", url, 1)

            w.WriteHeader(300)
            _, err := fmt.Fprint(w, response)
            assert.NoError(t, err)
        },
    }
    client := runSSHD(t, successAPI(t, handler))
    session, stdin, stdout := newSession(t, client)

    err := session.Start(fmt.Sprintf("git-receive-pack %s", testRepo))
    require.NoError(t, err)

    // Gracefully close connection
    _, err = fmt.Fprintln(stdin, "0000")
    require.NoError(t, err)
    stdin.Close()

    output, err := io.ReadAll(stdout)
    require.NoError(t, err)

    outputLines := strings.Split(string(output), "\n")

    for i := 0; i < (len(outputLines) - 1); i++ {
        require.Regexp(t, "^[0-9a-f]{44} refs/(heads|tags)/[^ ]+", outputLines[i])
    }

    require.Equal(t, "0000", outputLines[len(outputLines)-1])
}

func TestGitUploadPackSuccess(t *testing.T) {
    ensureGitalyRepository(t)

    client := runSSHD(t, successAPI(t))
    defer client.Close()

    numberOfSessions := 3
    for sessionNumber := 0; sessionNumber < numberOfSessions; sessionNumber++ {
        t.Run(fmt.Sprintf("session #%v", sessionNumber), func(t *testing.T) {
            session, stdin, stdout := newSession(t, client)
            reader := bufio.NewReader(stdout)

            err := session.Start(fmt.Sprintf("git-upload-pack %s", testRepo))
            require.NoError(t, err)

            line, err := reader.ReadString('\n')
            require.NoError(t, err)
            require.Regexp(t, "^[0-9a-f]{44} HEAD.+", line)

            // Gracefully close connection
            _, err = fmt.Fprintln(stdin, "0000")
            require.NoError(t, err)

            output, err := io.ReadAll(stdout)
            require.NoError(t, err)

            outputLines := strings.Split(string(output), "\n")

            for i := 1; i < (len(outputLines) - 1); i++ {
                require.Regexp(t, "^[0-9a-f]{44} refs/(heads|tags)/[^ ]+", outputLines[i])
            }

            require.Equal(t, "0000", outputLines[len(outputLines)-1])
        })
    }
}

func TestGitUploadArchiveSuccess(t *testing.T) {
    ensureGitalyRepository(t)

    client := runSSHD(t, successAPI(t))
    session, stdin, stdout := newSession(t, client)
    reader := bufio.NewReader(stdout)

    err := session.Start(fmt.Sprintf("git-upload-archive %s", testRepo))
    require.NoError(t, err)

    _, err = fmt.Fprintln(stdin, "0012argument HEAD\n0000")
    require.NoError(t, err)

    line, err := reader.ReadString('\n')
    require.Equal(t, "0008ACK\n", line)
    require.NoError(t, err)

    // Gracefully close connection
    _, err = fmt.Fprintln(stdin, "0000")
    require.NoError(t, err)

    output, err := io.ReadAll(stdout)
    require.NoError(t, err)

    t.Logf("output: %q", output)
    require.Equal(t, []byte("0000"), output[len(output)-4:])
}

func newSession(t *testing.T, client *ssh.Client) (*ssh.Session, io.WriteCloser, io.Reader) {
    session, err := client.NewSession()
    require.NoError(t, err)

    stdin, err := session.StdinPipe()
    require.NoError(t, err)

    stdout, err := session.StdoutPipe()
    require.NoError(t, err)

    t.Cleanup(func() {
        session.Close()
    })

    return session, stdin, stdout
}