vorteil/direktiv

View on GitHub
pkg/refactor/gateway/gateway.go

Summary

Maintainability
C
7 hrs
Test Coverage
// nolint
package gateway

import (
    "bytes"
    "context"
    "fmt"
    "log/slog"
    "net/http"
    "path/filepath"
    "strings"
    "sync"
    "time"

    "github.com/direktiv/direktiv/pkg/flow/nohome/recipient"
    "github.com/direktiv/direktiv/pkg/refactor/core"
    "github.com/direktiv/direktiv/pkg/refactor/database"
    "github.com/direktiv/direktiv/pkg/refactor/filestore"
    "github.com/direktiv/direktiv/pkg/refactor/gateway/consumer"
    "github.com/direktiv/direktiv/pkg/refactor/gateway/endpoints"
    "github.com/direktiv/direktiv/pkg/refactor/gateway/plugins"
    _ "github.com/direktiv/direktiv/pkg/refactor/gateway/plugins/auth"
    _ "github.com/direktiv/direktiv/pkg/refactor/gateway/plugins/inbound"
    _ "github.com/direktiv/direktiv/pkg/refactor/gateway/plugins/outbound"
    _ "github.com/direktiv/direktiv/pkg/refactor/gateway/plugins/target"
    "github.com/go-chi/chi/v5"
    "go.opentelemetry.io/otel/trace"
)

type namespaceGateway struct {
    EndpointList *endpoints.EndpointList
    ConsumerList *consumer.List
}

type gatewayManager struct {
    db         *database.SQLStore
    nsGateways map[string]*namespaceGateway
    lock       sync.RWMutex
}

const anonymousUsername = "Anonymous"

func NewGatewayManager(db *database.SQLStore) core.GatewayManager {
    return &gatewayManager{
        db:         db,
        nsGateways: make(map[string]*namespaceGateway),
    }
}

func (ep *gatewayManager) updateNamespace(ns string) {
    slog.Debug("updating namespace gateway", slog.String("namespace", ns), "track", recipient.Namespace.String()+"."+ns)

    ep.lock.Lock()
    defer ep.lock.Unlock()

    gw, ok := ep.nsGateways[ns]
    if !ok {
        gw = &namespaceGateway{
            EndpointList: endpoints.NewEndpointList(),
            ConsumerList: consumer.NewConsumerList(),
        }
        ep.nsGateways[ns] = gw
    }

    fStore := ep.db.FileStore()
    ctx := context.Background()

    files, err := fStore.ForNamespace(ns).ListDirektivFilesWithData(ctx)
    if err != nil {
        slog.Error("list files", "err", err, "track", recipient.Namespace.String()+"."+ns)

        return
    }

    eps := make([]*core.Endpoint, 0)
    consumers := make([]*core.ConsumerFile, 0)

    for _, file := range files {
        if file.Typ != filestore.FileTypeConsumer &&
            file.Typ != filestore.FileTypeEndpoint {
            continue
        }

        if file.Typ == filestore.FileTypeConsumer {
            item, err := core.ParseConsumerFile(file.Data)
            if err != nil {
                slog.Error("parse endpoint file", "err", err, "track", recipient.Namespace.String()+"."+ns)

                continue
            }

            // username can not be empty or contain a colon for basic auth
            if item.Username == "" ||
                strings.Contains(item.Username, ":") {
                slog.Warn("username invalid", slog.String("user", item.Username))

                continue
            }

            consumers = append(consumers, item)
        } else {
            ep := &core.Endpoint{
                Methods:                 make([]string, 0),
                Errors:                  make([]string, 0),
                Warnings:                make([]string, 0),
                AuthPluginInstances:     make([]core.PluginInstance, 0),
                InboundPluginInstances:  make([]core.PluginInstance, 0),
                OutboundPluginInstances: make([]core.PluginInstance, 0),
                TargetPluginInstance:    nil,
                FilePath:                file.Path,
                Namespace:               ns,
            }

            item, err := core.ParseEndpointFile(file.Data)
            // if parsing fails, the endpoint is still getting added to report
            // an error in the API
            if err != nil {
                slog.Error("parse endpoint file", "err", err)
                ep.Errors = append(ep.Errors, err.Error())
                eps = append(eps, ep)

                continue
            }

            ep.ServerPath = filepath.Join("/ns", ns, item.Path)

            ep.AllowAnonymous = item.AllowAnonymous
            ep.Timeout = item.Timeout
            ep.Methods = item.Methods
            ep.Path = item.Path
            ep.Plugins = item.Plugins

            endpoints.MakeEndpointPluginChain(ep, &item.Plugins)

            eps = append(eps, ep)
        }
    }

    gw.EndpointList.SetEndpoints(eps)
    gw.ConsumerList.SetConsumers(consumers)
}

func (ep *gatewayManager) UpdateAll() {
    _, dStore := ep.db.FileStore(), ep.db.DataStore()

    ep.nsGateways = map[string]*namespaceGateway{}

    nsList, err := dStore.Namespaces().GetAll(context.Background())
    if err != nil {
        slog.Error("listing namespaces", "err", err)

        return
    }

    for _, ns := range nsList {
        ep.updateNamespace(ns.Name)
    }
}

type DummyWriter struct {
    HeaderMap http.Header
    Body      *bytes.Buffer
    Code      int
}

func NewDummyWriter() *DummyWriter {
    return &DummyWriter{
        HeaderMap: make(http.Header),
        Body:      new(bytes.Buffer),
        Code:      http.StatusOK,
    }
}

func (dr *DummyWriter) Header() http.Header {
    return dr.HeaderMap
}

func (dr *DummyWriter) Write(buf []byte) (int, error) {
    return dr.Body.Write(buf)
}

func (dr *DummyWriter) WriteHeader(statusCode int) {
    dr.Code = statusCode
}

func (ep *gatewayManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    ctx := r.Context()
    span := trace.SpanFromContext(ctx)
    defer span.End()
    spanContext := span.SpanContext()
    traceID := spanContext.TraceID().String()
    spanID := spanContext.SpanID()
    slog := slog.With("trace", traceID, "span", spanID, "component", "gateway")
    slog.Info("serving gateway request")
    chiCtx := chi.RouteContext(r.Context())
    namespace := core.SystemNamespace
    routePath := chi.URLParam(r, "*")

    // get namespace from URL or use magical one
    if chiCtx.RoutePattern() == "/ns/{namespace}/*" {
        namespace = chi.URLParam(r, "namespace")
    }

    gw, ok := ep.nsGateways[namespace]
    if !ok {
        plugins.ReportNotFound(w)

        return
    }

    endpointEntry, urlParams := gw.EndpointList.FindRoute(routePath, r.Method)
    if endpointEntry == nil {
        plugins.ReportNotFound(w)

        return
    }

    // if there are configuration errors, return it
    if len(endpointEntry.Errors) > 0 {
        plugins.ReportError(ctx, w, http.StatusInternalServerError, "plugin has errors",
            fmt.Errorf(strings.Join(endpointEntry.Errors, ", ")))

        return
    }

    slogRoute := slog.With("trace", traceID,
        "span", spanID,
        "track", recipient.Route.String()+"."+namespace+"."+endpointEntry.Path,
        "namespace", namespace, "endpoint", endpointEntry.Path,
        "route", routePath)

    // add url params e.g. /{id}
    ctx = context.WithValue(ctx, plugins.URLParamCtxKey, urlParams)
    ctx = context.WithValue(ctx, plugins.ConsumersParamCtxKey, gw.ConsumerList)

    // timeout
    t := endpointEntry.Timeout

    // timeout is 30 secs if not set
    if t == 0 {
        t = 30
    }

    tracer := trace.SpanFromContext(ctx).TracerProvider().Tracer("direktiv/flow")
    ctx, childSpan := tracer.Start(ctx, "plugins-processing")
    defer childSpan.End()

    slogRoute.Info("Serving plugins")

    ctx, cancel := context.WithTimeout(ctx, time.Second*time.Duration(t))
    defer cancel()
    ctx = context.WithValue(ctx, plugins.NamespaceCtxKey, namespace)
    ctx = context.WithValue(ctx, plugins.EndpointCtxKey, endpointEntry.Path)
    ctx = context.WithValue(ctx, plugins.RouteCtxKey, routePath)
    r = r.WithContext(ctx)
    c := &core.ConsumerFile{}
    for i := range endpointEntry.AuthPluginInstances {
        authPlugin := endpointEntry.AuthPluginInstances[i]

        // all authplugins succeed, the setting of the consumer is the success message
        authPlugin.ExecutePlugin(c, w, r)

        // check and exit if consumer is set in plugin
        if c.Username != "" {
            slog.Info("user authenticated", "user", c.Username)

            break
        }
    }

    // if user not authenticated and anonymous access not enabled
    if c.Username == "" && !endpointEntry.AllowAnonymous {
        plugins.ReportError(ctx, w, http.StatusUnauthorized, "no permission",
            fmt.Errorf("request not authorized"))
        slogRoute.Debug("user authenticated", "user", c.Username)

        return
    }

    // set username Anonymous if allowed and not set via auth plugin
    if c.Username == "" {
        c.Username = anonymousUsername
    }

    // run inbound
    for i := range endpointEntry.InboundPluginInstances {
        inboundPlugin := endpointEntry.InboundPluginInstances[i]
        proceed := inboundPlugin.ExecutePlugin(c, w, r)
        if !proceed {
            return
        }
    }
    // if there are outbound plugins the reponsewrite is getting swapped out
    // because target plugins can do io.copy and set headers which would go
    // on the wire immediately.
    targetWriter := w
    if len(endpointEntry.OutboundPluginInstances) > 0 {
        targetWriter = NewDummyWriter()
    }

    // run target if it exists
    if endpointEntry.TargetPluginInstance != nil &&
        !endpointEntry.TargetPluginInstance.ExecutePlugin(c, targetWriter, r) {
        return
    }

    for i := range endpointEntry.OutboundPluginInstances {
        outboundPlugin := endpointEntry.OutboundPluginInstances[i]

        // nolint
        tw := targetWriter.(*DummyWriter)

        rin, err := swapRequestResponse(r, tw)
        if err != nil {
            plugins.ReportError(ctx, w, http.StatusUnauthorized, "output plugin failed",
                err)

            return
        }

        proceed := executePlugin(c, tw, rin,
            outboundPlugin.ExecutePlugin)

        // in outbound we need to break and not return
        // to write the actual output
        if !proceed {
            break
        }
    }

    // response already written, except if there are outbound plugins
    if len(endpointEntry.OutboundPluginInstances) > 0 {
        // nolint
        tw := targetWriter.(*DummyWriter)

        for k, v := range tw.HeaderMap {
            for a := range v {
                w.Header().Add(k, v[a])
            }
        }
        w.WriteHeader(tw.Code)
        _, err := w.Write(tw.Body.Bytes())
        if err != nil {
            slogRoute.Error("Failed to write api response", "err", err)
        }
    }
}

func executePlugin(c *core.ConsumerFile, w http.ResponseWriter, r *http.Request,
    fn func(*core.ConsumerFile, http.ResponseWriter, *http.Request) bool,
) bool {
    select {
    case <-r.Context().Done():
        w.WriteHeader(http.StatusRequestTimeout)
        //nolint
        w.Write([]byte("request timed out"))

        return false
    default:
    }

    return fn(c, w, r)
}

func swapRequestResponse(rin *http.Request, w *DummyWriter) (*http.Request, error) {
    r, err := http.NewRequest(http.MethodGet, "/writer", w.Body)
    if err != nil {
        return nil, err
    }
    r.Response = &http.Response{
        StatusCode: w.Code,
    }

    return r.WithContext(rin.Context()), nil
}

// API functions.
func (ep *gatewayManager) GetConsumers(namespace string) ([]*core.ConsumerFile, error) {
    g, ok := ep.nsGateways[namespace]
    if !ok {
        return nil, fmt.Errorf("no consumers for namespace %s", namespace)
    }

    return g.ConsumerList.GetConsumers(), nil
}

func (ep *gatewayManager) GetRoutes(namespace, filteredPath string) ([]*core.Endpoint, error) {
    g, ok := ep.nsGateways[namespace]
    if !ok {
        return nil, fmt.Errorf("no routes for namespace %s", namespace)
    }

    result := make([]*core.Endpoint, 0)
    for _, v := range g.EndpointList.GetEndpoints() {
        if v.Path == filteredPath || filteredPath == "" {
            result = append(result, v)
        }
    }

    return result, nil
}