asteris-llc/converge

View on GitHub
rpc/jwt.go

Summary

Maintainability
A
50 mins
Test Coverage
// Copyright © 2016 Asteris, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package rpc

import (
    "errors"
    "fmt"
    "net/http"
    "strings"
    "time"

    jwt "github.com/dgrijalva/jwt-go"
    "golang.org/x/net/context"
    "google.golang.org/grpc"
    "google.golang.org/grpc/metadata"
)

const (
    // JWTLifetime is the acceptable lifetime of an issued JWT token
    JWTLifetime = 30 * time.Second

    // JWTAlg is the signing algorithm used for signing and verification
    JWTAlg = "HS512"
)

var (
    errAuthNotProvided = errors.New("authorization not provided")
)

// JWTAuth does authentication between client and server
type JWTAuth struct {
    token []byte
}

// NewJWTAuth initializes a new JWTAuth from the token
func NewJWTAuth(token string) *JWTAuth {
    return &JWTAuth{token: []byte(token)}
}

// GetRequestMetadata gets the current request metadata
func (j *JWTAuth) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
    token, err := j.New()
    if err != nil {
        return nil, err
    }

    return map[string]string{
        "authorization": "BEARER " + token,
    }, nil
}

// RequireTransportSecurity indicates whether JWT requires transport security
// (it does not)
func (j *JWTAuth) RequireTransportSecurity() bool { return false }

// New creates a signed token
func (j *JWTAuth) New() (string, error) {
    token := jwt.NewWithClaims(
        jwt.GetSigningMethod(JWTAlg),
        jwt.StandardClaims{
            IssuedAt:  time.Now().Unix(),
            ExpiresAt: time.Now().Add(JWTLifetime).Unix(),
        },
    )

    return token.SignedString(j.token)
}

// Verify a generated token
func (j *JWTAuth) Verify(material string) error {
    token, err := jwt.ParseWithClaims(
        material,
        &jwt.StandardClaims{},
        func(token *jwt.Token) (interface{}, error) {
            if alg, ok := token.Header["alg"]; !ok || alg != JWTAlg {
                return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
            }

            return j.token, nil
        },
    )
    if err != nil {
        return err
    }

    claims, ok := token.Claims.(*jwt.StandardClaims)
    if !ok {
        return errors.New("internal error, standard claims not present")
    }

    // standard verification: issued/expires at was not issued before now. No,
    // this doesn't account for clock skew. We'll see if it's actually a
    // problem.
    if !claims.VerifyIssuedAt(time.Now().Unix(), true) {
        return errors.New("issued at was invalid")
    }

    if !claims.VerifyExpiresAt(time.Now().Unix(), true) {
        return errors.New("expires at was invalid")
    }

    exp := time.Duration(claims.ExpiresAt) * time.Second
    iat := time.Duration(claims.IssuedAt) * time.Second

    if (exp - iat) != JWTLifetime {
        return fmt.Errorf("lifetime too large. Expected %s, was %s", JWTLifetime, (exp - iat))
    }

    return nil
}

// VerifyContext verifies a token in context metadata
func (j *JWTAuth) VerifyContext(ctx context.Context) error {
    md, ok := metadata.FromContext(ctx)
    if !ok {
        return errAuthNotProvided
    }

    tokens, ok := md["authorization"]
    if !ok {
        return errAuthNotProvided
    }

    return j.Verify(strings.TrimLeft(tokens[0], "BEARER "))
}

// Protect checks requests for a valid token
func (j *JWTAuth) Protect(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        var token string

        // Get token out of querystring, header, or cookie
        if query := r.URL.Query().Get("jwt"); query != "" {
            token = query
        } else if bearer := r.Header.Get("Authorization"); strings.HasPrefix(bearer, "BEARER ") {
            token = strings.TrimLeft(bearer, "BEARER ")
        } else if cookie, err := r.Cookie("jwt"); err != nil && cookie.Value != "" {
            token = cookie.Value
        }

        // Token is required
        if token == "" {
            http.Error(w, "authorization is required", http.StatusUnauthorized)
            return
        }

        // Validate token
        if err := j.Verify(token); err != nil {
            http.Error(w, err.Error(), http.StatusUnauthorized)
            return
        }

        // looks like we're good, call the next handler
        next.ServeHTTP(w, r)
    })
}

// StreamInterceptor implements StreamServerInterceptor to use in a middleware capacity
func (j *JWTAuth) StreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    if err := j.VerifyContext(stream.Context()); err != nil {
        return err
    }

    return handler(srv, stream)
}

// UnaryInterceptor implements UnaryServerInterceptor to use in a middleware capacity
func (j *JWTAuth) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
    if err := j.VerifyContext(ctx); err != nil {
        return nil, err
    }

    return handler(ctx, req)
}