efritz/chevron

View on GitHub
middleware/cache.go

Summary

Maintainability
A
35 mins
Test Coverage
package middleware

import (
    "context"
    "encoding/json"
    "fmt"
    "net/http"
    "strings"

    "github.com/efritz/gache"
    "github.com/efritz/nacelle"
    "github.com/efritz/response"

    "github.com/efritz/chevron"
)

type CacheMiddleware struct {
    cache        gache.Cache
    tags         []string
    errorFactory ErrorFactory
}

// NewResponseCache creates middleware that stores the complete response
// in a cache instance. The wrapped handler is not invoked if a response
// payload for the given request is available in the cache.
func NewResponseCache(
    cache gache.Cache,
    configs ...CacheMiddlewareConfigFunc,
) chevron.Middleware {
    m := &CacheMiddleware{
        cache:        cache,
        errorFactory: defaultErrorFactory,
    }

    for _, config := range configs {
        config(m)
    }

    return m
}

func (m *CacheMiddleware) Convert(f chevron.Handler) (chevron.Handler, error) {
    handler := func(ctx context.Context, req *http.Request, logger nacelle.Logger) response.Response {
        // If we don't have a cache instance or if this request can
        // have side effects, do not attempt to touch the cache in
        // either direction.
        if m.cache == nil || !shouldCache(req) {
            return f(ctx, req, logger)
        }

        val, err := m.generateCacheValue(ctx, req, logger, f)
        if err != nil {
            logger.Error("failed to retrieve response from cache (%s)", err.Error())
            return m.errorFactory(err)
        }

        // Value is either from the cache or was just generated and
        // inserted into the cache. In the later case, it is not safe
        // to return the response object as the underlying reader has
        // been consumed. To get around this, we serialize the response
        // into a string and re-serialize it into a reader.
        resp, err := deserialize(val)
        if err != nil {
            logger.Error("failed to round-trip response (%s)", err.Error())
            return m.errorFactory(err)
        }

        return resp
    }

    return handler, nil
}

// GeneratCacheValue attempts to retrieve a response payload value
// from the cache. If no value exists, then the handler is called
// and the cache value is added to the cache.
func (c *CacheMiddleware) generateCacheValue(
    ctx context.Context,
    req *http.Request,
    logger nacelle.Logger,
    f chevron.Handler,
) (val string, err error) {
    key := c.makeCacheKey(req)

    if val, err = c.cache.GetValue(key); val != "" || err != nil {
        return
    }

    if val, err = serialize(f(ctx, req, logger)); err != nil {
        err = fmt.Errorf("failed to serialize response (%s)", err.Error())
        return
    }

    if err = c.cache.SetValue(key, val, c.tags...); err != nil {
        err = fmt.Errorf("failed to insert response into cache (%s)", err.Error())
    }

    return
}

// MakeCacheKey creates a cache key from a request.
func (c *CacheMiddleware) makeCacheKey(req *http.Request) string {
    // TODO - additional keys
    return strings.Replace(strings.Trim(req.URL.RequestURI(), "/"), "/", ".", -1)
}

//
// Helpers

// ShouldCache determiens if the request is cacheable. This must
// be an idempotent HTTP verb.
func shouldCache(req *http.Request) bool {
    return req.Method == "GET" || req.Method == "OPTIONS"
}

// Deserialize converts a cache value into a response. An error
// is returned if the cache value was not generated by a call
// to the serialize function.
func deserialize(data string) (response.Response, error) {
    payload := struct {
        Status  int         `json:"status"`
        Headers http.Header `json:"headers"`
        Body    string      `json:"body"`
    }{}

    if err := json.Unmarshal([]byte(data), &payload); err != nil {
        return nil, err
    }

    resp := response.Respond([]byte(payload.Body))
    resp.SetStatusCode(payload.Status)

    for k, vs := range payload.Headers {
        for _, v := range vs {
            resp.AddHeader(k, v)
        }
    }

    return resp, nil
}

// Serialize creates a cache value from a response.
func serialize(resp response.Response) (string, error) {
    headers, body, err := response.Serialize(resp)
    if err != nil {
        return "", err
    }

    serialized, err := json.Marshal(map[string]interface{}{
        "status":  resp.StatusCode(),
        "headers": headers,
        "body":    string(body),
    })

    return string(serialized), err
}