bnkamalesh/webgo

View on GitHub
router.go

Summary

Maintainability
A
0 mins
Test Coverage
package webgo

import (
    "bufio"
    "context"
    "encoding/json"
    "errors"
    "fmt"
    "net"
    "net/http"
    "sync"
)

// httpResponseWriter has all the functions to be implemented by the custom
// responsewriter used
type httpResponseWriter interface {
    http.ResponseWriter
    http.Flusher
    http.Hijacker
    http.Pusher
}

func init() {
    var err error
    jsonErrPayload, err = json.Marshal(errOutput{
        Errors: ErrInternalServer,
        Status: http.StatusInternalServerError,
    })
    if err != nil {
        panic(err)
    }

    // ensure the custom response writer implements all the required functions
    crw := &customResponseWriter{}
    _ = httpResponseWriter(crw)
}

var (
    validHTTPMethods = []string{
        http.MethodOptions,
        http.MethodHead,
        http.MethodGet,
        http.MethodPost,
        http.MethodPut,
        http.MethodPatch,
        http.MethodDelete,
    }

    ctxPool = &sync.Pool{
        New: func() interface{} {
            return new(ContextPayload)
        },
    }
    crwPool = &sync.Pool{
        New: func() interface{} {
            return new(customResponseWriter)
        },
    }
)

// customResponseWriter is a custom HTTP response writer
type customResponseWriter struct {
    http.ResponseWriter
    statusCode    int
    written       bool
    headerWritten bool
}

// WriteHeader is the interface implementation to get HTTP response code and add
// it to the custom response writer
func (crw *customResponseWriter) WriteHeader(code int) {
    if crw.headerWritten {
        return
    }

    crw.headerWritten = true
    crw.statusCode = code
    crw.ResponseWriter.WriteHeader(code)
}

// Write is the interface implementation to respond to the HTTP request,
// but check if a response was already sent.
func (crw *customResponseWriter) Write(body []byte) (int, error) {
    crw.WriteHeader(crw.statusCode)
    crw.written = true
    return crw.ResponseWriter.Write(body)
}

// Flush calls the http.Flusher to clear/flush the buffer
func (crw *customResponseWriter) Flush() {
    if rw, ok := crw.ResponseWriter.(http.Flusher); ok {
        rw.Flush()
    }
}

// Hijack implements the http.Hijacker interface
func (crw *customResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
    if hj, ok := crw.ResponseWriter.(http.Hijacker); ok {
        return hj.Hijack()
    }

    return nil, nil, errors.New("unable to create hijacker")
}

func (crw *customResponseWriter) Push(target string, opts *http.PushOptions) error {
    if n, ok := crw.ResponseWriter.(http.Pusher); ok {
        return n.Push(target, opts)
    }
    return errors.New("pusher not implemented")
}

func (crw *customResponseWriter) reset() {
    crw.statusCode = 0
    crw.written = false
    crw.headerWritten = false
    crw.ResponseWriter = nil
}

// Middleware is the signature of WebGo's middleware
type Middleware func(http.ResponseWriter, *http.Request, http.HandlerFunc)

// discoverRoute returns the correct 'route', for the given request
func discoverRoute(path string, routes []*Route) (*Route, map[string]string) {
    for _, route := range routes {
        if ok, params := route.matchPath(path); ok {
            return route, params
        }
    }
    return nil, nil
}

// Router is the HTTP router
type Router struct {
    optHandlers    []*Route
    headHandlers   []*Route
    getHandlers    []*Route
    postHandlers   []*Route
    putHandlers    []*Route
    patchHandlers  []*Route
    deleteHandlers []*Route
    allHandlers    map[string][]*Route

    // NotFound is the generic handler for 404 resource not found response
    NotFound http.HandlerFunc

    // NotImplemented is the generic handler for 501 method not implemented
    NotImplemented http.HandlerFunc

    // config has all the app config
    config *Config

    // httpServer is the server handler for the active HTTP server
    httpServer *http.Server
    // httpsServer is the server handler for the active HTTPS server
    httpsServer *http.Server
}

// methodRoutes returns the list of Routes handling the HTTP method given the request
func (rtr *Router) methodRoutes(method string) (routes []*Route) {
    switch method {
    case http.MethodOptions:
        return rtr.optHandlers
    case http.MethodHead:
        return rtr.headHandlers
    case http.MethodGet:
        return rtr.getHandlers
    case http.MethodPost:
        return rtr.postHandlers
    case http.MethodPut:
        return rtr.putHandlers
    case http.MethodPatch:
        return rtr.patchHandlers
    case http.MethodDelete:
        return rtr.deleteHandlers
    }

    return nil
}

func (rtr *Router) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
    // a custom response writer is used to set appropriate HTTP status code in case of
    // encoding errors. i.e. if there's a JSON encoding issue while responding,
    // the HTTP status code would say 200, and and the JSON payload {"status": 500}
    crw := newCRW(rw, http.StatusOK)

    routes := rtr.methodRoutes(r.Method)
    if routes == nil {
        // serve 501 when HTTP method is not implemented
        crw.statusCode = http.StatusNotImplemented
        rtr.NotImplemented(crw, r)
        releaseCRW(crw)
        return
    }

    path := r.URL.EscapedPath()
    route, params := discoverRoute(path, routes)
    if route == nil {
        // serve 404 when there are no matching routes
        crw.statusCode = http.StatusNotFound
        rtr.NotFound(crw, r)
        releaseCRW(crw)
        return
    }

    ctxPayload := newContext()
    ctxPayload.Route = route
    ctxPayload.URIParams = params

    // webgo context is injected to the HTTP request context
    *r = *r.WithContext(
        context.WithValue(
            r.Context(),
            wgoCtxKey,
            ctxPayload,
        ),
    )

    defer releasePoolResources(crw, ctxPayload)
    route.serve(crw, r)
}

// Use adds a middleware layer
func (rtr *Router) Use(mm ...Middleware) {
    for _, handlers := range rtr.allHandlers {
        for idx := range handlers {
            route := handlers[idx]
            if route.skipMiddleware {
                continue
            }

            route.use(mm...)
        }
    }
}

// UseOnSpecialHandlers adds middleware to the 2 special handlers of webgo
func (rtr *Router) UseOnSpecialHandlers(mm ...Middleware) {
    // v3.2.1 introduced the feature of adding middleware to both notfound & not implemented
    // handlers
    /*
        - It was added considering an `accesslog` middleware, where all requests should be logged
        # This is now being moved to a separate function considering an authentication middleware, where all requests
          including 404 & 501 would respond with `not authenticated` if you do not have special handling
          within the middleware. It is a cleaner implementation to avoid this and let users add their
          middleware separately to NOTFOUND & NOTIMPLEMENTED handlers
    */

    for idx := range mm {
        m := mm[idx]
        nf := rtr.NotFound
        rtr.NotFound = func(rw http.ResponseWriter, req *http.Request) {
            m(rw, req, nf)
        }

        ni := rtr.NotImplemented
        rtr.NotImplemented = func(rw http.ResponseWriter, req *http.Request) {
            m(rw, req, ni)
        }
    }
}

// Add is a convenience method used to add a new route to an already initialized router
// Important: `.Use` should be used only after all routes are added
func (rtr *Router) Add(routes ...*Route) {
    hmap := httpHandlers(routes)
    rtr.optHandlers = append(rtr.optHandlers, hmap[http.MethodOptions]...)
    rtr.headHandlers = append(rtr.headHandlers, hmap[http.MethodHead]...)
    rtr.getHandlers = append(rtr.getHandlers, hmap[http.MethodGet]...)
    rtr.postHandlers = append(rtr.postHandlers, hmap[http.MethodPost]...)
    rtr.putHandlers = append(rtr.putHandlers, hmap[http.MethodPut]...)
    rtr.patchHandlers = append(rtr.patchHandlers, hmap[http.MethodPatch]...)
    rtr.deleteHandlers = append(rtr.deleteHandlers, hmap[http.MethodDelete]...)

    all := rtr.allHandlers
    if all == nil {
        all = map[string][]*Route{}
    }

    for _, key := range supportedHTTPMethods {
        newlist, hasKey := hmap[key]
        if !hasKey {
            continue
        }
        if all[key] == nil {
            all[key] = make([]*Route, 0, len(hmap))
        }
        all[key] = append(all[key], newlist...)
    }

    rtr.allHandlers = all
}

func newCRW(rw http.ResponseWriter, rCode int) *customResponseWriter {
    crw := crwPool.Get().(*customResponseWriter)
    crw.ResponseWriter = rw
    crw.statusCode = rCode
    return crw
}

func releaseCRW(crw *customResponseWriter) {
    crw.reset()
    crwPool.Put(crw)
}

func newContext() *ContextPayload {
    return ctxPool.Get().(*ContextPayload)
}

func releaseContext(cp *ContextPayload) {
    cp.reset()
    ctxPool.Put(cp)
}

func releasePoolResources(crw *customResponseWriter, cp *ContextPayload) {
    releaseCRW(crw)
    releaseContext(cp)
}

// NewRouter initializes & returns a new router instance with all the configurations and routes set
func NewRouter(cfg *Config, routes ...*Route) *Router {
    r := &Router{
        NotFound: http.NotFound,
        NotImplemented: func(rw http.ResponseWriter, req *http.Request) {
            Send(rw, "", "501 Not Implemented", http.StatusNotImplemented)
        },
        config: cfg,
    }

    r.Add(routes...)

    return r
}

// checkDuplicateRoutes checks if any of the routes have duplicate name or URI pattern
func checkDuplicateRoutes(idx int, route *Route, routes []*Route) {
    // checking if the URI pattern is duplicated
    for i := 0; i < idx; i++ {
        rt := routes[i]

        if rt.Name == route.Name {
            LOGHANDLER.Info(
                fmt.Sprintf(
                    "Duplicate route name('%s') detected",
                    rt.Name,
                ),
            )
        }

        if rt.Method != route.Method {
            continue
        }

        // regex pattern match
        if ok, _ := rt.matchPath(route.Pattern); !ok {
            continue
        }

        LOGHANDLER.Warn(
            fmt.Sprintf(
                "Duplicate URI pattern detected.\nPattern: '%s'\nDuplicate pattern: '%s'",
                rt.Pattern,
                route.Pattern,
            ),
        )
        LOGHANDLER.Warn("Only the first route to match the URI pattern would handle the request")
    }
}

// httpHandlers returns all the handlers in a map, for each HTTP method
func httpHandlers(routes []*Route) map[string][]*Route {
    handlers := map[string][]*Route{}

    handlers[http.MethodHead] = []*Route{}
    handlers[http.MethodGet] = []*Route{}

    for idx, route := range routes {
        found := false
        for _, validMethod := range validHTTPMethods {
            if route.Method == validMethod {
                found = true
                break
            }
        }

        if !found {
            LOGHANDLER.Fatal(
                fmt.Sprintf(
                    "Unsupported HTTP method provided. Method: '%s'",
                    route.Method,
                ),
            )
            return nil
        }

        if route.Handlers == nil || len(route.Handlers) == 0 {
            LOGHANDLER.Fatal(
                fmt.Sprintf(
                    "No handlers provided for the route '%s', method '%s'",
                    route.Pattern,
                    route.Method,
                ),
            )
            return nil
        }

        err := route.init()
        if err != nil {
            LOGHANDLER.Fatal("Unsupported URI pattern.", route.Pattern, err)
            return nil
        }

        checkDuplicateRoutes(idx, route, routes)

        handlers[route.Method] = append(handlers[route.Method], route)
    }

    return handlers
}