cyberark/secretless-broker

View on GitHub
internal/plugin/connectors/ssh/service_connector.go

Summary

Maintainability
D
1 day
Test Coverage
D
66%
package ssh

import (
    "fmt"
    "io"
    "reflect"
    "strings"
    "time"

    validation "github.com/go-ozzo/ozzo-validation"
    "golang.org/x/crypto/ssh"

    "github.com/cyberark/secretless-broker/pkg/secretless/log"
    "github.com/cyberark/secretless-broker/pkg/secretless/plugin/connector"
)

// ServerConfig is the configuration info for the target server
type ServerConfig struct {
    Network      string
    Address      string
    ClientConfig ssh.ClientConfig
}

// ServiceConnector contains the configuration and channels
type ServiceConnector struct {
    channels <-chan ssh.NewChannel
    logger   log.Logger
}

func (h *ServiceConnector) serverConfig(values map[string][]byte) (config ServerConfig, err error) {
    keys := reflect.ValueOf(values).MapKeys()
    h.logger.Debugf("SSH backend connection parameters: %s", keys)

    config.Network = "tcp"
    if address, ok := values["address"]; ok {
        config.Address = string(address)
        if !strings.Contains(config.Address, ":") {
            config.Address = config.Address + ":22"
        }
    }

    // XXX: Should this be the user that the client was trying to connect as?
    config.ClientConfig.User = "root"
    if user, ok := values["user"]; ok {
        config.ClientConfig.User = string(user)

    }

    h.logger.Debugf("Trying to connect with user: %s", config.ClientConfig.User)

    if hostKeyStr, ok := values["hostKey"]; ok {
        var hostKey ssh.PublicKey
        if hostKey, err = ssh.ParsePublicKey([]byte(hostKeyStr)); err != nil {
            h.logger.Debugf("Unable to parse public key: %v", err)
            return
        }
        config.ClientConfig.HostKeyCallback = ssh.FixedHostKey(hostKey)
    } else {
        h.logger.Warnf("No SSH hostKey specified. Secretless will accept any backend host key!")
        config.ClientConfig.HostKeyCallback = ssh.InsecureIgnoreHostKey()
    }

    if privateKeyStr, ok := values["privateKey"]; ok {
        var signer ssh.Signer
        if signer, err = ssh.ParsePrivateKey([]byte(privateKeyStr)); err != nil {
            h.logger.Debugf("Unable to parse private key: %v", err)
            return
        }
        config.ClientConfig.Auth = []ssh.AuthMethod{
            ssh.PublicKeys(signer),
        }
    }

    return
}

// Connect opens the connection to the target server and proxies requests
func (h *ServiceConnector) Connect(
    credentialValuesByID connector.CredentialValuesByID,
) error {
    var err error
    var serverConfig ServerConfig
    var server ssh.Conn

    errors := validation.Errors{}
    for _, credential := range [...]string{"address", "privateKey"} {
        if _, hasCredential := credentialValuesByID[credential]; !hasCredential {
            errors[credential] = fmt.Errorf("must have credential '%s'", credential)
        }
    }

    if err := errors.Filter(); err != nil {
        return err
    }

    if serverConfig, err = h.serverConfig(credentialValuesByID); err != nil {
        return fmt.Errorf("could not resolve server config: '%s'", err)
    }

    if server, err = ssh.Dial(serverConfig.Network, serverConfig.Address, &serverConfig.ClientConfig); err != nil {
        return fmt.Errorf("failed to dial SSH backend '%s': %s", serverConfig.Address, err)
    }

    // Service the incoming Channel channel.
    for newChannel := range h.channels {
        serverChannel, serverRequests, err := server.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData())
        if err != nil {
            sshError := err.(*ssh.OpenChannelError)
            if err := newChannel.Reject(sshError.Reason, sshError.Message); err != nil {
                h.logger.Errorf("Failed to send new channel rejection : %s", err)
            }
            return err
        }

        clientChannel, clientRequests, err := newChannel.Accept()
        if err != nil {
            h.logger.Errorf("Failed to accept client channel : %s", err)
            serverChannel.Close()
            return err
        }

        go func() {
            for clientRequest := range clientRequests {
                h.logger.Debugf("Client request : %s", clientRequest.Type)
                ok, err := serverChannel.SendRequest(clientRequest.Type, clientRequest.WantReply, clientRequest.Payload)
                if err != nil {
                    h.logger.Warnf("Failed to send client request to server channel : %s", err)
                }
                if clientRequest.WantReply {
                    h.logger.Debugf("Server reply is %v", ok)
                }
            }
        }()

        go func() {
            for serverRequest := range serverRequests {
                h.logger.Debugf("Server request : %s", serverRequest.Type)
                ok, err := clientChannel.SendRequest(serverRequest.Type, serverRequest.WantReply, serverRequest.Payload)
                if err != nil {
                    h.logger.Debugf("WARN: Failed to send server request to client channel : %s", err)
                }
                if serverRequest.WantReply {
                    h.logger.Debugf("Client reply is %v", ok)
                }
            }
        }()

        // This delay is to prevent closing of channels on the other side
        // too early when we receive an EOF but have not had the chance to
        // pass that on to the client/server.
        // TODO: Maybe use a better logic for handling EOF conditions
        softDelay := time.Second * 2

        go func() {
            for {
                data := make([]byte, 1024)
                len, err := clientChannel.Read(data)
                if err == io.EOF {
                    h.logger.Debugf("Client channel is closed")
                    time.Sleep(softDelay)
                    serverChannel.Close()
                    return
                }
                _, err = serverChannel.Write(data[0:len])
                if err != nil {
                    h.logger.Debugf("Error writing %d bytes to server channel : %s", len, err)
                }
            }
        }()

        go func() {
            for {
                data := make([]byte, 1024)
                len, err := serverChannel.Read(data)
                if err == io.EOF {
                    h.logger.Debugf("Server channel is closed")
                    time.Sleep(softDelay)
                    clientChannel.Close()
                    return
                }
                _, err = clientChannel.Write(data[0:len])
                if err != nil {
                    h.logger.Debugf("Error writing %d bytes to client channel : %s", len, err)
                }
            }
        }()
    }

    return nil
}