caidooss/grafana-auth-proxy

View on GitHub
pkg/validation/validation.go

Summary

Maintainability
A
1 hr
Test Coverage
package validation

import (
    "crypto/rsa"
    "errors"
    "fmt"

    "github.com/dgrijalva/jwt-go"
    "github.com/lestrrat-go/jwx/jwk"
)

type TokenValidator struct {
    keys       *jwk.Set
    algorithms []string
    audience   string
    issuer     string
}

func NewTokenValidator(keys *jwk.Set, algorithms []string, audience string, issuer string) *TokenValidator {
    return &TokenValidator{
        keys,
        algorithms,
        audience,
        issuer,
    }
}

func (tv *TokenValidator) Validate(tokenString string) (*jwt.Token, error) {
    // Extract token
    token, err := jwt.Parse(tokenString, tv.getTokenAssociatedPublicKey)
    if err != nil {
        switch err.(type) {
        case *jwt.ValidationError:
            validationError := err.(*jwt.ValidationError)
            errorMessage := validationError.Inner
            errorCode := validationError.Errors

            if errorCode == jwt.ValidationErrorExpired {
                errorM := fmt.Sprintf("TOKEN EXPIRED : error_message=\"%v\" error_code=%v", errorMessage, errorCode)
                return nil, &Error{errorM, ErrorExpired}
            } else {
                errorM := fmt.Sprintf("VALIDATION ERROR : error_message=\"%v\" error_code=%v", errorMessage, errorCode)
                return nil, &Error{errorM, ErrorValidation}
            }
        default:
            return nil, err
        }
    }

    // Ensure validity and claims (https://auth0.com/docs/api-auth/tutorials/verify-access-token)
    if !token.Valid {
        return token, &Error{"Token is invalid", ErrorToken}
    }

    tokenAud := token.Claims.(jwt.MapClaims)["aud"].(string)
    if tokenAud != tv.audience {
        return token, &Error{"audience does not match", ErrorAudience}
    }

    tokenIssuer := token.Claims.(jwt.MapClaims)["iss"].(string)
    if tokenIssuer != tv.issuer {
        return token, &Error{"issuer does not match", ErrorIssuer}
    }

    return token, nil
}

func (tv *TokenValidator) getTokenAssociatedPublicKey(token *jwt.Token) (interface{}, error) {
    // Verify ALG: it should at least be not "none". We decided to restrict it further to a set of trusted algorithms.
    // See vulnerability: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
    algHeader := token.Header["alg"]
    if algHeader == nil {
        return nil, &Error{"token ALG header is nil", ErrorNilAlg}
    }
    if !stringInSlice(algHeader.(string), tv.algorithms) {
        errorMessage := fmt.Sprintf("algorithm %v is not in accepted algoritmns (%v)", algHeader, tv.algorithms)
        return nil, &Error{errorMessage, ErrorUnsupportedAlg}
    }

    // Fetch keys associated with the KID
    kidHeader := token.Header["kid"]
    if kidHeader == nil {
        return nil, &Error{"token KID header is nil", ErrorNilKid}
    }

    keys := tv.keys.LookupKeyID(kidHeader.(string))
    if len(keys) == 0 {
        return nil, errors.New("failed to lookup key")
    }

    // Use the first key
    key, err := keys[0].Materialize()
    if err != nil {
        return nil, errors.New(fmt.Sprintf("failed to generate public key: %s", err))
    }

    return key.(*rsa.PublicKey), nil
}

func stringInSlice(a string, list []string) bool {
    for _, b := range list {
        if b == a {
            return true
        }
    }
    return false
}