proxy.go
package j8a
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"strconv"
"strings"
"time"
unicode "unicode"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jws"
"github.com/lestrrat-go/jwx/jwt"
"github.com/rs/zerolog"
"golang.org/x/net/idna"
"github.com/rs/zerolog/log"
)
var httpUpstreamMaxAttempts int
type TLSType string
const (
TLS12 TLSType = "1.2"
TLS13 TLSType = "1.3"
TLS_UNKNOWN TLSType = "unknown"
TLS_NONE TLSType = "none"
Authorization = "Authorization"
Sep = " "
)
// RFC7231 4.2.1
var httpSafeMethods []string = []string{"GET", "HEAD", "OPTIONS", "TRACE"}
// RFC7231 4.2.2
var httpIdempotentMethods []string = []string{"PUT", "DELETE"}
var httpRepeatableMethods = append(httpSafeMethods, httpIdempotentMethods...)
// RFC7231 4.3
var httpLegalMethods []string = append(httpRepeatableMethods, []string{"POST", "PATCH", "CONNECT"}...)
type ContentEncoding string
const (
EncStar ContentEncoding = "*"
EncIdentity ContentEncoding = "identity"
EncBrotli ContentEncoding = "br"
EncGzip ContentEncoding = "gzip"
EncXGzip ContentEncoding = "x-gzip"
EncDeflate ContentEncoding = "deflate"
EncXDeflate ContentEncoding = "x-deflate"
EncCompress ContentEncoding = "compress"
EncXCompress ContentEncoding = "x-compress"
)
var GzipContentEncodings = AcceptEncoding{EncGzip, EncXGzip}
var CompressedContentEncodings = AcceptEncoding{EncBrotli, EncGzip, EncXGzip, EncDeflate, EncXDeflate, EncCompress, EncXCompress}
var SupportedContentEncodings = AcceptEncoding{EncStar, EncIdentity, EncBrotli, EncGzip, EncXGzip}
var UnsupportedContentEncodings = AcceptEncoding{EncDeflate, EncXDeflate, EncCompress, EncXCompress}
func NewContentEncoding(raw string) ContentEncoding {
encs := strings.TrimFunc(raw, func(r rune) bool {
return !unicode.IsGraphic(r)
})
return ContentEncoding(strings.ToLower(strings.TrimSpace(encs)))
}
func (c ContentEncoding) isSupported() bool {
for _, ce := range SupportedContentEncodings {
if ce == c {
return true
}
}
return false
}
func (c ContentEncoding) isCompressed() bool {
for _, ce := range CompressedContentEncodings {
if ce == c {
return true
}
}
return false
}
func (c ContentEncoding) isEncoded() bool {
return len(c) > 0 &&
c != EncIdentity
}
func (c ContentEncoding) isAtomic() bool {
return len(c) > 0 && !strings.Contains(string(c), COMMA)
}
func (c ContentEncoding) isGzip() bool {
for _, ce := range GzipContentEncodings {
if ce == c {
return true
}
}
return false
}
func (c ContentEncoding) isUnSupported() bool {
for _, ce := range UnsupportedContentEncodings {
if ce == c {
return true
}
}
return false
}
func (c ContentEncoding) isCustom() bool {
return len(c) > 0 &&
c != EncIdentity &&
!c.isCompressed()
}
func (c ContentEncoding) isBrotli() bool {
return c == EncBrotli
}
const xdash string = "x-"
func (c ContentEncoding) matches(encoding ContentEncoding) bool {
if len(c) == 0 && encoding == EncIdentity {
return true
} else if len(string(encoding)) == 0 {
return false
} else if c == encoding || c == STAR {
return true
} else if len(string(encoding)) >= 2 && string(encoding)[0:2] == xdash {
return c == NewContentEncoding(string(encoding)[2:])
} else {
return c == NewContentEncoding(xdash+string(encoding))
}
}
func (c ContentEncoding) print() string {
return string(c)
}
type AcceptEncoding []ContentEncoding
func (ae AcceptEncoding) hasAtLeastOneValidEncoding() bool {
var valid bool = false
if len(ae) == 0 {
valid = true
} else if len(ae) == 1 && string(ae[0]) == emptyString {
valid = true
} else {
for _, ce := range ae {
valid = valid || ce.isSupported()
}
}
return valid
}
func (ae AcceptEncoding) isCompatible(enc ContentEncoding) bool {
var comp = false
for _, ce := range ae {
comp = comp || ce.matches(enc)
}
return comp
}
const commaSpace = ", "
func (ae AcceptEncoding) Print() string {
p := emptyString
for _, ce := range ae {
p = p + string(ce) + commaSpace
}
p = p[:len(p)-2]
return p
}
type proxyfunc func(*Proxy)
// Atmpt wraps connection attempts to specific upstreams that are already mapped by label
type Atmpt struct {
URL *URL
Label string
Count int
StatusCode int
ContentEncoding ContentEncoding
resp *http.Response
respBody *[]byte
CompleteHeader chan struct{}
CompleteBody chan struct{}
Aborted <-chan struct{}
AbortedFlag bool
CancelFunc func()
startDate time.Time
}
func (atmpt Atmpt) print() string {
return fmt.Sprintf("%d/%d", atmpt.Count, Runner.Connection.Upstream.MaxAttempts)
}
// Resp wraps downstream http response writer and data
type Resp struct {
Writer http.ResponseWriter
StatusCode int
Message string
Body *[]byte
ContentLength int64
ContentEncoding ContentEncoding
}
// Up wraps upstream
type Up struct {
Atmpt *Atmpt
Atmpts []Atmpt
Count int
}
// Down wraps downstream exchange
type Down struct {
Req *http.Request
Resp Resp
Method string
Host string
Path string
URI string
UserAgent string
AcceptEncoding AcceptEncoding
Body []byte
Aborted <-chan struct{}
AbortedFlag bool
Timeout <-chan struct{}
TimeoutFlag bool
ReqTooLarge bool
startDate time.Time
HttpVer string
TlsVer string
Port int
Listener string
}
// Proxy wraps data for a single downstream request/response with multiple upstream HTTP request/response cycles.
type Proxy struct {
XRequestID string
XRequestInfo bool
Up Up
Dwn Down
Route *Route
}
func (proxy *Proxy) hasDownstreamAbortedOrTimedout() bool {
//non blocking read if request context was aborted
select {
case <-proxy.Dwn.Timeout:
proxy.Dwn.TimeoutFlag = true
case <-proxy.Dwn.Aborted:
proxy.Dwn.AbortedFlag = true
default:
}
if proxy.Dwn.TimeoutFlag == true {
proxy.respondWith(504, gatewayTimeoutTriggeredByDownstreamEvent)
} else if proxy.Dwn.AbortedFlag == true {
proxy.respondWith(499, connectionClosedByRemoteUserAgent)
}
return proxy.Dwn.AbortedFlag || proxy.Dwn.TimeoutFlag
}
func (proxy *Proxy) resolveUpstreamURI() string {
uri := proxy.Up.Atmpt.URL.String() + proxy.Dwn.URI
if len(proxy.Route.Transform) > 0 {
t := proxy.Route.Transform
if t == "/" {
t = ""
}
uri = proxy.Up.Atmpt.URL.String() + strings.Replace(proxy.Dwn.URI, proxy.Route.Path, t, 1)
}
return uri
}
const abortedUpstreamAttempt = "upstream attempt aborted"
func (proxy *Proxy) abortAllUpstreamAttempts() {
for _, atmpt := range proxy.Up.Atmpts {
atmpt.AbortedFlag = true
if atmpt.CancelFunc != nil {
atmpt.CancelFunc()
scaffoldUpAttemptLog(proxy).
Msgf(abortedUpstreamAttempt)
}
}
}
func (proxy *Proxy) hasUpstreamAttemptAborted() bool {
//non blocking read if request context was aborted
select {
case <-proxy.Up.Atmpt.Aborted:
proxy.Up.Atmpt.AbortedFlag = true
default:
}
return proxy.Up.Atmpt.AbortedFlag
}
// tells us if we can safely retry with another upstream attempt
const upstreamRetriesStopped = "upstream retries stopped"
func (proxy *Proxy) shouldRetryUpstreamAttempt() bool {
// part one is checking for repeatable methods. we don't retry i.e. POST
retry := false
Retry:
for _, method := range httpRepeatableMethods {
if proxy.Dwn.Method == method {
if proxy.Up.Atmpt.Count < Runner.Connection.Upstream.MaxAttempts {
retry = true
break Retry
}
retry = false
}
}
// once downstream context has signalled, do not re-attempt upstream
if proxy.hasDownstreamAbortedOrTimedout() {
retry = false
}
if !retry {
scaffoldUpAttemptLog(proxy).
Msg(upstreamRetriesStopped)
}
return retry
}
func (proxy *Proxy) hasMadeUpstreamAttempt() bool {
return proxy.Up.Atmpt != nil && proxy.Up.Atmpt.resp != nil
}
const headerParsed = "downstream request headers successfully parsed"
const bodyBytes = "bodyBytes"
const method = "method"
const path = "path"
// ParseIncoming is a factory method for a new ProxyRequest, embeds the incoming request.
func (proxy *Proxy) parseIncoming(request *http.Request) *Proxy {
proxy.Dwn.startDate = time.Now()
proxy.XRequestID = createXRequestID(request)
//set request new request context for timeout
ctx, cancel := context.WithCancel(context.TODO())
proxy.Dwn.Timeout = ctx.Done()
time.AfterFunc(Runner.getDownstreamRoundTripTimeoutDuration(), func() {
cancel()
})
//this is separate context for abort. abort is manual close
proxy.Dwn.Aborted = request.Context().Done()
if !Runner.DisableXRequestInfo {
proxy.XRequestInfo = parseXRequestInfo(request)
}
proxy.Dwn.Host = parseHost(request)
proxy.Dwn.Path = request.URL.EscapedPath()
proxy.Dwn.URI = request.URL.RequestURI()
proxy.Dwn.AcceptEncoding = parseAcceptEncoding(request)
proxy.Dwn.HttpVer = parseHTTPVer(request)
proxy.Dwn.TlsVer = parseTlsVersion(request)
proxy.Dwn.UserAgent = parseUserAgent(request)
proxy.Dwn.Method = parseMethod(request)
proxy.Dwn.Listener = parseListener(request)
proxy.Dwn.Port = parsePort(request)
proxy.Dwn.Req = request
proxy.Dwn.AbortedFlag = false
infoOrTraceEv(proxy).Str(path, proxy.Dwn.Path).
Str(method, proxy.Dwn.Method).
Int64(dwnElpsdMicros, time.Since(proxy.Dwn.startDate).Microseconds()).
Str(XRequestID, proxy.XRequestID).
Msg(headerParsed)
proxy.parseRequestBody(request)
return proxy
}
const AcceptEncodingS = "Accept-Encoding"
const COMMA = ","
const STAR = "*"
func parseAcceptEncoding(request *http.Request) AcceptEncoding {
//case insensitive
var ae AcceptEncoding
raw := request.Header.Get(AcceptEncodingS)
//do not assume this header is set.
encs := strings.Split(raw, COMMA)
for _, e := range encs {
ae = append(ae, NewContentEncoding(e))
}
return ae
}
func parsePort(request *http.Request) int {
if request.TLS == nil {
return Runner.Connection.Downstream.Http.Port
} else {
return Runner.Connection.Downstream.Tls.Port
}
}
func parseListener(request *http.Request) string {
if request.TLS == nil {
return HTTP
} else {
return TLS
}
}
const dwnHeaderContentLengthZero = "downstream request has content-length 0, body not read"
const dwnBodyContentLengthExceedsMaxBytes = "downstream request body content-length %d exceeds max allowed bytes %d, refuse reading body"
const dwnBodyTooLarge = "downstream request body too large. %d body bytes > server max %d"
const dwnBodyReadTimeout = "downstream request body read timed out, cause: %v"
const dwnBodyReadAbort = "downstream request body read aborted, cause: %v"
const dwnBodyRead = "downstream request body read (%d/%d) bytes/content-length"
const timeout = "timeout"
func (proxy *Proxy) parseRequestBody(request *http.Request) {
//content length 0, do not read just go back
if request.ContentLength == 0 {
infoOrTraceEv(proxy).
Int64(dwnElpsdMicros, time.Since(proxy.Dwn.startDate).Microseconds()).
Str(XRequestID, proxy.XRequestID).
Msg(dwnHeaderContentLengthZero)
return
}
//only try to parse the request if supplied content-length is within limits
if request.ContentLength >= Runner.Connection.Downstream.MaxBodyBytes {
proxy.Dwn.ReqTooLarge = true
infoOrTraceEv(proxy).
Str(XRequestID, proxy.XRequestID).
Int64(dwnElpsdMicros, time.Since(proxy.Dwn.startDate).Microseconds()).
Msgf(dwnBodyContentLengthExceedsMaxBytes, request.ContentLength, Runner.Connection.Downstream.MaxBodyBytes)
return
}
//create buffered reader so we can fetch chunks of request as they come.
//No need to close request.Body of type io.ReadCloser, see: https://golang.org/pkg/net/http/#Request
bodyReader := bufio.NewReader(http.MaxBytesReader(proxy.Dwn.Resp.Writer,
request.Body,
Runner.Connection.Downstream.MaxBodyBytes))
var err error
var buf []byte
//read body. knows how to deal with transfer encoding: chunked, identity
buf, err = ioutil.ReadAll(bodyReader)
n := len(buf)
if int64(n) > Runner.Connection.Downstream.MaxBodyBytes {
proxy.Dwn.ReqTooLarge = true
infoOrTraceEv(proxy).
Str(path, proxy.Dwn.Path).
Str(method, proxy.Dwn.Method).
Str(XRequestID, proxy.XRequestID).
Int64(dwnElpsdMicros, time.Since(proxy.Dwn.startDate).Microseconds()).
Msgf(dwnBodyTooLarge, n, Runner.Connection.Downstream.MaxBodyBytes)
} else if err != nil && err != io.EOF {
ev := infoOrTraceEv(proxy).
Str(path, proxy.Dwn.Path).
Str(method, proxy.Dwn.Method).
Str(XRequestID, proxy.XRequestID).
Int64(dwnElpsdMicros, time.Since(proxy.Dwn.startDate).Microseconds())
if strings.Contains(err.Error(), timeout) {
proxy.Dwn.TimeoutFlag = true
ev.Msgf(dwnBodyReadTimeout, err)
} else {
proxy.Dwn.AbortedFlag = true
ev.Msgf(dwnBodyReadAbort, err)
}
} else {
proxy.Dwn.Body = buf
infoOrTraceEv(proxy).
Str(path, proxy.Dwn.Path).
Str(method, proxy.Dwn.Method).
Str(XRequestID, proxy.XRequestID).
Int(bodyBytes, len(proxy.Dwn.Body)).
Int64(dwnElpsdMicros, time.Since(proxy.Dwn.startDate).Microseconds()).
Msgf(dwnBodyRead, n, request.ContentLength)
}
}
func infoOrTraceEv(proxy *Proxy) *zerolog.Event {
var ev *zerolog.Event
if proxy.XRequestInfo {
ev = log.Info()
} else {
ev = log.Trace()
}
return ev
}
func infoOrDebugEv(proxy *Proxy) *zerolog.Event {
var ev *zerolog.Event
if proxy.XRequestInfo {
ev = log.Info()
} else {
ev = log.Debug()
}
return ev
}
const colon = ":"
func isIPv6(address string) bool {
ip := net.ParseIP(address)
return ip != nil && ip.To4() == nil
}
func parseHost(request *http.Request) string {
host := request.Host
hostElements := strings.Split(host, ":")
//trim port for ipv4
if len(hostElements) == 2 {
host = hostElements[0]
}
//trim port for ipv6
if strings.Contains(host, "]") {
host = host[:strings.LastIndex(host, "]")+1]
}
host, _ = idna.ToASCII(host)
return host
}
func parseMethod(request *http.Request) string {
return strings.ToUpper(request.Method)
}
func parseUserAgent(request *http.Request) string {
ua := request.Header.Get("User-Agent")
if len(ua) == 0 {
ua = "unknown"
}
return ua
}
func parseHTTPVer(request *http.Request) string {
return fmt.Sprintf("%d.%d", request.ProtoMajor, request.ProtoMinor)
}
const xRequestInfo = "X-REQUEST-INFO"
const xRequestDebug = "X-REQUEST-DEBUG"
const trueStr = "true"
func parseXRequestInfo(request *http.Request) bool {
h := request.Header.Get(xRequestInfo)
h2 := request.Header.Get(xRequestDebug)
return (len(h) > 0 && strings.ToLower(h) == trueStr) ||
(len(h2) > 0 && strings.ToLower(h2) == trueStr)
}
// parseTlsVersion checks the TLS version of the incoming request.
// It returns the TLS version as a string.
// If the request has a valid TLS connection, the function checks
// the TLS version and returns "1.2" for TLS 1.2, "1.3" for TLS 1.3,
// and "unknown" for other TLS versions.
// If the request does not have a TLS connection, the function returns
// "none".
// Usage Example:
//
// req, _ := http.NewRequest("GET", "/hello", nil)
// req.TLS = &tls.ConnectionState{
// Version: tls.VersionTLS12,
// }
// if "1.2" != parseTlsVersion(req) {
// t.Errorf("wrong TLS version")
// }
func parseTlsVersion(request *http.Request) string {
if request.TLS != nil {
if request.TLS.Version == tls.VersionTLS12 {
return string(TLS12)
}
if request.TLS.Version == tls.VersionTLS13 {
return string(TLS13)
}
return string(TLS_UNKNOWN)
} else {
return string(TLS_NONE)
}
}
func createXRequestID(request *http.Request) string {
//matches case insensitive
xr := request.Header.Get(XRequestID)
if len(xr) == 0 {
uuid, _ := uuid.NewRandom()
xr = fmt.Sprintf("XR-%s-%s", ID, uuid)
}
return xr
}
func (proxy *Proxy) setOutgoing(out http.ResponseWriter) *Proxy {
proxy.Dwn.Resp = Resp{
Writer: out,
}
return proxy
}
func (proxy Proxy) bodyReader() io.Reader {
if len(proxy.Dwn.Body) > 0 {
return bytes.NewReader(proxy.Dwn.Body)
}
return nil
}
const upstreamAttemptInitialized = "upstream attempt initialized"
func (proxy *Proxy) firstAttempt(URL *URL, label string) *Proxy {
first := Atmpt{
Label: label,
URL: URL,
Count: 1,
resp: nil,
respBody: nil,
CompleteHeader: make(chan struct{}),
CompleteBody: make(chan struct{}),
Aborted: make(chan struct{}),
CancelFunc: nil,
startDate: time.Now(),
}
proxy.Up.Atmpts = []Atmpt{first}
proxy.Up.Atmpt = &proxy.Up.Atmpts[0]
proxy.Up.Count = 1
scaffoldUpAttemptLog(proxy).
Str(upResource, URL.String()).
Msg(upstreamAttemptInitialized)
return proxy
}
const upAtmptCnt = "upAtmptCnt"
func (proxy *Proxy) nextAttempt() *Proxy {
next := Atmpt{
URL: proxy.Up.Atmpt.URL,
Label: proxy.Up.Atmpt.Label,
Count: proxy.Up.Atmpt.Count + 1,
StatusCode: 0,
resp: nil,
respBody: nil,
CompleteHeader: make(chan struct{}),
CompleteBody: make(chan struct{}),
Aborted: make(chan struct{}),
AbortedFlag: false,
CancelFunc: nil,
startDate: time.Now(),
}
proxy.Up.Atmpts = append(proxy.Up.Atmpts, next)
proxy.Up.Count = next.Count
proxy.Up.Atmpt = &proxy.Up.Atmpts[len(proxy.Up.Atmpts)-1]
scaffoldUpAttemptLog(proxy).
Int(upAtmptCnt, proxy.Up.Count).
Str(upResource, next.URL.String()).
Msg(upstreamAttemptInitialized)
return proxy
}
func (proxy *Proxy) copyUpstreamResponseHeaders() {
for key, values := range proxy.Up.Atmpt.resp.Header {
if shouldProxyHeader(key) {
for _, value := range values {
proxy.Dwn.Resp.Writer.Header().Add(key, value)
}
}
}
}
const upstreamEncodeFlate = "upstream response body re-encoded with flate before passing downstream"
const upstreamEncodeBr = "upstream response body re-encoded with brotli before passing downstream"
const upstreamEncodeGzip = "upstream response body re-encoded with gzip before passing downstream"
const upstreamCopyNoRecode = "upstream response body copied without re-coding before passing downstream"
const upstreamResponseNoBody = "upstream response has no body, nothing to copy before passing downstream"
const varyS = "Vary"
func (proxy *Proxy) encodeUpstreamResponseBody() {
atmpt := *proxy.Up.Atmpt
if atmpt.respBody != nil && len(*atmpt.respBody) > 0 {
//we pass through all compressed responses as is, including unsupported deflate and compress codecs.
//this includes custom encodings, i.e. multiple compressions in series.
if atmpt.ContentEncoding.isEncoded() {
proxy.Dwn.Resp.Body = atmpt.respBody
proxy.Dwn.Resp.ContentEncoding = atmpt.ContentEncoding
scaffoldUpAttemptLog(proxy).
Msgf(upstreamCopyNoRecode)
} else if proxy.Dwn.AcceptEncoding.isCompatible(EncGzip) {
proxy.Dwn.Resp.Body = Gzip(*atmpt.respBody)
proxy.Dwn.Resp.ContentEncoding = EncGzip
scaffoldUpAttemptLog(proxy).
Msg(upstreamEncodeGzip)
} else if proxy.Dwn.AcceptEncoding.isCompatible(EncBrotli) {
proxy.Dwn.Resp.Body = BrotliEncode(*atmpt.respBody)
proxy.Dwn.Resp.ContentEncoding = EncBrotli
scaffoldUpAttemptLog(proxy).
Msg(upstreamEncodeBr)
} else {
proxy.Dwn.Resp.Body = atmpt.respBody
if len(atmpt.ContentEncoding) > 0 {
//only set this if it was present upstream, otherwise assume nothing and leave empty.
proxy.Dwn.Resp.ContentEncoding = atmpt.ContentEncoding
} else {
proxy.Dwn.Resp.ContentEncoding = EncIdentity
}
scaffoldUpAttemptLog(proxy).
Msgf(upstreamCopyNoRecode)
}
//set this when present, but do not give instructions for empty values
if len(proxy.Dwn.Resp.ContentEncoding) > 0 {
proxy.Dwn.Resp.Writer.Header().Set(contentEncoding, proxy.Dwn.Resp.ContentEncoding.print())
}
//send a vary header for accept encoding if final downstream content encoding
//doesn't match expectations for content negotiation, i.e. when upstream was passed through.
if !proxy.Dwn.AcceptEncoding.isCompatible(proxy.Dwn.Resp.ContentEncoding) {
proxy.Dwn.Resp.Writer.Header().Set(varyS, acceptEncoding)
}
} else {
//just in case golang tries to use this value downstream.
nobody := make([]byte, 0)
proxy.Dwn.Resp.Body = &nobody
scaffoldUpAttemptLog(proxy).
Msg(upstreamResponseNoBody)
}
}
func (proxy *Proxy) setRoute(route *Route) {
proxy.Route = route
}
const connectS = "CONNECT"
const head = "HEAD"
// RFC7230, section 3.3.2
func (proxy *Proxy) setContentLengthHeader() {
proxy.Dwn.Resp.ContentLength = 0
if proxy.Dwn.Resp.Body != nil {
proxy.Dwn.Resp.ContentLength = int64(len(*proxy.Dwn.Resp.Body))
}
if te := proxy.Dwn.Resp.Writer.Header().Get(transferEncoding); len(te) != 0 ||
//we set 0 for status code 204 because of RFC7230, 4.3.7, see: https://tools.ietf.org/html/rfc7231#page-31
//however golang removes this in it's own implementation.
//Spec ambiguous, see Errata: https://www.rfc-editor.org/errata/eid5806
//overall there is little harm done by absent header. J8a tests distinguish between
//Content-Length==0 and no header present to detect when/if future golang version changes behavior.
proxy.Dwn.Resp.StatusCode == 204 ||
(proxy.Dwn.Resp.StatusCode >= 100 && proxy.Dwn.Resp.StatusCode < 200) ||
proxy.Dwn.Method == connectS {
proxy.Dwn.Resp.ContentLength = 0
} else if proxy.Dwn.Method == head {
//special case for upstream HEAD response with intact content-length we do copy
//see RFC7231 4.3.2: https://tools.ietf.org/html/rfc7231#page-25
cl := proxy.Up.Atmpt.resp.Header.Get(contentLength)
cli, err := strconv.ParseInt(cl, 10, 64)
if len(cl) > 0 && err == nil {
proxy.Dwn.Resp.ContentLength = cli
} else {
proxy.Dwn.Resp.ContentLength = 0
}
}
proxy.Dwn.Resp.Writer.Header().Set(contentLength, fmt.Sprintf("%d", proxy.Dwn.Resp.ContentLength))
}
func (proxy *Proxy) pipeDownstreamResponse() {
proxy.Dwn.Resp.Writer.Write(*proxy.Dwn.Resp.Body)
}
// status Code must be last, no headers may be written after this one.
func (proxy *Proxy) copyUpstreamStatusCodeHeader() {
proxy.respondWith(proxy.Up.Atmpt.StatusCode, "none")
}
func (proxy *Proxy) sendDownstreamStatusCodeHeader() {
proxy.Dwn.Resp.Writer.WriteHeader(proxy.Dwn.Resp.StatusCode)
}
func (proxy *Proxy) respondWith(statusCode int, message string) *Proxy {
proxy.Dwn.Resp.StatusCode = statusCode
proxy.Dwn.Resp.Message = message
return proxy
}
func (proxy *Proxy) hasLegalHTTPMethod() bool {
for _, legal := range httpLegalMethods {
if proxy.Dwn.Method == legal {
return true
}
}
return false
}
// get bearer token from request. feed into lib. check signature. check expiry. return true || false.
func (proxy *Proxy) validateJwt() bool {
var token string = ""
var err error
ok := false
ev := log.Trace().
Str("dwnReqPath", proxy.Dwn.Path).
Str(XRequestID, proxy.XRequestID)
auth := proxy.Dwn.Req.Header.Get(Authorization)
bearer := strings.Split(auth, Sep)
if len(bearer) > 1 {
token = bearer[1]
routeSec := Runner.Jwt[proxy.Route.Jwt]
alg := *new(jwa.SignatureAlgorithm)
alg.Accept(routeSec.Alg)
var parsed jwt.Token
switch alg {
case jwa.RS256, jwa.RS384, jwa.RS512, jwa.PS256, jwa.PS384, jwa.PS512:
parsed, err = proxy.verifyJwtSignature(token, routeSec.RSAPublic, alg, ev)
case jwa.ES256, jwa.ES384, jwa.ES512:
parsed, err = proxy.verifyJwtSignature(token, routeSec.ECDSAPublic, alg, ev)
case jwa.HS256, jwa.HS384, jwa.HS512:
parsed, err = proxy.verifyJwtSignature(token, routeSec.Secret, alg, ev)
case jwa.NoSignature:
parsed, err = jwt.Parse([]byte(token))
default:
parsed, err = jwt.Parse([]byte(token))
}
//date claims are verified separately to signature including skew
skew, _ := strconv.Atoi(routeSec.AcceptableSkewSeconds)
if parsed != nil && err == nil {
err = verifyDateClaims(token, skew, ev)
}
if parsed != nil && err == nil {
err = proxy.verifyMandatoryJwtClaims(parsed, ev)
}
if parsed != nil {
logDateClaims(parsed, ev)
}
ok = parsed != nil && err == nil
} else {
err = errors.New("jwt bearer token not present")
}
if ok {
ev.Int64("dwnElapsedMicros", time.Since(proxy.Dwn.startDate).Microseconds()).
Msg("jwt token validated")
} else {
ev.Int64("dwnElapsedMicros", time.Since(proxy.Dwn.startDate).Microseconds()).
Msgf("jwt token rejected, cause: %v", err)
}
return ok
}
func (proxy *Proxy) verifyMandatoryJwtClaims(token jwt.Token, ev *zerolog.Event) error {
var err error
jwtc := Runner.Jwt[proxy.Route.Jwt]
if jwtc.hasMandatoryClaims() {
err = errors.New("failed to match any claims required by route")
ev.Bool("jwtClaimsMatchRequiredAny", false)
ev.Bool("jwtClaimsHasRequiredAny", true)
} else {
ev.Bool("jwtClaimsHasRequiredAny", false)
}
for i, claim := range jwtc.Claims {
if len(claim) > 0 {
lk := "jwtClaimsMatchRequired[" + claim + "]"
ev.Bool(lk, false)
json, _ := token.AsMap(context.Background())
iter := jwtc.claimsVal[i].Run(json)
value, ok := iter.Next()
if value != nil {
if _, nok := value.(error); nok {
err = value.(error)
} else if ok {
ev.Bool("jwtClaimsMatchRequiredAny", true)
ev.Bool(lk, ok)
return nil
} else {
err = errors.New(fmt.Sprintf("claim not matched %s", claim))
}
}
}
}
return err
}
func (proxy *Proxy) verifyJwtSignature(token string, keySet KeySet, alg jwa.SignatureAlgorithm, ev *zerolog.Event) (jwt.Token, error) {
var msg *jws.Message
var err error
var parsed jwt.Token
msg, err = jws.Parse([]byte(token))
if len(msg.Signatures()) > 0 {
//first we try to validate by a key with the kid parameter to match.
kid := extractKid(token)
var key interface{}
if len(kid) > 0 {
ev.Str("jwtKid", kid)
key = keySet.Find(kid)
if key != nil {
parsed, err = jwt.Parse([]byte(token),
jwt.WithVerify(alg, key))
} else {
proxy.triggerKeyRotationCheck(kid)
}
}
//TODO: try this with x5t SHA1 thumbprint on previously loaded keys to augment kid. If you're reading this
//TODO: comment feel free to get in touch with a github issue.
//if it didn't validate above, we try other keys, provided there are any
if len(kid) == 0 ||
key == nil ||
(err != nil && len(keySet) > 1) {
for _, kp := range keySet {
parsed, err = jwt.Parse([]byte(token),
jwt.WithVerify(alg, kp.Key))
if err == nil {
break
}
}
}
} else {
err = errors.New("no signature found on jwt token")
}
return parsed, err
}
func (proxy *Proxy) triggerKeyRotationCheck(kid string) {
route := proxy.Route
routeSec := Runner.Jwt[route.Jwt]
if len(routeSec.JwksUrl) > 0 {
//MUST run async since it will block on loading remote JWKS key
go routeSec.LoadJwks()
log.Info().
Str("route", route.Path).
Str("jwt", route.Jwt).
Str(XRequestID, proxy.XRequestID).
Msgf("unmatched kid [%v] on incoming req triggered background key rotation search for route [%v] jwt [%v]", kid, route.Path, route.Jwt)
}
}
func logDateClaims(parsed jwt.Token, ev *zerolog.Event) {
if parsed.IssuedAt().Unix() > 1 {
ev.Bool("jwtClaimsIat", true)
ev.Str("jwtIatUtcIso", parsed.IssuedAt().Format(time.RFC3339))
ev.Str("jwtIatLclIso", parsed.IssuedAt().Local().Format(time.RFC3339))
ev.Int64("jwtIatUnix", parsed.IssuedAt().Unix())
} else {
ev.Bool("jwtClaimsIat", false)
}
if parsed.NotBefore().Unix() > 1 {
ev.Bool("jwtClaimsNbf", true)
ev.Str("jwtNbfUtcIso", parsed.NotBefore().Format(time.RFC3339))
ev.Str("jwtNbfLclIso", parsed.NotBefore().Local().Format(time.RFC3339))
ev.Int64("jwtNbfUnix", parsed.NotBefore().Unix())
} else {
ev.Bool("jwtClaimsNbf", false)
}
if parsed.Expiration().Unix() > 1 {
ev.Bool("jwtClaimsExp", true)
ev.Str("jwtExpUtcIso", parsed.Expiration().Format(time.RFC3339))
ev.Str("jwtExpLclIso", parsed.Expiration().Local().Format(time.RFC3339))
ev.Int64("jwtExpUnix", parsed.Expiration().Unix())
} else {
ev.Bool("jwtClaimsExp", false)
}
}
func verifyDateClaims(token string, skew int, ev *zerolog.Event) error {
//arghh i need a deep copy of this token so i can modify it, but it's an interface wrapping a package private jwt.stdToken
//so i need to parse it again.
skewed, err := jwt.Parse([]byte(token))
if skewed.IssuedAt().Unix() > int64(skew*1000) {
ev.Bool("jwtClaimsIatValidated", true)
skewed.Set("iat", skewed.IssuedAt().Add(-time.Second*time.Duration(skew)))
}
if skewed.NotBefore().Unix() > int64(skew*1000) {
ev.Bool("jwtClaimsNbfValidated", true)
skewed.Set("nbf", skewed.NotBefore().Add(-time.Second*time.Duration(skew)))
}
if skewed.Expiration().Unix() > 1 {
ev.Bool("jwtClaimsExpValidated", true)
skewed.Set("exp", skewed.Expiration().Add(time.Second*time.Duration(skew)))
}
if skewed != nil {
err = jwt.Validate(skewed)
}
if err != nil && strings.Contains(err.Error(), "iat") {
ev.Bool("jwtClaimsIatValidated", false)
}
if err != nil && strings.Contains(err.Error(), "nbf") {
ev.Bool("jwtClaimsNbfValidated", false)
}
if err != nil && strings.Contains(err.Error(), "exp") {
ev.Bool("jwtClaimsExpValidated", false)
}
return err
}
func extractKid(token string) string {
header := strings.Split(token, ".")[0]
var decoded []byte
decoded, err := base64.RawURLEncoding.DecodeString(header)
if err != nil {
return ""
}
var jsonToken map[string]interface{} = make(map[string]interface{})
err = json.Unmarshal(decoded, &jsonToken)
if err != nil {
return ""
}
kid := jsonToken["kid"]
switch kid.(type) {
case string:
return kid.(string)
default:
return ""
}
}