src/jetstream/passthrough.go
package main
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"path"
"strings"
"time"
"github.com/labstack/echo/v4"
log "github.com/sirupsen/logrus"
"github.com/cloudfoundry-incubator/stratos/src/jetstream/repository/interfaces"
)
// API Host Prefix to replace if the custom header is supplied
const apiPrefix = "api."
const (
longRunningTimeoutHeader = "x-cap-long-running"
noTokenHeader = "x-cap-no-token"
)
// Timeout for long-running requests, after which we will return indicating request it still active
// to prevent hitting the 2 minute browser timeout
const longRunningRequestTimeout = 30
type PassthroughErrorStatus struct {
StatusCode int `json:"statusCode"`
Status string `json:"status"`
}
type PassthroughError struct {
Error *PassthroughErrorStatus `json:"error"`
ErrorResponse *json.RawMessage `json:"errorResponse"`
}
func getEchoURL(c echo.Context) url.URL {
log.Debug("getEchoURL")
u := c.Request().URL
// dereference so we get a copy
return *u
}
func getEchoHeaders(c echo.Context) http.Header {
log.Debug("getEchoHeaders")
h := make(http.Header)
originalHeader := c.Request().Header
for k, v := range originalHeader {
if k == "Cookie" {
continue
}
vCopy := make([]string, len(v))
copy(vCopy, v)
h[k] = vCopy
}
return h
}
func makeRequestURI(c echo.Context) *url.URL {
log.Debug("makeRequestURI")
uri := getEchoURL(c)
prefix := strings.TrimSuffix(c.Path(), "*")
uri.Path = strings.TrimPrefix(uri.Path, prefix)
return &uri
}
func getPortalUserGUID(c echo.Context) (string, error) {
log.Debug("getPortalUserGUID")
portalUserGUIDIntf := c.Get("user_id")
if portalUserGUIDIntf == nil {
return "", errors.New("Corrupted session")
}
return portalUserGUIDIntf.(string), nil
}
func getRequestParts(c echo.Context) (*http.Request, []byte, error) {
log.Debug("getRequestParts")
var body []byte
var err error
req := c.Request()
if bodyReader := req.Body; bodyReader != nil {
if body, err = ioutil.ReadAll(bodyReader); err != nil {
return nil, nil, errors.New("Failed to read request body")
}
}
return req, body, nil
}
func buildJSONResponse(cnsiList []string, responses map[string]*interfaces.CNSIRequest) map[string]*json.RawMessage {
log.Debug("buildJSONResponse")
jsonResponse := make(map[string]*json.RawMessage)
for _, guid := range cnsiList {
var response []byte
cnsiResponse, ok := responses[guid]
var errorStatus = &PassthroughErrorStatus{
StatusCode: -1,
}
var errorResponse []byte
switch {
case !ok:
errorStatus.StatusCode = 500
errorStatus.Status = "Request timed out"
case cnsiResponse.Error != nil:
errorStatus.StatusCode = 500
errorStatus.Status = cnsiResponse.Error.Error()
case cnsiResponse.Response != nil:
response = cnsiResponse.Response
}
// Check the HTTP Status code to make sure that it is actually a valid response
if cnsiResponse.StatusCode >= 400 {
errorStatus.Status = cnsiResponse.Status
errorStatus.StatusCode = cnsiResponse.StatusCode
if errorStatus.StatusCode <= 0 {
errorStatus.StatusCode = 500
errorStatus.Status = "Failed to proxy request"
}
// Check that the error response was valid json - convert to string otherwise
if !isValidJSON(cnsiResponse.Response) {
errorResponse = []byte(fmt.Sprintf("%q", cnsiResponse.Response))
} else {
errorResponse = cnsiResponse.Response
}
}
if errorStatus.StatusCode >= 0 {
passthroughError := &PassthroughError{
Error: errorStatus,
ErrorResponse: (*json.RawMessage)(&errorResponse),
}
res, _ := json.Marshal(passthroughError)
jsonResponse[guid] = (*json.RawMessage)(&res)
} else {
if len(response) > 0 {
jsonResponse[guid] = (*json.RawMessage)(&response)
} else {
jsonResponse[guid] = nil
}
}
}
return jsonResponse
}
// When we move to goland 1.9 we can use json.isValid()
func isValidJSON(data []byte) bool {
var res interface{}
err := json.Unmarshal(data, &res)
return err == nil
}
func (p *portalProxy) buildCNSIRequest(cnsiGUID string, userGUID string, method string, uri *url.URL, body []byte, header http.Header) (interfaces.CNSIRequest, error) {
log.Debug("buildCNSIRequest")
cnsiRequest := interfaces.CNSIRequest{
GUID: cnsiGUID,
UserGUID: userGUID,
Method: method,
Body: body,
Header: header,
}
cnsiRec, err := p.GetCNSIRecord(cnsiGUID)
if err != nil {
return cnsiRequest, err
}
cnsiRequest.URL = new(url.URL)
*cnsiRequest.URL = *cnsiRec.APIEndpoint
// The APIEndpoint might have a path already - so join the request URI to it...
// but ensure we don't escape parameters again
extraPath := uri.Path
if len(uri.RawPath) > 0 {
extraPath = uri.RawPath
}
cnsiRequest.URL.RawPath = path.Join(cnsiRequest.URL.Path, extraPath)
cnsiRequest.URL.Path, _ = url.PathUnescape(cnsiRequest.URL.RawPath)
cnsiRequest.URL.RawQuery = uri.RawQuery
return cnsiRequest, nil
}
func (p *portalProxy) validateCNSIList(cnsiList []string) error {
log.Debug("validateCNSIList")
for _, cnsiGUID := range cnsiList {
if _, err := p.GetCNSIRecord(cnsiGUID); err != nil {
return err
}
}
return nil
}
func fwdCNSIStandardHeaders(cnsiRequest *interfaces.CNSIRequest, req *http.Request) {
log.Debug("fwdCNSIStandardHeaders")
for k, v := range cnsiRequest.Header {
switch {
// Skip these
// - "Referer" causes CF to fail with a 403
// - "Connection", "X-Cap-*" and "Cookie" are consumed by us
// - "Accept-Encoding" must be excluded otherwise the transport will expect us to handle the encoding/compression
// - X-Forwarded-* headers - these will confuse Cloud Foundry in some cases (e.g. load balancers)
case k == "Connection", k == "Cookie", k == "Referer", k == "Accept-Encoding",
strings.HasPrefix(strings.ToLower(k), "x-cap-"),
strings.HasPrefix(strings.ToLower(k), "x-forwarded-"):
// Forwarding everything else
default:
req.Header[k] = v
}
}
}
func (p *portalProxy) proxy(c echo.Context) error {
log.Debug("proxy")
responses, err := p.ProxyRequest(c, makeRequestURI(c))
if err != nil {
return err
}
return p.SendProxiedResponse(c, responses)
}
func (p *portalProxy) ProxyRequest(c echo.Context, uri *url.URL) (map[string]*interfaces.CNSIRequest, error) {
log.Debug("ProxyRequest")
cnsiList := strings.Split(c.Request().Header.Get("x-cap-cnsi-list"), ",")
shouldPassthrough := "true" == c.Request().Header.Get("x-cap-passthrough")
longRunning := "true" == c.Request().Header.Get(longRunningTimeoutHeader)
if err := p.validateCNSIList(cnsiList); err != nil {
return nil, echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
header := getEchoHeaders(c)
header.Del("Cookie")
portalUserGUID, err := getPortalUserGUID(c)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
req, body, err := getRequestParts(c)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
if shouldPassthrough {
if len(cnsiList) > 1 {
err := errors.New("Requested passthrough to multiple CNSIs. Only single CNSI passthroughs are supported")
return nil, echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
}
// Only support one endpoint for long running operation (due to way we do timeout with the response channel)
if longRunning {
if len(cnsiList) > 1 {
err := errors.New("Requested long-running proxy to multiple CNSIs. Only single CNSI is supported for long running passthrough")
return nil, echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
}
// send the request to each CNSI
done := make(chan *interfaces.CNSIRequest)
for _, cnsi := range cnsiList {
cnsiRequest, buildErr := p.buildCNSIRequest(cnsi, portalUserGUID, req.Method, uri, body, header)
if buildErr != nil {
return nil, echo.NewHTTPError(http.StatusBadRequest, buildErr.Error())
}
cnsiRequest.LongRunning = longRunning
// Allow the host part of the API URL to be overridden
apiHost := c.Request().Header.Get("x-cap-api-host")
// Don't allow any '.' chars in the api name
if apiHost != "" && !strings.ContainsAny(apiHost, ".") {
// Add trailing . for when we replace
apiHost = apiHost + "."
// Override the API URL if needed
if strings.HasPrefix(cnsiRequest.URL.Host, apiPrefix) {
// Replace 'api.' prefix with supplied prefix
cnsiRequest.URL.Host = strings.Replace(cnsiRequest.URL.Host, apiPrefix, apiHost, 1)
} else {
// Add supplied prefix to the domain
cnsiRequest.URL.Host = apiHost + cnsiRequest.URL.Host
}
}
go p.doRequest(&cnsiRequest, done)
}
// Wait for all responses
responses := make(map[string]*interfaces.CNSIRequest)
if !longRunning {
for range cnsiList {
res := <-done
responses[res.GUID] = res
}
} else {
// Long running has a timeout
for range cnsiList {
select {
case res := <-done:
responses[res.GUID] = res
case <-time.After(longRunningRequestTimeout * time.Second):
// For all those that have not completed, add a timeout response
for _, id := range cnsiList {
if _, ok := responses[id]; !ok {
// Did not get a response for the endpoint
responses[id] = &interfaces.CNSIRequest{
GUID: id,
UserGUID: portalUserGUID,
Method: req.Method,
StatusCode: http.StatusAccepted,
Status: "Long Running Operation still active",
Response: makeLongRunningTimeoutError(),
Error: nil,
ResponseGUID: id,
}
}
}
break
}
}
}
return responses, nil
}
func makeLongRunningTimeoutError() []byte {
description := "Long Running Operation still active"
var errorStatus = &PassthroughErrorStatus{
StatusCode: http.StatusAccepted,
Status: description,
}
errorResponse := []byte(fmt.Sprint("{\"longRunningTimeout\": true, \"description\": \"" + description + "\", \"error_code\": \"longRunningTimeout\"}"))
passthroughError := &PassthroughError{}
passthroughError.Error = errorStatus
passthroughError.ErrorResponse = (*json.RawMessage)(&errorResponse)
res, e := json.Marshal(passthroughError)
if e != nil {
log.Errorf("makeLongRunningTimeoutError: could not marshal JSON: %+v", e)
}
return res
}
// TODO: This should be used by the function above
func (p *portalProxy) DoProxyRequest(requests []interfaces.ProxyRequestInfo) (map[string]*interfaces.CNSIRequest, error) {
log.Debug("DoProxyRequest")
// send the request to each endpoint
done := make(chan *interfaces.CNSIRequest)
for _, requestInfo := range requests {
cnsiRequest, buildErr := p.buildCNSIRequest(requestInfo.EndpointGUID, requestInfo.UserGUID, requestInfo.Method, requestInfo.URI, requestInfo.Body, requestInfo.Headers)
cnsiRequest.ResponseGUID = requestInfo.ResultGUID
if buildErr != nil {
return nil, echo.NewHTTPError(http.StatusBadRequest, buildErr.Error())
}
go p.doRequest(&cnsiRequest, done)
}
responses := make(map[string]*interfaces.CNSIRequest)
for range requests {
res := <-done
responses[res.ResponseGUID] = res
}
return responses, nil
}
// Convenience helper for a single request
func (p *portalProxy) DoProxySingleRequest(cnsiGUID, userGUID, method, requestUrl string, headers http.Header, body []byte) (*interfaces.CNSIRequest, error) {
requests := make([]interfaces.ProxyRequestInfo, 0)
proxyURL, err := url.Parse(requestUrl)
if err != nil {
return nil, err
}
req := interfaces.ProxyRequestInfo{}
req.UserGUID = userGUID
req.ResultGUID = "REQ_" + cnsiGUID
req.EndpointGUID = cnsiGUID
req.Method = method
req.URI = proxyURL
if headers != nil {
req.Headers = headers
}
if body != nil {
req.Body = body
}
requests = append(requests, req)
responses, err := p.DoProxyRequest(requests)
if err != nil {
return nil, err
}
return responses[req.ResultGUID], err
}
// Convenience helper for a single request using a token
func (p *portalProxy) DoProxySingleRequestWithToken(cnsiGUID string, token *interfaces.TokenRecord, method, requestURL string, headers http.Header, body []byte) (*interfaces.CNSIRequest, error) {
proxyURL, err := url.Parse(requestURL)
if err != nil {
return nil, err
}
done := make(chan *interfaces.CNSIRequest)
cnsiRequest, buildErr := p.buildCNSIRequest(cnsiGUID, "", method, proxyURL, body, headers)
if buildErr != nil {
return nil, echo.NewHTTPError(http.StatusBadRequest, buildErr.Error())
}
cnsiRequest.Token = token
go p.doRequest(&cnsiRequest, done)
res := <-done
return res, nil
}
func (p *portalProxy) SendProxiedResponse(c echo.Context, responses map[string]*interfaces.CNSIRequest) error {
shouldPassthrough := "true" == c.Request().Header.Get("x-cap-passthrough")
var cnsiList []string
for k := range responses {
cnsiList = append(cnsiList, k)
}
if shouldPassthrough {
cnsiGUID := cnsiList[0]
res, ok := responses[cnsiGUID]
if !ok {
return echo.NewHTTPError(http.StatusRequestTimeout, "Request timed out")
}
// in passthrough mode, set the status code to that of the single response
c.Response().WriteHeader(res.StatusCode)
// we don't care if this fails
_, err := c.Response().Write(res.Response)
if err != nil {
log.Errorf("Failed to write passthrough response %v", err)
}
return nil
}
jsonResponse := buildJSONResponse(cnsiList, responses)
e := json.NewEncoder(c.Response())
err := e.Encode(jsonResponse)
if err != nil {
log.Errorf("Failed to encode JSON: %v\n%#v\n", err, jsonResponse)
}
return err
}
func (p *portalProxy) doRequest(cnsiRequest *interfaces.CNSIRequest, done chan<- *interfaces.CNSIRequest) {
log.Debugf("doRequest for URL: %s", cnsiRequest.URL.String())
var body io.Reader
var res *http.Response
var req *http.Request
var err error
if len(cnsiRequest.Body) > 0 {
body = bytes.NewReader(cnsiRequest.Body)
}
proxyURL := cnsiRequest.URL.String()
req, err = http.NewRequest(cnsiRequest.Method, proxyURL, body)
if err != nil {
cnsiRequest.Error = err
if done != nil {
done <- cnsiRequest
}
return
}
var tokenRec interfaces.TokenRecord
if cnsiRequest.Token != nil {
tokenRec = *cnsiRequest.Token
} else {
// get a cnsi token record and a cnsi record
tokenRec, _, err = p.getCNSIRequestRecords(cnsiRequest)
if err != nil {
cnsiRequest.Error = err
if done != nil {
cnsiRequest.StatusCode = 400
cnsiRequest.Status = "Unable to retrieve CNSI token record"
done <- cnsiRequest
}
return
}
}
// Copy original headers through, except custom portal-proxy Headers
fwdCNSIStandardHeaders(cnsiRequest, req)
// If this is a long running request, add a header which we can use at request time to change the timeout
if cnsiRequest.LongRunning {
req.Header.Set(longRunningTimeoutHeader, "true")
}
// Find the auth provider for the auth type - default ot oauthflow
authHandler := p.GetAuthProvider(tokenRec.AuthType)
if authHandler.Handler != nil {
res, err = authHandler.Handler(cnsiRequest, req)
} else {
res, err = p.DoOAuthFlowRequest(cnsiRequest, req)
}
if err != nil {
cnsiRequest.StatusCode = 500
cnsiRequest.Status = "Error proxing request"
cnsiRequest.Response = []byte(err.Error())
cnsiRequest.Error = err
} else if res.Body != nil {
cnsiRequest.StatusCode = res.StatusCode
cnsiRequest.Status = res.Status
cnsiRequest.Response, cnsiRequest.Error = ioutil.ReadAll(res.Body)
defer res.Body.Close()
}
// If Status Code >=400, log this as a warning
if cnsiRequest.StatusCode >= 400 {
var contentType = "Unknown"
var contentLength int64 = -1
if res != nil {
contentType = res.Header.Get("Content-Type")
contentLength = res.ContentLength
}
log.Warnf("Passthrough response: URL: %s, Status Code: %d, Status: %s, Content Type: %s, Length: %d",
cnsiRequest.URL.String(), cnsiRequest.StatusCode, cnsiRequest.Status, contentType, contentLength)
log.Warn(string(cnsiRequest.Response))
}
if done != nil {
done <- cnsiRequest
}
}
func (p *portalProxy) ProxySingleRequest(c echo.Context) error {
log.Debug("ProxySingleRequest")
cnsi := c.Param("uuid")
uri := url.URL{}
// Ensure we don't escape parameters again
uri.RawPath = c.Param("*")
uri.Path, _ = url.PathUnescape(uri.RawPath)
uri.RawQuery = c.Request().URL.RawQuery
header := getEchoHeaders(c)
header.Del("Cookie")
header.Del(APIKeyHeader)
portalUserGUID, err := getPortalUserGUID(c)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
req, body, err := getRequestParts(c)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
done := make(chan *interfaces.CNSIRequest)
cnsiRequest, buildErr := p.buildCNSIRequest(cnsi, portalUserGUID, req.Method, &uri, body, header)
if buildErr != nil {
return echo.NewHTTPError(http.StatusBadRequest, buildErr.Error())
}
longRunning := "true" == c.Request().Header.Get(longRunningTimeoutHeader)
noToken := "true" == c.Request().Header.Get(noTokenHeader)
cnsiRequest.LongRunning = longRunning
if noToken {
// Fake a token record with no authentication
cnsiRequest.Token = &interfaces.TokenRecord{
AuthType: interfaces.AuthConnectTypeNone,
}
}
go p.doRequest(&cnsiRequest, done)
res := <-done
// FIXME: cnsiRequest.Status info is lost for failures, only get a status code
c.Response().WriteHeader(res.StatusCode)
// we don't care if this fails
_, writeErr := c.Response().Write(res.Response)
if writeErr != nil {
log.Errorf("Failed to write passthrough response %v", err)
}
return nil
}