gitlabhq/gitlab-shell

View on GitHub
internal/sshd/session.go

Summary

Maintainability
A
1 hr
Test Coverage
package sshd

import (
    "context"
    "errors"
    "fmt"
    "reflect"
    "time"

    "gitlab.com/gitlab-org/labkit/log"
    "golang.org/x/crypto/ssh"
    grpccodes "google.golang.org/grpc/codes"
    grpcstatus "google.golang.org/grpc/status"

    shellCmd "gitlab.com/gitlab-org/gitlab-shell/v14/cmd/gitlab-shell/command"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/shared/disallowedcommand"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/console"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/metrics"
    "gitlab.com/gitlab-org/gitlab-shell/v14/internal/sshenv"
)

type session struct {
    // State set up by the connection
    cfg                 *config.Config
    channel             ssh.Channel
    gitlabKeyId         string
    gitlabKrb5Principal string
    gitlabUsername      string
    namespace           string
    remoteAddr          string

    // State managed by the session
    execCmd            string
    gitProtocolVersion string
    started            time.Time
}

type execRequest struct {
    Command string
}

type envRequest struct {
    Name  string
    Value string
}

type exitStatusReq struct {
    ExitStatus uint32
}

func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (context.Context, error) {
    ctxWithLogData := ctx
    ctxlog := log.ContextLogger(ctx)

    ctxlog.Debug("session: handle: entering request loop")

    var err error
    for req := range requests {
        sessionLog := ctxlog.WithFields(log.Fields{
            "bytesize":   len(req.Payload),
            "type":       req.Type,
            "want_reply": req.WantReply,
        })
        sessionLog.Debug("session: handle: request received")

        var shouldContinue bool
        switch req.Type {
        case "env":
            shouldContinue, err = s.handleEnv(ctx, req)
        case "exec":
            // The command has been executed as `ssh user@host command` or `exec` channel has been used
            // in the app implementation
            ctxWithLogData, shouldContinue, err = s.handleExec(ctx, req)
        case "shell":
            // The command has been entered into the shell or `shell` channel has been used
            // in the app implementation
            shouldContinue = false
            var status uint32
            ctxWithLogData, status, err = s.handleShell(ctx, req)
            s.exit(ctx, status)
        default:
            // Ignore unknown requests but don't terminate the session
            shouldContinue = true

            if req.WantReply {
                if err := req.Reply(false, []byte{}); err != nil {
                    sessionLog.WithError(err).Debug("session: handle: Failed to reply")
                }
            }
        }

        sessionLog.WithField("should_continue", shouldContinue).Debug("session: handle: request processed")

        if !shouldContinue {
            s.channel.Close()
            break
        }
    }

    ctxlog.Debug("session: handle: exiting request loop")

    return ctxWithLogData, err
}

func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) {
    var accepted bool
    var envRequest envRequest

    if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
        log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request")
        return false, err
    }

    switch envRequest.Name {
    case sshenv.GitProtocolEnv:
        s.gitProtocolVersion = envRequest.Value
        accepted = true
    default:
        // Client requested a forbidden envvar, nothing to do
    }

    if req.WantReply {
        if err := req.Reply(accepted, []byte{}); err != nil {
            log.ContextLogger(ctx).WithError(err).Debug("session: handleEnv: Failed to reply")
        }
    }

    log.WithContextFields(
        ctx, log.Fields{"accepted": accepted, "env_request": envRequest},
    ).Debug("session: handleEnv: processed")

    return true, nil
}

func (s *session) handleExec(ctx context.Context, req *ssh.Request) (context.Context, bool, error) {
    var execRequest execRequest

    if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil {
        return ctx, false, err
    }

    s.execCmd = execRequest.Command

    ctxWithLogData, status, err := s.handleShell(ctx, req)
    s.exit(ctxWithLogData, status)

    return ctxWithLogData, false, err
}

func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Context, uint32, error) {
    ctxlog := log.ContextLogger(ctx)

    if req.WantReply {
        if err := req.Reply(true, []byte{}); err != nil {
            ctxlog.WithError(err).Debug("session: handleShell: Failed to reply")
        }
    }

    env := sshenv.Env{
        IsSSHConnection:    true,
        OriginalCommand:    s.execCmd,
        GitProtocolVersion: s.gitProtocolVersion,
        RemoteAddr:         s.remoteAddr,
        NamespacePath:      s.namespace,
    }

    countingWriter := &readwriter.CountingWriter{W: s.channel}

    rw := &readwriter.ReadWriter{
        Out:    countingWriter,
        In:     s.channel,
        ErrOut: s.channel.Stderr(),
    }

    var cmd command.Command
    var err error

    if s.gitlabKrb5Principal != "" {
        cmd, err = shellCmd.NewWithKrb5Principal(s.gitlabKrb5Principal, env, s.cfg, rw)
    } else if s.gitlabUsername != "" {
        cmd, err = shellCmd.NewWithUsername(s.gitlabUsername, env, s.cfg, rw)
    } else {
        cmd, err = shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw)
    }

    if err != nil {
        if errors.Is(err, disallowedcommand.Error) {
            s.toStderr(ctx, "ERROR: Unknown command: %v\n", s.execCmd)
        } else {
            s.toStderr(ctx, "ERROR: Failed to parse command: %v\n", err.Error())
        }

        return ctx, 128, err
    }

    cmdName := reflect.TypeOf(cmd).String()

    establishSessionDuration := time.Since(s.started).Seconds()
    ctxlog.WithFields(log.Fields{
        "env": env, "command": cmdName, "established_session_duration_s": establishSessionDuration,
    }).Info("session: handleShell: executing command")
    metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)

    ctxWithLogData, err := cmd.Execute(ctx)

    logData := extractDataFromContext(ctxWithLogData)
    logData.WrittenBytes = countingWriter.N

    ctxWithLogData = context.WithValue(ctx, "logData", logData)

    if err != nil {
        grpcStatus := grpcstatus.Convert(err)
        if grpcStatus.Code() != grpccodes.Internal {
            s.toStderr(ctx, "ERROR: %v\n", grpcStatus.Message())
        }

        return ctx, 1, err
    }

    ctxlog.Info("session: handleShell: command executed successfully")

    return ctxWithLogData, 0, nil
}

func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) {
    out := fmt.Sprintf(format, args...)
    log.WithContextFields(ctx, log.Fields{"stderr": out}).Debug("session: toStderr: output")
    console.DisplayWarningMessage(out, s.channel.Stderr())
}

func (s *session) exit(ctx context.Context, status uint32) {
    log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting")
    req := exitStatusReq{ExitStatus: status}

    s.channel.CloseWrite()
    s.channel.SendRequest("exit-status", false, ssh.Marshal(req))
}