client/testserver/gitalyserver.go
package testserver
import (
"context"
"fmt"
"net"
"os"
"path"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitaly/v16/client"
pb "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb"
"gitlab.com/gitlab-org/labkit/log"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
)
type TestGitalyServer struct {
ReceivedMD metadata.MD
pb.UnimplementedSSHServiceServer
}
func (s *TestGitalyServer) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error {
req, err := stream.Recv()
if err != nil {
return err
}
s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context())
response := []byte("ReceivePack: " + req.GlId + " " + req.Repository.GlRepository)
return stream.Send(&pb.SSHReceivePackResponse{Stdout: response})
}
func (s *TestGitalyServer) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error {
req, err := stream.Recv()
if err != nil {
return err
}
s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context())
response := []byte("UploadPack: " + req.Repository.GlRepository)
return stream.Send(&pb.SSHUploadPackResponse{Stdout: response})
}
func (s *TestGitalyServer) SSHUploadPackWithSidechannel(ctx context.Context, req *pb.SSHUploadPackWithSidechannelRequest) (*pb.SSHUploadPackWithSidechannelResponse, error) {
conn, err := client.OpenServerSidechannel(ctx)
if err != nil {
return nil, err
}
defer conn.Close()
s.ReceivedMD, _ = metadata.FromIncomingContext(ctx)
response := []byte("SSHUploadPackWithSidechannel: " + req.Repository.GlRepository)
if _, err := fmt.Fprintf(conn, "%04x\x01%s", len(response)+5, response); err != nil {
return nil, err
}
if err := conn.Close(); err != nil {
return nil, err
}
return &pb.SSHUploadPackWithSidechannelResponse{}, nil
}
func (s *TestGitalyServer) SSHUploadArchive(stream pb.SSHService_SSHUploadArchiveServer) error {
req, err := stream.Recv()
if err != nil {
return err
}
s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context())
response := []byte("UploadArchive: " + req.Repository.GlRepository)
return stream.Send(&pb.SSHUploadArchiveResponse{Stdout: response})
}
func StartGitalyServer(t *testing.T, network string) (string, *TestGitalyServer) {
t.Helper()
switch network {
case "unix":
// We can't use t.TempDir() here because it will create a directory that
// far exceeds the 108 character limit which results in the socket failing
// to be created.
//
// See https://gitlab.com/gitlab-org/gitlab-shell/-/issues/696#note_1664726924
// for more detail.
tempDir, err := os.MkdirTemp("", "gitaly")
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, os.RemoveAll(tempDir)) })
gitalySocketPath := path.Join(tempDir, "gitaly.sock")
err = os.MkdirAll(filepath.Dir(gitalySocketPath), 0700)
require.NoError(t, err)
addr, testServer := doStartTestServer(t, "unix", gitalySocketPath)
return fmt.Sprintf("unix:%s", addr), testServer
case "tcp":
addr, testServer := doStartTestServer(t, "tcp", "127.0.0.1:0")
return fmt.Sprintf("tcp://%s", addr), testServer
case "dns":
addr, testServer := doStartTestServer(t, "tcp", "127.0.0.1:0")
// gRPC URL with DNS scheme follows this format: https://grpc.github.io/grpc/core/md_doc_naming.html
// When the authority is dropped, the URL have 3 splashes.
return fmt.Sprintf("dns:///%s", addr), testServer
default:
panic(fmt.Sprintf("Unsupported network %s", network))
}
}
func doStartTestServer(t *testing.T, network string, path string) (string, *TestGitalyServer) {
server := grpc.NewServer(
client.SidechannelServer(log.ContextLogger(context.Background()), insecure.NewCredentials()),
)
listener, err := net.Listen(network, path)
require.NoError(t, err)
testServer := TestGitalyServer{}
pb.RegisterSSHServiceServer(server, &testServer)
go func() {
require.NoError(t, server.Serve(listener))
}()
t.Cleanup(func() { server.GracefulStop() })
return listener.Addr().String(), &testServer
}