vorteil/direktiv

View on GitHub
pkg/refactor/gateway/endpoints/tree.go

Summary

Maintainability
F
3 days
Test Coverage
// nolint
package endpoints

// This file has been copied and modified from https://github.com/go-chi/chi
// Radix tree implementation below is a based on the original work by
// Armon Dadgar in https://github.com/armon/go-radix/blob/master/radix.go
// (MIT licensed). It's been heavily modified for use as a HTTP routing tree.

import (
    "fmt"
    "net/http"
    "regexp"
    "sort"
    "strings"

    "github.com/direktiv/direktiv/pkg/refactor/core"
    "github.com/go-chi/chi/v5"
)

type methodTyp uint

const (
    mSTUB methodTyp = 1 << iota
    mCONNECT
    mDELETE
    mGET
    mHEAD
    mOPTIONS
    mPATCH
    mPOST
    mPUT
    mTRACE
)

var mALL = mCONNECT | mDELETE | mGET | mHEAD |
    mOPTIONS | mPATCH | mPOST | mPUT | mTRACE

var methodMap = map[string]methodTyp{
    http.MethodConnect: mCONNECT,
    http.MethodDelete:  mDELETE,
    http.MethodGet:     mGET,
    http.MethodHead:    mHEAD,
    http.MethodOptions: mOPTIONS,
    http.MethodPatch:   mPATCH,
    http.MethodPost:    mPOST,
    http.MethodPut:     mPUT,
    http.MethodTrace:   mTRACE,
}

type nodeTyp uint8

const (
    ntStatic   nodeTyp = iota // /home
    ntRegexp                  // /{id:[0-9]+}
    ntParam                   // /{user}
    ntCatchAll                // /api/v1/*
)

type node struct {
    // subroutes on the leaf node
    subroutes chi.Routes

    // regexp matcher for regexp nodes
    rex *regexp.Regexp

    // HTTP handler endpoints on the leaf node
    endpoints endpoints

    // prefix is the common prefix we ignore
    prefix string

    // child nodes should be stored in-order for iteration,
    // in groups of the node type.
    children [ntCatchAll + 1]nodes

    // first byte of the child prefix
    tail byte

    // node type: static, regexp, param, catchAll
    typ nodeTyp

    // first byte of the prefix
    label byte
}

// endpoints is a mapping of http method constants to handlers
// for a given route.
type endpoints map[methodTyp]*endpoint

type endpoint struct {
    // endpoint handler
    handler *core.Endpoint

    // pattern is the routing pattern for handler nodes
    pattern string

    // parameter keys recorded on handler nodes
    paramKeys []string
}

func (s endpoints) Value(method methodTyp) *endpoint {
    mh, ok := s[method]
    if !ok {
        mh = &endpoint{}
        s[method] = mh
    }
    return mh
}

func (n *node) InsertRoute(method methodTyp, pattern string, handler *core.Endpoint) *node {
    var parent *node
    search := pattern

    for {
        // Handle key exhaustion
        if len(search) == 0 {
            // Insert or update the node's leaf handler
            n.setEndpoint(method, handler, pattern)
            return n
        }
        // We're going to be searching for a wild node next,
        // in this case, we need to get the tail
        label := search[0]
        var segTail byte
        var segEndIdx int
        var segTyp nodeTyp
        var segRexpat string
        if label == '{' || label == '*' {
            segTyp, _, segRexpat, segTail, _, segEndIdx = patNextSegment(search)
        }

        var prefix string
        if segTyp == ntRegexp {
            prefix = segRexpat
        }
        // Look for the edge to attach to
        parent = n
        n = n.getEdge(segTyp, label, segTail, prefix)

        // No edge, create one
        if n == nil {
            child := &node{label: label, tail: segTail, prefix: search}
            hn := parent.addChild(child, search)
            hn.setEndpoint(method, handler, pattern)

            return hn
        }

        // Found an edge to match the pattern

        if n.typ > ntStatic {
            // We found a param node, trim the param from the search path and continue.
            // This param/wild pattern segment would already be on the tree from a previous
            // call to addChild when creating a new node.
            search = search[segEndIdx:]
            continue
        }

        // Static nodes fall below here.
        // Determine longest prefix of the search key on match.
        commonPrefix := longestPrefix(search, n.prefix)
        if commonPrefix == len(n.prefix) {
            // the common prefix is as long as the current node's prefix we're attempting to insert.
            // keep the search going.
            search = search[commonPrefix:]
            continue
        }

        // Split the node
        child := &node{
            typ:    ntStatic,
            prefix: search[:commonPrefix],
        }
        parent.replaceChild(search[0], segTail, child)

        // Restore the existing node
        n.label = n.prefix[commonPrefix]
        n.prefix = n.prefix[commonPrefix:]
        child.addChild(n, n.prefix)

        // If the new key is a subset, set the method/handler on this node and finish.
        search = search[commonPrefix:]
        if len(search) == 0 {
            child.setEndpoint(method, handler, pattern)
            return child
        }

        // Create a new edge for the node
        subchild := &node{
            typ:    ntStatic,
            label:  search[0],
            prefix: search,
        }
        hn := child.addChild(subchild, search)
        hn.setEndpoint(method, handler, pattern)
        return hn
    }
}

// addChild appends the new `child` node to the tree using the `pattern` as the trie key.
// For a URL router like chi's, we split the static, param, regexp and wildcard segments
// into different nodes. In addition, addChild will recursively call itself until every
// pattern segment is added to the url pattern tree as individual nodes, depending on type.
func (n *node) addChild(child *node, prefix string) *node {
    search := prefix

    // handler leaf node added to the tree is the child.
    // this may be overridden later down the flow
    hn := child

    // Parse next segment
    segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx := patNextSegment(search)

    // Add child depending on next up segment
    switch segTyp {

    case ntStatic:
        // Search prefix is all static (that is, has no params in path)
        // noop

    default:
        // Search prefix contains a param, regexp or wildcard

        if segTyp == ntRegexp {
            rex, err := regexp.Compile(segRexpat)
            if err != nil {
                panic(fmt.Sprintf("chi: invalid regexp pattern '%s' in route param", segRexpat))
            }
            child.prefix = segRexpat
            child.rex = rex
        }

        if segStartIdx == 0 {
            // Route starts with a param
            child.typ = segTyp

            if segTyp == ntCatchAll {
                segStartIdx = -1
            } else {
                segStartIdx = segEndIdx
            }
            if segStartIdx < 0 {
                segStartIdx = len(search)
            }
            child.tail = segTail // for params, we set the tail

            if segStartIdx != len(search) {
                // add static edge for the remaining part, split the end.
                // its not possible to have adjacent param nodes, so its certainly
                // going to be a static node next.

                search = search[segStartIdx:] // advance search position

                nn := &node{
                    typ:    ntStatic,
                    label:  search[0],
                    prefix: search,
                }
                hn = child.addChild(nn, search)
            }

        } else if segStartIdx > 0 {
            // Route has some param

            // starts with a static segment
            child.typ = ntStatic
            child.prefix = search[:segStartIdx]
            child.rex = nil

            // add the param edge node
            search = search[segStartIdx:]

            nn := &node{
                typ:   segTyp,
                label: search[0],
                tail:  segTail,
            }
            hn = child.addChild(nn, search)

        }
    }

    n.children[child.typ] = append(n.children[child.typ], child)
    n.children[child.typ].Sort()
    return hn
}

func (n *node) replaceChild(label, tail byte, child *node) {
    for i := 0; i < len(n.children[child.typ]); i++ {
        if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail {
            n.children[child.typ][i] = child
            n.children[child.typ][i].label = label
            n.children[child.typ][i].tail = tail
            return
        }
    }
    panic("chi: replacing missing child")
}

func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node {
    nds := n.children[ntyp]
    for i := 0; i < len(nds); i++ {
        if nds[i].label == label && nds[i].tail == tail {
            if ntyp == ntRegexp && nds[i].prefix != prefix {
                continue
            }
            return nds[i]
        }
    }
    return nil
}

func (n *node) setEndpoint(method methodTyp, handler *core.Endpoint, pattern string) {
    // Set the handler for the method type on the node
    if n.endpoints == nil {
        n.endpoints = make(endpoints)
    }

    paramKeys := patParamKeys(pattern)

    if method&mSTUB == mSTUB {
        n.endpoints.Value(mSTUB).handler = handler
    }
    if method&mALL == mALL {
        h := n.endpoints.Value(mALL)
        h.handler = handler
        h.pattern = pattern
        h.paramKeys = paramKeys
        for _, m := range methodMap {
            h := n.endpoints.Value(m)
            h.handler = handler
            h.pattern = pattern
            h.paramKeys = paramKeys
        }
    } else {
        h := n.endpoints.Value(method)
        h.handler = handler
        h.pattern = pattern
        h.paramKeys = paramKeys
    }
}

func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, *core.Endpoint) {
    // Reset the context routing pattern and params
    rctx.routePattern = ""
    rctx.routeParams.Keys = rctx.routeParams.Keys[:0]
    rctx.routeParams.Values = rctx.routeParams.Values[:0]

    // Find the routing handlers for the path
    rn := n.findRoute(rctx, method, path)
    if rn == nil {
        return nil, nil, nil
    }

    // Record the routing params in the request lifecycle
    rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...)
    rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...)

    // Record the routing pattern in the request lifecycle
    if rn.endpoints[method].pattern != "" {
        rctx.routePattern = rn.endpoints[method].pattern
        rctx.RoutePatterns = append(rctx.RoutePatterns, rctx.routePattern)
    }

    return rn, rn.endpoints, rn.endpoints[method].handler
}

// Recursive edge traversal by checking all nodeTyp groups along the way.
// It's like searching through a multi-dimensional radix trie.
func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node {
    nn := n
    search := path

    for t, nds := range nn.children {

        ntyp := nodeTyp(t)
        if len(nds) == 0 {
            continue
        }

        var xn *node
        xsearch := search

        var label byte
        if search != "" {
            label = search[0]
        }

        switch ntyp {
        case ntStatic:
            xn = nds.findEdge(label)
            if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) {
                continue
            }
            xsearch = xsearch[len(xn.prefix):]

        case ntParam, ntRegexp:
            // short-circuit and return no matching route for empty param values
            if xsearch == "" {
                continue
            }

            // serially loop through each node grouped by the tail delimiter
            for idx := 0; idx < len(nds); idx++ {
                xn = nds[idx]

                // label for param nodes is the delimiter byte
                p := strings.IndexByte(xsearch, xn.tail)

                if p < 0 {
                    if xn.tail == '/' {
                        p = len(xsearch)
                    } else {
                        continue
                    }
                } else if ntyp == ntRegexp && p == 0 {
                    continue
                }

                if ntyp == ntRegexp && xn.rex != nil {
                    if !xn.rex.MatchString(xsearch[:p]) {
                        continue
                    }
                } else if strings.IndexByte(xsearch[:p], '/') != -1 {
                    // avoid a match across path segments
                    continue
                }

                prevlen := len(rctx.routeParams.Values)
                rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p])
                xsearch = xsearch[p:]

                if len(xsearch) == 0 {
                    if xn.isLeaf() {
                        h := xn.endpoints[method]
                        if h != nil && h.handler != nil {
                            rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...)
                            return xn
                        }

                        for endpoints := range xn.endpoints {
                            if endpoints == mALL || endpoints == mSTUB {
                                continue
                            }
                            rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints)
                        }

                        // flag that the routing context found a route, but not a corresponding
                        // supported method
                        rctx.methodNotAllowed = true
                    }
                }

                // recursively find the next node on this branch
                fin := xn.findRoute(rctx, method, xsearch)
                if fin != nil {
                    return fin
                }

                // not found on this branch, reset vars
                rctx.routeParams.Values = rctx.routeParams.Values[:prevlen]
                xsearch = search
            }

            rctx.routeParams.Values = append(rctx.routeParams.Values, "")

        default:
            // catch-all nodes
            rctx.routeParams.Values = append(rctx.routeParams.Values, search)
            xn = nds[0]
            xsearch = ""
        }

        if xn == nil {
            continue
        }

        // did we find it yet?
        if len(xsearch) == 0 {
            if xn.isLeaf() {
                h := xn.endpoints[method]
                if h != nil && h.handler != nil {
                    rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...)
                    return xn
                }

                for endpoints := range xn.endpoints {
                    if endpoints == mALL || endpoints == mSTUB {
                        continue
                    }
                    rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints)
                }

                // flag that the routing context found a route, but not a corresponding
                // supported method
                rctx.methodNotAllowed = true
            }
        }

        // recursively find the next node..
        fin := xn.findRoute(rctx, method, xsearch)
        if fin != nil {
            return fin
        }

        // Did not find final handler, let's remove the param here if it was set
        if xn.typ > ntStatic {
            if len(rctx.routeParams.Values) > 0 {
                rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1]
            }
        }

    }

    return nil
}

func (n *node) isLeaf() bool {
    return n.endpoints != nil
}

func (n *node) Routes() []Route {
    rts := []Route{}

    n.walk(func(eps endpoints, subroutes chi.Routes) bool {
        if eps[mSTUB] != nil && eps[mSTUB].handler != nil && subroutes == nil {
            return false
        }

        // Group methodHandlers by unique patterns
        pats := make(map[string]endpoints)

        for mt, h := range eps {
            if h.pattern == "" {
                continue
            }
            p, ok := pats[h.pattern]
            if !ok {
                p = endpoints{}
                pats[h.pattern] = p
            }
            p[mt] = h
        }

        for p, mh := range pats {
            hs := make(map[string]*core.Endpoint)
            if mh[mALL] != nil && mh[mALL].handler != nil {
                hs["*"] = mh[mALL].handler
            }

            path := p
            for mt, h := range mh {
                if h.handler == nil {
                    continue
                }
                m := methodTypString(mt)
                if m == "" {
                    continue
                }
                hs[m] = h.handler

                // in direktiv methods can be only served from one file
                path = h.handler.FilePath
            }
            rt := Route{subroutes, hs, p, path}
            rts = append(rts, rt)
        }

        return false
    })

    return rts
}

func (n *node) walk(fn func(eps endpoints, subroutes chi.Routes) bool) bool {
    // Visit the leaf values if any
    if (n.endpoints != nil || n.subroutes != nil) && fn(n.endpoints, n.subroutes) {
        return true
    }

    // Recurse on the children
    for _, ns := range n.children {
        for _, cn := range ns {
            if cn.walk(fn) {
                return true
            }
        }
    }
    return false
}

// patNextSegment returns the next segment details from a pattern:
// node type, param key, regexp string, param tail byte, param starting index, param ending index
func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int) {
    ps := strings.Index(pattern, "{")
    ws := strings.Index(pattern, "*")

    if ps < 0 && ws < 0 {
        return ntStatic, "", "", 0, 0, len(pattern) // we return the entire thing
    }

    // Sanity check
    if ps >= 0 && ws >= 0 && ws < ps {
        panic("chi: wildcard '*' must be the last pattern in a route, otherwise use a '{param}'")
    }

    var tail byte = '/' // Default endpoint tail to / byte

    if ps >= 0 {
        // Param/Regexp pattern is next
        nt := ntParam

        // Read to closing } taking into account opens and closes in curl count (cc)
        cc := 0
        pe := ps
        for i, c := range pattern[ps:] {
            if c == '{' {
                cc++
            } else if c == '}' {
                cc--
                if cc == 0 {
                    pe = ps + i
                    break
                }
            }
        }
        if pe == ps {
            panic("chi: route param closing delimiter '}' is missing")
        }

        key := pattern[ps+1 : pe]
        pe++ // set end to next position

        if pe < len(pattern) {
            tail = pattern[pe]
        }

        var rexpat string
        if idx := strings.Index(key, ":"); idx >= 0 {
            nt = ntRegexp
            rexpat = key[idx+1:]
            key = key[:idx]
        }

        if len(rexpat) > 0 {
            if rexpat[0] != '^' {
                rexpat = "^" + rexpat
            }
            if rexpat[len(rexpat)-1] != '$' {
                rexpat += "$"
            }
        }

        return nt, key, rexpat, tail, ps, pe
    }

    // Wildcard pattern as finale
    if ws < len(pattern)-1 {
        panic("chi: wildcard '*' must be the last value in a route. trim trailing text or use a '{param}' instead")
    }
    return ntCatchAll, "*", "", 0, ws, len(pattern)
}

func patParamKeys(pattern string) []string {
    pat := pattern
    paramKeys := []string{}
    for {
        ptyp, paramKey, _, _, _, e := patNextSegment(pat)
        if ptyp == ntStatic {
            return paramKeys
        }
        for i := 0; i < len(paramKeys); i++ {
            if paramKeys[i] == paramKey {
                panic(fmt.Sprintf("chi: routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey))
            }
        }
        paramKeys = append(paramKeys, paramKey)
        pat = pat[e:]
    }
}

// longestPrefix finds the length of the shared prefix
// of two strings
func longestPrefix(k1, k2 string) int {
    max := len(k1)
    if l := len(k2); l < max {
        max = l
    }
    var i int
    for i = 0; i < max; i++ {
        if k1[i] != k2[i] {
            break
        }
    }
    return i
}

func methodTypString(method methodTyp) string {
    for s, t := range methodMap {
        if method == t {
            return s
        }
    }
    return ""
}

type nodes []*node

// Sort the list of nodes by label
func (ns nodes) Sort()              { sort.Sort(ns); ns.tailSort() }
func (ns nodes) Len() int           { return len(ns) }
func (ns nodes) Swap(i, j int)      { ns[i], ns[j] = ns[j], ns[i] }
func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label }

// tailSort pushes nodes with '/' as the tail to the end of the list for param nodes.
// The list order determines the traversal order.
func (ns nodes) tailSort() {
    for i := len(ns) - 1; i >= 0; i-- {
        if ns[i].typ > ntStatic && ns[i].tail == '/' {
            ns.Swap(i, len(ns)-1)
            return
        }
    }
}

func (ns nodes) findEdge(label byte) *node {
    num := len(ns)
    idx := 0
    i, j := 0, num-1
    for i <= j {
        idx = i + (j-i)/2
        if label > ns[idx].label {
            i = idx + 1
        } else if label < ns[idx].label {
            j = idx - 1
        } else {
            i = num // breaks cond
        }
    }
    if ns[idx].label != label {
        return nil
    }
    return ns[idx]
}

// Route describes the details of a routing handler.
// Handlers map key is an HTTP method
type Route struct {
    SubRoutes chi.Routes
    Handlers  map[string]*core.Endpoint
    Pattern   string

    FilePath string
}