opcotech/elemo

View on GitHub
internal/transport/http/middleware.go

Summary

Maintainability
A
0 mins
Test Coverage
F
34%
package http

import (
    "context"
    "fmt"
    "net/http"
    "reflect"
    "runtime"
    "strings"
    "time"

    "github.com/go-chi/chi/v5/middleware"
    "github.com/go-oauth2/oauth2/v4"
    "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
    "go.opentelemetry.io/otel/attribute"

    httpMetricsProm "github.com/slok/go-http-metrics/metrics/prometheus"
    httpMetricsMiddleware "github.com/slok/go-http-metrics/middleware"
    httpMetricsMiddlewareStd "github.com/slok/go-http-metrics/middleware/std"

    "github.com/opcotech/elemo/internal/model"
    "github.com/opcotech/elemo/internal/pkg"
    "github.com/opcotech/elemo/internal/pkg/log"
    "github.com/opcotech/elemo/internal/pkg/tracing"
)

type ctxCallbackFunc func(w http.ResponseWriter, r *http.Request) any

func getMiddlewareName(fn func(next http.Handler) http.Handler) (string, string) {
    cache := make(map[uintptr][]string)

    fnPtr := reflect.ValueOf(fn).Pointer()

    if res, ok := cache[fnPtr]; ok {
        return res[0], res[1]
    }

    path := runtime.FuncForPC(fnPtr).Name()
    parts := strings.Split(path, ".")
    name := parts[len(parts)-1]
    cache[fnPtr] = append(cache[fnPtr], name, path)

    return name, path
}

func withContextObject(ctxKey pkg.CtxKey, cb ctxCallbackFunc) func(next http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            r = r.WithContext(context.WithValue(r.Context(), ctxKey, cb(w, r)))
            next.ServeHTTP(w, r)
        })
    }
}

// WithContextObject returns a middleware that adds any value to the context
// associated with the given key.
func WithContextObject(key pkg.CtxKey, value any) func(next http.Handler) http.Handler {
    return withContextObject(key, func(_ http.ResponseWriter, _ *http.Request) any {
        return value
    })
}

func WithOtelTracer(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        otelhttp.NewHandler(next, r.URL.Path).ServeHTTP(w, r)
    })
}

func WithPrometheusMetrics(next http.Handler) http.Handler {
    return httpMetricsMiddlewareStd.Handler("", httpMetricsMiddleware.New(httpMetricsMiddleware.Config{
        Service:  "elemo",
        Recorder: httpMetricsProm.NewRecorder(httpMetricsProm.Config{}),
    }), next)
}

// WithTracedMiddleware returns an HTTP middleware that traces the middleware
// execution by creating a new span and passing the context to the next
// handler.
func WithTracedMiddleware(tracer tracing.Tracer, middleware func(next http.Handler) http.Handler) func(next http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            name, path := getMiddlewareName(middleware)
            ctx, span := tracer.Start(r.Context(), fmt.Sprintf("transport.http.middleware/%s", name))
            defer span.End()

            span.SetAttributes(attribute.KeyValue{
                Key:   "middleware.path",
                Value: attribute.StringValue(path),
            })

            middleware(next).ServeHTTP(w, r.WithContext(ctx))
        })
    }
}

// WithRequestLogger returns a middleware that logs the request.
//
// The middleware depends on WithLogger. To use this middleware, you must call
// both of those first.
func WithRequestLogger(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        wrappedWriter := middleware.NewWrapResponseWriter(w, r.ProtoMajor)

        currentTime := time.Now().UTC()
        defer func(ctx context.Context, w middleware.WrapResponseWriter, r *http.Request, t time.Time) {
            log.Info(ctx, "serve http request",
                log.WithProtocol(r.Proto),
                log.WithMethod(r.Method),
                log.WithPath(r.URL.Path),
                log.WithRequestID(middleware.GetReqID(ctx)),
                log.WithRemoteAddr(r.RemoteAddr),
                log.WithUserAgent(r.UserAgent()),
                log.WithSize(int64(w.BytesWritten())),
                log.WithStatus(w.Status()),
                log.WithDuration(time.Since(t).Seconds()),
                log.WithAction(log.ActionHTTPRequestHandle),
            )
        }(r.Context(), wrappedWriter, r, currentTime)

        next.ServeHTTP(wrappedWriter, r)
    })
}

// WithUserID returns a middleware that adds the user ID to the context, parsed
// from the Authorization header if present. Otherwise, an empty string is
// added.
func WithUserID(tokenValidator func(r *http.Request) (oauth2.TokenInfo, error)) func(next http.Handler) http.Handler {
    return withContextObject(pkg.CtxKeyUserID, func(_ http.ResponseWriter, r *http.Request) any {
        if info, _ := tokenValidator(r); info != nil {
            id, _ := model.NewIDFromString(info.GetUserID(), model.ResourceTypeUser.String())
            return id
        }

        return model.MustNewNilID(model.ResourceTypeUser)
    })
}