cloudfoundry/stratos

View on GitHub
src/jetstream/plugins/cfappssh/app_ssh.go

Summary

Maintainability
C
7 hrs
Test Coverage
package cfappssh

import (
    "crypto/tls"
    "encoding/json"
    "errors"
    "fmt"
    "io"
    "net"
    "net/http"
    "net/url"
    "time"

    "github.com/cloudfoundry-incubator/stratos/src/jetstream/repository/interfaces"
    "github.com/gorilla/websocket"
    "github.com/labstack/echo/v4"
    log "github.com/sirupsen/logrus"
    "golang.org/x/crypto/ssh"
)

// See: https://docs.cloudfoundry.org/devguide/deploy-apps/ssh-apps.html

// WebScoket code based on: https://github.com/gorilla/websocket/blob/master/examples/command/main.go

const (
    // Time allowed to write a message to the peer.
    writeWait = 10 * time.Second

    // Inactivity timeout
    inActivityTimeout = 10 * time.Second

    md5FingerprintLength          = 47 // inclusive of space between bytes
    base64Sha256FingerprintLength = 43
)

// Allow connections from any Origin
var upgrader = websocket.Upgrader{
    CheckOrigin: func(r *http.Request) bool { return true },
}

// KeyCode - JSON object that is passed from the front-end to notify of a key press or a term resize
type KeyCode struct {
    Key  string `json:"key"`
    Cols int    `json:"cols"`
    Rows int    `json:"rows"`
}

func (cfAppSsh *CFAppSSH) appSSH(c echo.Context) error {
    // Need to get info for the endpoint
    // Get the CNSI and app IDs from route parameters
    cnsiGUID := c.Param("cnsiGuid")
    userGUID := c.Get("user_id").(string)

    var p = cfAppSsh.portalProxy

    // Extract the Doppler endpoint from the CNSI record
    cnsiRecord, err := p.GetCNSIRecord(cnsiGUID)
    if err != nil {
        return sendSSHError("Could not get endpoint information")
    }

    // Make the info call to the SSH endpoint info
    // Currently this is not cached, so we must get it each time
    apiEndpoint := cnsiRecord.APIEndpoint

    cfPlugin, err := p.GetEndpointTypeSpec("cf")
    if err != nil {
        return sendSSHError("Can not get Cloud Foundry endpoint plugin")
    }

    _, info, err := cfPlugin.Info(apiEndpoint.String(), cnsiRecord.SkipSSLValidation)
    if err != nil {
        return sendSSHError("Can not get Cloud Foundry info")
    }

    cfInfo, found := info.(interfaces.V2Info)
    if !found {
        return sendSSHError("Can not get Cloud Foundry info")
    }

    appGUID := c.Param("appGuid")
    appInstance := c.Param("appInstance")

    host, _, err := net.SplitHostPort(cfInfo.AppSSHEndpoint)
    if err != nil {
        host = cfInfo.AppSSHEndpoint
    }

    // Build the Username
    // cf:APP-GUID/APP-INSTANCE-INDEX@SSH-ENDPOINT
    username := fmt.Sprintf("cf:%s/%s@%s", appGUID, appInstance, host)

    // Need to get SSH Code
    // Refresh token first - makes sure it will be valid when we make the request to get the code
    refreshedTokenRec, err := p.RefreshOAuthToken(cnsiRecord.SkipSSLValidation, cnsiRecord.GUID, userGUID, cnsiRecord.ClientId, cnsiRecord.ClientSecret, cnsiRecord.TokenEndpoint)
    if err != nil {
        return sendSSHError("Couldn't get refresh token for CNSI with GUID %s", cnsiRecord.GUID)
    }

    code, err := getSSHCode(cnsiRecord.TokenEndpoint, cfInfo.AppSSHOauthCLient, refreshedTokenRec.AuthToken, cnsiRecord.SkipSSLValidation)
    if err != nil {
        return sendSSHError("Couldn't get SSH Code: %s", err)
    }

    sshConfig := &ssh.ClientConfig{
        User: username,
        Auth: []ssh.AuthMethod{
            ssh.Password(code),
        },
        HostKeyCallback: sshHostKeyChecker(cfInfo.AppSSHHostKeyFingerprint),
    }

    connection, err := ssh.Dial("tcp", cfInfo.AppSSHEndpoint, sshConfig)
    if err != nil {
        return fmt.Errorf("Failed to dial: %s", err)
    }

    session, err := connection.NewSession()
    if err != nil {
        return fmt.Errorf("Failed to create session: %s", err)
    }

    defer connection.Close()

    // Upgrade the web socket
    ws, pingTicker, err := interfaces.UpgradeToWebSocket(c)
    if err != nil {
        return err
    }
    defer ws.Close()
    defer pingTicker.Stop()

    modes := ssh.TerminalModes{
        ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
        ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
    }

    // NB: rows, cols
    if err := session.RequestPty("xterm", 84, 80, modes); err != nil {
        session.Close()
        return fmt.Errorf("request for pseudo terminal failed: %s", err)
    }

    stdin, err := session.StdinPipe()
    if err != nil {
        return fmt.Errorf("Unable to setup stdin for session: %v", err)
    }

    stdout, err := session.StdoutPipe()
    if err != nil {
        return fmt.Errorf("Unable to setup stdout for session: %v", err)
    }

    defer session.Close()

    stdoutDone := make(chan struct{})
    go pumpStdout(ws, stdout, stdoutDone)
    go session.Shell()

    // Read the input from the web socket and pipe it to the SSH client
    for {
        _, r, err := ws.ReadMessage()
        if err != nil {
            log.Error("Error reading message from web socket")
            log.Warnf("%+v", err)
            return err
        }

        res := KeyCode{}
        json.Unmarshal(r, &res)

        if res.Cols == 0 {
            stdin.Write([]byte(res.Key))
        } else {
            // Terminal resize request
            if err := windowChange(session, res.Rows, res.Cols); err != nil {
                log.Error("Can not resize the PTY")
            }
        }
    }
}

func sendSSHError(format string, a ...interface{}) error {
    if len(a) == 0 {
        log.Error("App SSH Error: " + format)
    } else {
        log.Errorf("App SSH Error: "+format, a)
    }
    return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf(format, a...))
}

func sshHostKeyChecker(fingerprint string) ssh.HostKeyCallback {
    return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
        switch len(fingerprint) {
        case base64Sha256FingerprintLength:
            if fmt.Sprintf("SHA256:%s", fingerprint) == ssh.FingerprintSHA256(key) {
                return nil
            }
        case md5FingerprintLength:
            if fingerprint == ssh.FingerprintLegacyMD5(key) {
                return nil
            }
        default:
            return errors.New("Unsupported host key fingerprint format")
        }
        return errors.New("Host key fingerprint is incorrect")
    }
}

// RFC 4254 Section 6.7.
type windowChangeRequestMsg struct {
    Columns uint32
    Rows    uint32
    Width   uint32
    Height  uint32
}

func windowChange(s *ssh.Session, h, w int) error {

    req := windowChangeRequestMsg{
        Columns: uint32(w),
        Rows:    uint32(h),
        Width:   uint32(w * 8),
        Height:  uint32(h * 8),
    }
    ok, err := s.SendRequest("window-change", true, ssh.Marshal(&req))
    if err == nil && !ok {
        err = errors.New("ssh: window-change failed")
    }
    return err
}

func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) {
    buffer := make([]byte, 32768)
    for {
        len, err := r.Read(buffer)
        if err != nil {
            if err != io.EOF {
                log.Errorf("App SSH encountered an error reading from stdout; %v", err)
            }
            ws.Close()
            break
        }

        ws.SetWriteDeadline(time.Now().Add(writeWait))
        bytes := fmt.Sprintf("% x\n", buffer[:len])
        if err := ws.WriteMessage(websocket.TextMessage, []byte(bytes)); err != nil {
            log.Error("App SSH Failed to write nessage")
            ws.Close()
            break
        }
    }
}

// ErrPreventRedirect - Error to indicate a redirect - used to make a redirect that we want to prevent later
var ErrPreventRedirect = errors.New("prevent-redirect")

func getSSHCode(authorizeEndpoint, clientID, token string, skipSSLValidation bool) (string, error) {
    authorizeURL, err := url.Parse(authorizeEndpoint)
    if err != nil {
        return "", err
    }

    values := url.Values{}
    values.Set("response_type", "code")
    values.Set("grant_type", "authorization_code")
    values.Set("client_id", clientID)

    authorizeURL.Path += "/oauth/authorize"
    authorizeURL.RawQuery = values.Encode()

    authorizeReq, err := http.NewRequest("GET", authorizeURL.String(), nil)
    if err != nil {
        return "", err
    }

    authorizeReq.Header.Add("authorization", "Bearer "+token)

    httpClientWithoutRedirects := &http.Client{
        CheckRedirect: func(req *http.Request, _ []*http.Request) error {
            return ErrPreventRedirect
        },
        Timeout: 30 * time.Second,
        Transport: &http.Transport{
            DisableKeepAlives: true,
            TLSClientConfig: &tls.Config{
                InsecureSkipVerify: skipSSLValidation,
            },
            Proxy:               http.ProxyFromEnvironment,
            TLSHandshakeTimeout: 10 * time.Second,
        },
    }

    resp, err := httpClientWithoutRedirects.Do(authorizeReq)
    if resp != nil {
        log.Infof("%+v", resp)
    }
    if err == nil {
        return "", errors.New("Authorization server did not redirect with one time code")
    }

    if netErr, ok := err.(*url.Error); !ok || netErr.Err != ErrPreventRedirect {
        return "", errors.New("Error requesting one time code from server")
    }

    loc, err := resp.Location()
    if err != nil {
        return "", errors.New("Error getting the redirected location")
    }

    codes := loc.Query()["code"]
    if len(codes) != 1 {
        return "", errors.New("Unable to acquire one time code from authorization response")
    }

    return codes[0], nil
}