gitlabhq/gitlab-shell

View on GitHub
internal/handler/exec_test.go

Summary

Maintainability
A
1 hr
Test Coverage
package handler

import (
    "context"
    "errors"
    "testing"

    "github.com/stretchr/testify/require"
    "google.golang.org/grpc"
    grpccodes "google.golang.org/grpc/codes"
    "google.golang.org/grpc/metadata"
    grpcstatus "google.golang.org/grpc/status"

    pb "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/gitlabnet/accessverifier"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/sshenv"
    "google.golang.org/protobuf/proto"
    "google.golang.org/protobuf/types/known/anypb"
    "google.golang.org/protobuf/types/known/durationpb"
)

func makeHandler(t *testing.T, err error) func(context.Context, *grpc.ClientConn) (int32, error) {
    return func(ctx context.Context, client *grpc.ClientConn) (int32, error) {
        require.NotNil(t, ctx)
        require.NotNil(t, client)

        return 0, err
    }
}

func TestRunGitalyCommand(t *testing.T) {
    cmd := NewGitalyCommand(
        newConfig(),
        string(commandargs.UploadPack),
        &accessverifier.Response{
            Gitaly: accessverifier.Gitaly{Address: "tcp://localhost:9999"},
        },
    )

    err := cmd.RunGitalyCommand(context.Background(), makeHandler(t, nil))
    require.NoError(t, err)

    expectedErr := errors.New("error")
    err = cmd.RunGitalyCommand(context.Background(), makeHandler(t, expectedErr))
    require.Equal(t, err, expectedErr)
}

func TestCachingOfGitalyConnections(t *testing.T) {
    ctx := context.Background()
    cfg := newConfig()
    response := &accessverifier.Response{
        Username: "user",
        Gitaly: accessverifier.Gitaly{
            Address: "tcp://localhost:9999",
            Token:   "token",
        },
    }

    cmd := NewGitalyCommand(cfg, string(commandargs.UploadPack), response)

    conn, err := cmd.getConn(ctx)
    require.NoError(t, err)

    // Reuses connection for different users
    response.Username = "another-user"
    cmd = NewGitalyCommand(cfg, string(commandargs.UploadPack), response)
    newConn, err := cmd.getConn(ctx)
    require.NoError(t, err)
    require.Equal(t, conn, newConn)
}

func TestMissingGitalyAddress(t *testing.T) {
    cmd := GitalyCommand{Config: newConfig()}

    err := cmd.RunGitalyCommand(context.Background(), makeHandler(t, nil))
    require.EqualError(t, err, "no gitaly_address given")
}

func TestUnavailableGitalyErr(t *testing.T) {
    cmd := NewGitalyCommand(
        newConfig(),
        string(commandargs.UploadPack),
        &accessverifier.Response{
            Gitaly: accessverifier.Gitaly{Address: "tcp://localhost:9999"},
        },
    )

    err := cmd.RunGitalyCommand(context.Background(), makeHandler(t, grpcstatus.Error(grpccodes.Unavailable, "error")))
    require.Equal(t, err, grpcstatus.Error(grpccodes.Unavailable, "The git server, Gitaly, is not available at this time. Please contact your administrator."))
}

func TestGitalyLimitErr(t *testing.T) {
    cmd := NewGitalyCommand(
        newConfig(),
        string(commandargs.UploadPack),
        &accessverifier.Response{
            Gitaly: accessverifier.Gitaly{Address: "tcp://localhost:9999"},
        },
    )
    limitErr := errWithDetail(t, &pb.LimitError{
        ErrorMessage: "concurrency queue wait time reached",
        RetryAfter:   durationpb.New(0)})
    err := cmd.RunGitalyCommand(context.Background(), makeHandler(t, limitErr))
    require.Equal(t, err, grpcstatus.Error(grpccodes.Unavailable, "GitLab is currently unable to handle this request due to load."))
}

func TestRunGitalyCommandMetadata(t *testing.T) {
    tests := []struct {
        name string
        gc   *GitalyCommand
        want map[string]string
    }{
        {
            name: "gitaly_feature_flags",
            gc: NewGitalyCommand(
                newConfig(),
                string(commandargs.UploadPack),
                &accessverifier.Response{
                    Gitaly: accessverifier.Gitaly{
                        Address: "tcp://localhost:9999",
                        Features: map[string]string{
                            "gitaly-feature-cache_invalidator":        "true",
                            "other-ff":                                "true",
                            "gitaly-feature-inforef_uploadpack_cache": "false",
                        },
                    },
                },
            ),
            want: map[string]string{
                "gitaly-feature-cache_invalidator":        "true",
                "gitaly-feature-inforef_uploadpack_cache": "false",
            },
        },
    }

    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            cmd := tt.gc

            err := cmd.RunGitalyCommand(context.Background(), func(ctx context.Context, _ *grpc.ClientConn) (int32, error) {
                md, exists := metadata.FromOutgoingContext(ctx)
                require.True(t, exists)
                require.Len(t, md, len(tt.want))

                for k, v := range tt.want {
                    values := md.Get(k)
                    require.Len(t, values, 1)
                    require.Equal(t, v, values[0])
                }

                return 0, nil
            })

            require.NoError(t, err)
        })
    }
}

func TestPrepareContext(t *testing.T) {
    tests := []struct {
        name     string
        gc       *GitalyCommand
        env      sshenv.Env
        repo     *pb.Repository
        response *accessverifier.Response
        want     map[string]string
    }{
        {
            name: "client_identity",
            gc: NewGitalyCommand(
                &config.Config{},
                string(commandargs.UploadPack),
                &accessverifier.Response{
                    KeyID:    1,
                    KeyType:  "key",
                    UserID:   "6",
                    Username: "jane.doe",
                    Gitaly: accessverifier.Gitaly{
                        Address: "tcp://localhost:9999",
                    },
                },
            ),
            env: sshenv.Env{
                GitProtocolVersion: "protocol",
                IsSSHConnection:    true,
                RemoteAddr:         "10.0.0.1",
            },
            repo: &pb.Repository{
                StorageName:                   "default",
                RelativePath:                  "@hashed/5f/9c/5f9c4ab08cac7457e9111a30e4664920607ea2c115a1433d7be98e97e64244ca.git",
                GitObjectDirectory:            "path/to/git_object_directory",
                GitAlternateObjectDirectories: []string{"path/to/git_alternate_object_directory"},
                GlRepository:                  "project-26",
                GlProjectPath:                 "group/private",
            },
            want: map[string]string{
                "key_id":    "1",
                "key_type":  "key",
                "user_id":   "6",
                "username":  "jane.doe",
                "remote_ip": "10.0.0.1",
            },
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            ctx := context.Background()

            ctx, cancel := tt.gc.PrepareContext(ctx, tt.repo, tt.env)
            defer cancel()

            md, exists := metadata.FromOutgoingContext(ctx)
            require.True(t, exists)
            require.Len(t, md, len(tt.want))

            for k, v := range tt.want {
                values := md.Get(k)
                require.Len(t, values, 1)
                require.Equal(t, v, values[0])
            }

        })
    }
}

func newConfig() *config.Config {
    cfg := &config.Config{}
    cfg.GitalyClient.InitSidechannelRegistry(context.Background())
    return cfg
}

// errWithDetail adds the given details to the error if it is a gRPC status whose code is not OK.
func errWithDetail(t *testing.T, detail proto.Message) error {
    st := grpcstatus.New(grpccodes.Unavailable, "too busy")

    proto := st.Proto()
    marshaled, err := anypb.New(detail)
    require.NoError(t, err)

    proto.Details = append(proto.Details, marshaled)

    return grpcstatus.ErrorProto(proto)
}