bnkamalesh/webgo

View on GitHub
middleware/cors/cors.go

Summary

Maintainability
A
0 mins
Test Coverage
/*
Package cors sets the appropriate CORS(https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS)
response headers, and lets you customize. Following customizations are allowed:
  - provide a list of allowed domains
  - provide a list of headers
  - set the max-age of CORS headers

The list of allowed methods are
*/
package cors

import (
    "fmt"
    "net/http"
    "regexp"
    "sort"
    "strings"

    "github.com/bnkamalesh/webgo/v7"
)

const (
    headerOrigin           = "Access-Control-Allow-Origin"
    headerMethods          = "Access-Control-Allow-Methods"
    headerCreds            = "Access-Control-Allow-Credentials"
    headerAllowHeaders     = "Access-Control-Allow-Headers"
    headerReqHeaders       = "Access-Control-Request-Headers"
    headerAccessControlAge = "Access-Control-Max-Age"
    allowHeaders           = "Accept,Content-Type,Content-Length,Accept-Encoding,Access-Control-Request-Headers,"
)

var (
    defaultAllowMethods = "HEAD,GET,POST,PUT,PATCH,DELETE,OPTIONS"
)

func allowedDomains() []string {
    // The domains mentioned here are default
    domains := []string{"*"}
    return domains
}

func getReqOrigin(r *http.Request) string {
    return r.Header.Get("Origin")
}

func allowedOriginsRegex(allowedOrigins ...string) []regexp.Regexp {
    if len(allowedOrigins) == 0 {
        allowedOrigins = []string{"*"}
    } else {
        // If "*" is one of the allowed domains, i.e. all domains, then rest of the values are ignored
        for _, val := range allowedOrigins {
            val = strings.TrimSpace(val)

            if val == "*" {
                allowedOrigins = []string{"*"}
                break
            }
        }
    }

    allowedOriginRegex := make([]regexp.Regexp, 0, len(allowedOrigins))
    for _, ao := range allowedOrigins {
        parts := strings.Split(ao, ":")
        str := strings.TrimSpace(parts[0])
        if str == "" {
            continue
        }

        if str == "*" {
            allowedOriginRegex = append(
                allowedOriginRegex,
                *(regexp.MustCompile(".+")),
            )
            break
        }

        regStr := fmt.Sprintf(`^(http)?(https)?(:\/\/)?(.+\.)?%s(:[0-9]+)?$`, str)

        allowedOriginRegex = append(
            allowedOriginRegex,
            // Allow any port number of the specified domain
            *(regexp.MustCompile(regStr)),
        )
    }

    return allowedOriginRegex
}

func allowedMethods(routes []*webgo.Route) string {
    if len(routes) == 0 {
        return defaultAllowMethods
    }

    methods := make([]string, 0, len(routes))
    for _, r := range routes {
        found := false
        for _, m := range methods {
            if m == r.Method {
                found = true
                break
            }
        }
        if found {
            continue
        }
        methods = append(methods, r.Method)
    }
    sort.Strings(methods)
    return strings.Join(methods, ",")
}

// Config holds all the configurations which is available for customizing this middleware
type Config struct {
    TimeoutSecs    int
    Routes         []*webgo.Route
    AllowedOrigins []string
    AllowedHeaders []string
}

func allowedHeaders(headers []string) string {
    if len(headers) == 0 {
        return allowHeaders
    }

    allowedHeaders := strings.Join(headers, ",")
    if allowedHeaders[len(allowedHeaders)-1] != ',' {
        allowedHeaders += ","
    }
    return allowedHeaders
}

func allowOrigin(reqOrigin string, allowedOriginRegex []regexp.Regexp) bool {

    for _, o := range allowedOriginRegex {
        // Set appropriate response headers required for CORS
        if o.MatchString(reqOrigin) || reqOrigin == "" {
            return true
        }
    }
    return false
}

// Middleware can be used as well, it lets the user use this middleware without webgo
func Middleware(allowedOriginRegex []regexp.Regexp, corsTimeout, allowedMethods, allowedHeaders string) webgo.Middleware {
    return func(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
        reqOrigin := getReqOrigin(req)
        allowed := allowOrigin(reqOrigin, allowedOriginRegex)

        if !allowed {
            // If CORS failed, no respective headers are set. But the execution is allowed to continue
            // Earlier this middleware blocked access altogether, which was considered an added
            // security measure despite it being outside the scope of this middelware. Though, such
            // restrictions create unnecessary complexities during inter-app communication.
            next(rw, req)
            return
        }

        // Set appropriate response headers required for CORS
        rw.Header().Set(headerOrigin, reqOrigin)
        rw.Header().Set(headerAccessControlAge, corsTimeout)
        rw.Header().Set(headerCreds, "true")
        rw.Header().Set(headerMethods, allowedMethods)
        rw.Header().Set(headerAllowHeaders, allowedHeaders+req.Header.Get(headerReqHeaders))

        if req.Method == http.MethodOptions {
            webgo.SendHeader(rw, http.StatusOK)
            return
        }

        next(rw, req)
    }
}

// AddOptionsHandlers appends OPTIONS handler for all the routes
// The response body would be empty for all the new handlers added
func AddOptionsHandlers(routes []*webgo.Route) []*webgo.Route {
    dummyHandler := func(w http.ResponseWriter, r *http.Request) {}
    if len(routes) == 0 {
        return []*webgo.Route{
            {
                Name:          "cors",
                Pattern:       "/:w*",
                Method:        http.MethodOptions,
                TrailingSlash: true,
                Handlers:      []http.HandlerFunc{dummyHandler},
            },
        }
    }

    list := make([]*webgo.Route, 0, len(routes))
    list = append(list, routes...)

    for _, r := range routes {
        list = append(list, &webgo.Route{
            Name:          fmt.Sprintf("%s-CORS", r.Name),
            Method:        http.MethodOptions,
            Pattern:       r.Pattern,
            TrailingSlash: true,
            Handlers:      []http.HandlerFunc{dummyHandler},
        })
    }

    return list
}

// CORS is a single CORS middleware which can be applied to the whole app at once
func CORS(cfg *Config) webgo.Middleware {
    if cfg == nil {
        cfg = new(Config)
        // 30 minutes
        cfg.TimeoutSecs = 30 * 60
    }

    allowedOrigins := cfg.AllowedOrigins
    if len(allowedOrigins) == 0 {
        allowedOrigins = allowedDomains()
    }

    allowedOriginRegex := allowedOriginsRegex(allowedOrigins...)
    allowedmethods := allowedMethods(cfg.Routes)
    allowedHeaders := allowedHeaders(cfg.AllowedHeaders)
    corsTimeout := fmt.Sprintf("%d", cfg.TimeoutSecs)

    return Middleware(
        allowedOriginRegex,
        corsTimeout,
        allowedmethods,
        allowedHeaders,
    )
}