package httpcache

import (


// Module lifecycle
// 1. Loaded (the Unmarshaler?)
// 2. Provisioned and validated
// 3. used
// 4. cleaned up

var (
    entries     []map[string][]*Entry
    entriesLock []*sync.RWMutex
    l           sync.RWMutex
    keyBufPool  = sync.Pool{
        New: func() interface{} {
            return new(bytes.Buffer)

// intend to mock for test
var now = func() time.Time { return time.Now().UTC() }

// RuleMatcherType specifies the type of matching rule to cache.
type RuleMatcherType string

// the following list the different way to decide the request
// whether is matched or not to be cached it's response.
const (
    MatcherTypePath   RuleMatcherType = "path"
    MatcherTypeHeader RuleMatcherType = "header"

// RuleMatcherRawWithType stores the marshal content for unmarshalling in provision stage
type RuleMatcherRawWithType struct {
    Type RuleMatcherType
    Data json.RawMessage

// RuleMatcher determines whether the request should be cached or not.
type RuleMatcher interface {
    matches(*http.Request, int, http.Header) bool

// PathRuleMatcher determines whether the request's path is matched.
type PathRuleMatcher struct {
    Path string `json:"path"`

func (p *PathRuleMatcher) matches(req *http.Request, statusCode int, resHeaders http.Header) bool {
    return strings.HasPrefix(req.URL.Path, p.Path)

// HeaderRuleMatcher determines whether the request's header is matched.
type HeaderRuleMatcher struct {
    Header string   `json:"header"`
    Value  []string `json:"value"`

func (h *HeaderRuleMatcher) matches(req *http.Request, statusCode int, resHeaders http.Header) bool {
    headerValues := resHeaders.Get(h.Header)
    for _, value := range h.Value {
        if value == headerValues {
            return true

    return false

func expirationObject(obj *cacheobject.Object, rv *cacheobject.ObjectResults) {
     * Okay, lets calculate Freshness/Expiration now. woo:

       o  If the cache is shared and the s-maxage response directive
          (Section is present, use its value, or

       o  If the max-age response directive (Section is present,
          use its value, or

       o  If the Expires response header field (Section 5.3) is present, use
          its value minus the value of the Date response header field, or

       o  Otherwise, no explicit expiration time is present in the response.
          A heuristic freshness lifetime might be applicable; see
          Section 4.2.2.

    var expiresTime time.Time

    if obj.RespDirectives.SMaxAge != -1 && !obj.CacheIsPrivate {
        expiresTime = obj.NowUTC.Add(time.Second * time.Duration(obj.RespDirectives.SMaxAge))
    } else if obj.RespDirectives.MaxAge != -1 {
        expiresTime = obj.NowUTC.UTC().Add(time.Second * time.Duration(obj.RespDirectives.MaxAge))
    } else if !obj.RespExpiresHeader.IsZero() {
        serverDate := obj.RespDateHeader
        if serverDate.IsZero() {
            // common enough case when a Date: header has not yet been added to an
            // active response.
            serverDate = obj.NowUTC
        expiresTime = obj.NowUTC.Add(obj.RespExpiresHeader.Sub(serverDate))
    } else {
        expiresTime = obj.NowUTC

    rv.OutExpirationTime = expiresTime

func judgeResponseShouldCacheOrNot(req *http.Request,
    statusCode int,
    respHeaders http.Header,
    privateCache bool) ([]cacheobject.Reason, time.Time, []cacheobject.Warning, *cacheobject.Object, error) {

    var reqHeaders http.Header
    var reqMethod string
    var reqDir *cacheobject.RequestCacheDirectives

    respDir, err := cacheobject.ParseResponseCacheControl(respHeaders.Get("Cache-Control"))
    if err != nil {
        return nil, time.Time{}, nil, nil, err

    if req != nil {
        reqDir, err = cacheobject.ParseRequestCacheControl(req.Header.Get("Cache-Control"))
        if err != nil {
            return nil, time.Time{}, nil, nil, err
        reqHeaders = req.Header
        reqMethod = req.Method

    var expiresHeader time.Time
    var dateHeader time.Time
    var lastModifiedHeader time.Time

    if respHeaders.Get("Expires") != "" {
        expiresHeader, err = http.ParseTime(respHeaders.Get("Expires"))
        if err != nil {
            // sometimes servers will return `Expires: 0` or `Expires: -1` to
            // indicate expired content
            expiresHeader = time.Time{}
        expiresHeader = expiresHeader.UTC()

    if respHeaders.Get("Date") != "" {
        dateHeader, err = http.ParseTime(respHeaders.Get("Date"))
        if err != nil {
            return nil, time.Time{}, nil, nil, err
        dateHeader = dateHeader.UTC()

    if respHeaders.Get("Last-Modified") != "" {
        lastModifiedHeader, err = http.ParseTime(respHeaders.Get("Last-Modified"))
        if err != nil {
            return nil, time.Time{}, nil, nil, err
        lastModifiedHeader = lastModifiedHeader.UTC()

    obj := cacheobject.Object{
        CacheIsPrivate: privateCache,

        RespDirectives:         respDir,
        RespHeaders:            respHeaders,
        RespStatusCode:         statusCode,
        RespExpiresHeader:      expiresHeader,
        RespDateHeader:         dateHeader,
        RespLastModifiedHeader: lastModifiedHeader,

        ReqDirectives: reqDir,
        ReqHeaders:    reqHeaders,
        ReqMethod:     reqMethod,

        NowUTC: now(),
    rv := cacheobject.ObjectResults{}

    cacheobject.CachableObject(&obj, &rv)
    if rv.OutErr != nil {
        return nil, time.Time{}, nil, nil, rv.OutErr

    expirationObject(&obj, &rv)
    if rv.OutErr != nil {
        return nil, time.Time{}, nil, nil, rv.OutErr

    return rv.OutReasons, rv.OutExpirationTime, rv.OutWarnings, &obj, nil

func getCacheStatus(req *http.Request, response *Response, config *Config) (bool, time.Time) {
    // NOTE: it seems that we can remove lock timeout
    if response.Code == http.StatusPartialContent || response.snapHeader.Get("Content-Range") != "" {
        return false, now().Add(config.LockTimeout)

    if response.Code == http.StatusNotModified {
        return false, now()

    reasonsNotToCache, expiration, _, _, err := judgeResponseShouldCacheOrNot(req, response.Code, response.snapHeader, false)
    if err != nil {
        return false, time.Time{}

    isPublic := len(reasonsNotToCache) == 0
    if !isPublic {
        return false, now().Add(config.LockTimeout)

    varyHeader := response.snapHeader.Get("Vary")
    if varyHeader == "*" {
        return false, now().Add(config.LockTimeout)

    for _, rule := range config.RuleMatchers {
        if !rule.matches(req, response.Code, response.snapHeader) {
            return false, now()

    if now().After(expiration.Add(-1 * time.Second)) {
        expiration = now().Add(config.DefaultMaxAge)

    return true, expiration

func matchVary(curReq *http.Request, entry *Entry) bool {
    // NOTE:
    vary := entry.Response.HeaderMap.Get("Vary")

    for _, searchedHeader := range strings.Split(vary, ",") {
        searchedHeader = strings.TrimSpace(searchedHeader)
        if curReq.Header.Get(searchedHeader) != entry.Request.Header.Get(searchedHeader) {
            return false

    return true

// Entry consists of a cache key and one or more response corresponding to
// the prior requests.
type Entry struct {
    isPublic   bool
    expiration time.Time
    key        string
    Request    *http.Request
    Response   *Response

// NewEntry creates a new Entry for the given request and response
// and it also calculates whether it is public or not
func NewEntry(key string, request *http.Request, response *Response, config *Config) *Entry {
    isPublic, expiration := getCacheStatus(request, response, config)

    return &Entry{
        isPublic:   isPublic,
        key:        key,
        expiration: expiration,
        Request:    request,
        Response:   response,

// Key return the key for the entry
func (e *Entry) Key() string {
    return e.key

// Clean purges the cache
func (e *Entry) Clean() error {
    return e.Response.Clean()

func (e *Entry) writePublicResponse(w http.ResponseWriter) error {
    // TODO: Maybe we can redesign here to get a better performance
    reader, err := e.Response.GetReader()

    if err != nil {
        return err

    defer reader.Close()

    // In io.copy will write the status code.
    length := w.Header().Get("Content-Length")

    if length == "" {
        contentLength := strconv.Itoa(e.Response.body.Length())
        if contentLength != "0" {
            w.Header().Set("Content-Length", contentLength)

    // wow, we should write the header before calling this function
    _, err = io.Copy(w, reader)
    return err

func (e *Entry) writePrivateResponse(w http.ResponseWriter) error {
    // wrap the original response writer
    return nil

// WriteBodyTo sends the body to the http.ResponseWritter
func (e *Entry) WriteBodyTo(w http.ResponseWriter) error {
    // the definition of private response seems come from
    // the package cacheobject
    if !e.isPublic {
        return e.writePrivateResponse(w)

    return e.writePublicResponse(w)

// IsFresh indicates this entry is not expired
func (e *Entry) IsFresh() bool {
    return e.expiration.After(time.Now())

func (e *Entry) keyWithRespectVary() string {
    buf := keyBufPool.Get().(*bytes.Buffer)

    defer keyBufPool.Put(buf)

    vary := e.Response.snapHeader.Get("Vary")
    for _, header := range strings.Split(vary, ",") {

    return url.PathEscape(buf.String())

func (e *Entry) setBackend(ctx context.Context, config *Config) error {
    var backend backends.Backend
    var err error

    switch config.Type {
    case file:
        backend, err = backends.NewFileBackend(config.Path)
    case inMemory:
        backend, err = backends.NewInMemoryBackend(ctx, e.keyWithRespectVary(), e.expiration)
    case redis:
        backend, err = backends.NewRedisBackend(ctx, e.keyWithRespectVary(), e.expiration)

    return err

// HTTPCache is a http cache for http request which is focus on static files
type HTTPCache struct {
    cacheKeyTemplate string
    cacheBucketsNum  int
    entries          []map[string][]*Entry
    entriesLock      []*sync.RWMutex
    isDistributed    bool

// NewHTTPCache new a HTTPCache to handle cache entries
func NewHTTPCache(config *Config, distributedOn bool) *HTTPCache {
    // TODO: how to handle when the bucket's num is changed
    defer l.Unlock()

    if entries == nil {
        entries = make([]map[string][]*Entry, config.CacheBucketsNum)
        for i := 0; i < config.CacheBucketsNum; i++ {
            entries[i] = make(map[string][]*Entry)

    if entriesLock == nil {
        entriesLock = make([]*sync.RWMutex, config.CacheBucketsNum)
        for i := 0; i < config.CacheBucketsNum; i++ {
            entriesLock[i] = new(sync.RWMutex)

    return &HTTPCache{
        cacheKeyTemplate: config.CacheKeyTemplate,
        cacheBucketsNum:  config.CacheBucketsNum,
        entries:          entries,
        entriesLock:      entriesLock,
        isDistributed:    distributedOn,


func (h *HTTPCache) getBucketIndexForKey(key string) uint32 {
    return uint32(math.Mod(float64(crc32.ChecksumIEEE([]byte(key))), float64(h.cacheBucketsNum)))

// In caddy2, it is automatically add the map by addHTTPVarsToReplacer
func getKey(cacheKeyTemplate string, r *http.Request) string {
    repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)

    // Add contentlength and bodyhash when not added before
    if _, ok := repl.Get("http.request.contentlength"); !ok {
        repl.Set("http.request.contentlength", r.ContentLength)
        repl.Map(func(key string) (interface{}, bool) {
            if key == "http.request.bodyhash" {
                return bodyHash(r), true
            return nil, false

    return repl.ReplaceKnown(cacheKeyTemplate, "")

// bodyHash calculates a hash value of the request body
func bodyHash(r *http.Request) string {
    body, err := io.ReadAll(r.Body)
    if err != nil {
        return ""

    h := sha1.New()
    bs := h.Sum(nil)
    result := fmt.Sprintf("%x", bs)

    r.Body = io.NopCloser(bytes.NewBuffer(body))

    return result

// Get returns the cached response
func (h *HTTPCache) Get(key string, request *http.Request, includeStale bool) (*Entry, bool) {
    b := h.getBucketIndexForKey(key)
    defer h.entriesLock[b].RUnlock()

    previousEntries, exists := h.entries[b][key]

    if !exists {
        return nil, false

    for _, entry := range previousEntries {
        if (entry.IsFresh() || includeStale) && matchVary(request, entry) {
            return entry, true

    return nil, false

// Keys list the keys holden by this cache
func (h *HTTPCache) Keys() []string {
    keys := []string{}

    for index, l := range h.entriesLock {
        for k, v := range h.entries[index] {
            if len(v) != 0 {
                keys = append(keys, k)

    return keys

// Del purge the key immediately
func (h *HTTPCache) Del(key string) error {
    b := h.getBucketIndexForKey(key)
    previousEntries, exists := h.entries[b][key]

    if !exists {
        return nil

    // the schedule will clean the entry automatically
    for _, entry := range previousEntries {
        if entry.IsFresh() {
            err := h.cleanEntry(entry)
            if err != nil {
                caddy.Log().Named("http.handlers.http_cache").Error(fmt.Sprintf("clean entry error: %s", err.Error()))
                return err

    return nil

// Put adds the entry in the cache
func (h *HTTPCache) Put(request *http.Request, entry *Entry, config *Config) {
    key := entry.Key()
    bucket := h.getBucketIndexForKey(key)

    defer h.entriesLock[bucket].Unlock()

    h.scheduleCleanEntry(entry, config.StaleMaxAge)

    for i, previousEntry := range h.entries[bucket][key] {
        if matchVary(entry.Request, previousEntry) {
            go previousEntry.Clean()
            h.entries[bucket][key][i] = entry

    h.entries[bucket][key] = append(h.entries[bucket][key], entry)

func (h *HTTPCache) distributedClean(key string, entry *Entry) error {
    // implement a simple Leader Election system
    // acquire a distributed lock here if the distributed mode on
    dl, err := distributed.NewDistributedLock(key)
    if err != nil {
        return err

    leaderCh, err := dl.Lock()
    if err != nil {
        return nil

    if leaderCh == nil {
        return nil

    select {
    case <-leaderCh:
        // do clean entry if this node is leader
        caddy.Log().Named("distributed cache").Debug("perform clean entry")
        err = entry.Clean()
        // sleep a little time wait for other nodes performing deleting cache
        time.Sleep(3 * time.Second)


    return err

func (h *HTTPCache) cleanEntry(entry *Entry) error {
    key := entry.Key()
    bucket := h.getBucketIndexForKey(key)

    defer h.entriesLock[bucket].Unlock()

    for i, otherEntry := range h.entries[bucket][key] {
        if entry == otherEntry {
            h.entries[bucket][key] = append(h.entries[bucket][key][:i], h.entries[bucket][key][i+1:]...)
            if !h.isDistributed {
                return entry.Clean()
            return h.distributedClean(key, entry)

    return nil

func (h *HTTPCache) scheduleCleanEntry(entry *Entry, staleMaxAge time.Duration) {
    go func(entry *Entry) {
        expiration := entry.expiration
        expiration = expiration.Add(staleMaxAge)