vorteil/direktiv

View on GitHub
cmd/sidecar/local-server.go

Summary

Maintainability
B
6 hrs
Test Coverage
package sidecar

import (
    "bytes"
    "context"
    "encoding/json"
    "errors"
    "fmt"
    "io"
    "log/slog"
    "net/http"
    "os"
    "sync"
    "time"

    "github.com/direktiv/direktiv/pkg/core"
    "github.com/direktiv/direktiv/pkg/datastore"
    "github.com/direktiv/direktiv/pkg/engine"
    "github.com/direktiv/direktiv/pkg/tracing"
    "github.com/direktiv/direktiv/pkg/utils"
    "github.com/gorilla/mux"
)

const (
    workerThreads = 10
)

type LocalServer struct {
    end       func()
    flowAddr  string
    flowToken string
    queue     chan *inboundRequest
    router    *mux.Router
    stopper   chan *time.Time
    server    http.Server
    workers   []*inboundWorker

    requestsLock sync.Mutex
    requests     map[string]*activeRequest
}

func (srv *LocalServer) initFlow() error {
    serverArr := fmt.Sprintf("%s:7777", os.Getenv(direktivFlowEndpoint))
    fmt.Printf("flow server: %s\n", serverArr)

    srv.flowToken = os.Getenv("API_KEY")
    srv.flowAddr = fmt.Sprintf("%s:6665", os.Getenv(direktivFlowEndpoint))

    return nil
}

func (srv *LocalServer) Start() {
    err := srv.initFlow()
    if err != nil {
        slog.Error("Localhost server unable to connect to flow", "error", err)
        Shutdown(ERROR)

        return
    }

    srv.queue = make(chan *inboundRequest, 100)
    srv.requests = make(map[string]*activeRequest)

    srv.router = mux.NewRouter()

    // TODO: Pass trace-id to user container
    srv.router.HandleFunc("/log", srv.logHandler)
    srv.router.HandleFunc("/var", srv.varHandler)

    srv.server.Addr = "127.0.0.1:8889"
    srv.server.Handler = srv.router

    srv.stopper = make(chan *time.Time, 1)

    srv.end = threads.Register(srv.stopper)

    slog.Debug("Localhost server thread registered.")

    //nolint:intrange
    for i := 0; i < workerThreads; i++ {
        worker := new(inboundWorker)
        worker.id = i
        worker.srv = srv
        srv.workers = append(srv.workers, worker)
        go worker.run()
    }

    go srv.run()
    go srv.wait()
}

func (srv *LocalServer) wait() {
    defer srv.server.Close()
    defer srv.end()

    t := <-srv.stopper
    close(srv.queue)

    slog.Debug("Localhost server shutting down.")

    for req := range srv.queue {
        go srv.drainRequest(req)
    }

    for _, worker := range srv.workers {
        go worker.Cancel()
    }

    ctx, cancel := context.WithDeadline(context.Background(), t.Add(20*time.Second))
    defer cancel()

    err := srv.server.Shutdown(ctx)
    if err != nil {
        slog.Error("Error shutting down localhost server", "error", err)
        Shutdown(ERROR)

        return
    }

    slog.Debug("Primary localhost server thread shut down successfully.")
}

func (srv *LocalServer) logHandler(w http.ResponseWriter, r *http.Request) {
    actionId := r.URL.Query().Get("aid")

    srv.requestsLock.Lock()
    req, ok := srv.requests[actionId]
    srv.requestsLock.Unlock()
    ctx := tracing.AddNamespace(r.Context(), req.Namespace)
    ctx = tracing.AddInstanceMemoryAttr(ctx, tracing.InstanceAttributes{
        Namespace:    req.Namespace,
        InstanceID:   req.Instance,
        Status:       core.LogUnknownStatus,
        WorkflowPath: req.Workflow,
        Callpath:     req.Callpath,
    }, req.State)
    ctx = tracing.WithTrack(ctx, tracing.BuildInstanceTrackViaCallpath(req.Callpath))
    ctx, span, err2 := tracing.InjectTraceParent(ctx, req.ActionContext.TraceParent, "writing logs in action: "+actionId+", workflow: "+req.Workflow)
    if err2 != nil {
        slog.Debug("Failed to populate trace information.", "action", actionId, "error", err2)
    }
    defer span.End()

    reportError := func(code int, err error) {
        http.Error(w, err.Error(), code)
        slog.WarnContext(ctx, "Log handler error occurred.", "action", actionId, "action_err_code", code, "error", err)
    }

    if !ok {
        reportError(http.StatusInternalServerError, errors.New("the action id went missing"))
        return
    }

    if req == nil {
        reportError(http.StatusNotFound, fmt.Errorf("actionId %s not found", actionId))
        return
    }

    var msg string
    if r.Method == http.MethodPost {
        const capa = int64(0x400000) // 4 MiB
        if r.ContentLength > capa {
            reportError(http.StatusRequestEntityTooLarge, errors.New(http.StatusText(http.StatusRequestEntityTooLarge)))
            return
        }

        data, err := io.ReadAll(io.LimitReader(r.Body, capa))
        if err != nil {
            reportError(http.StatusBadRequest, err)
            return
        }
        msg = string(data)
    } else {
        msg = r.URL.Query().Get("log")
    }

    if len(msg) == 0 {
        slog.Debug("Log handler received an empty message body.", "action", actionId)
        return
    }

    entry := tracing.GetRawLogEntryWithStatus(ctx, tracing.LevelInfo, msg, core.LogRunningStatus)
    d, err := json.Marshal(entry)
    if err != nil {
        slog.Error("Failed to marshal log entry.", "action", actionId, "error", err)
        http.Error(w, "", http.StatusInternalServerError)

        return
    }
    slog.DebugContext(ctx, "redirect log entry to flow", "org-msg", msg)
    addr := fmt.Sprintf("http://%v/api/v2/namespaces/%v/logs?instance=%v", srv.flowAddr, req.Namespace, req.Instance)
    resp, err := doRequest(req.ctx, http.MethodPost, srv.flowToken, addr, bytes.NewBuffer(d))
    if err != nil {
        slog.Error("Failed to forward log to Flow.", "action", actionId, "error", err)
        http.Error(w, "", http.StatusInternalServerError)

        return
    }

    if _, err := handleResponse(resp, nil); err != nil {
        slog.ErrorContext(ctx, "Failed to handle Flow response.", "action", actionId, "error", err)
        http.Error(w, "", http.StatusInternalServerError)

        return
    }

    slog.DebugContext(ctx, "Log handler successfully processed message.", "action", actionId)
}

// nolint:canonicalheader
func (srv *LocalServer) varHandler(w http.ResponseWriter, r *http.Request) {
    actionId := r.URL.Query().Get("aid")

    srv.requestsLock.Lock()
    req, ok := srv.requests[actionId]
    srv.requestsLock.Unlock()
    ctx := tracing.AddNamespace(r.Context(), req.Namespace)
    ctx = tracing.AddInstanceMemoryAttr(ctx, tracing.InstanceAttributes{
        Namespace:    req.Namespace,
        InstanceID:   req.Instance,
        Status:       core.LogUnknownStatus,
        WorkflowPath: req.Workflow,
        Callpath:     req.Callpath,
    }, req.State)
    ctx = tracing.WithTrack(ctx, tracing.BuildInstanceTrackViaCallpath(req.Callpath))
    ctx = tracing.WithTrack(ctx, tracing.BuildInstanceTrackViaCallpath(req.Callpath))
    reportError := func(code int, err error) {
        http.Error(w, err.Error(), code)
        slog.WarnContext(ctx, "Variable retrieval failed.", "action", actionId, "error", err)
    }

    if !ok {
        err := errors.New("the action id went missing")
        code := http.StatusInternalServerError
        reportError(code, err)

        return
    }

    if req == nil {
        code := http.StatusNotFound
        reportError(code, fmt.Errorf("actionId %s not found", actionId))

        return
    }

    ctx = req.ctx
    ctx = tracing.AddNamespace(ctx, req.Namespace)
    ctx = tracing.AddInstanceMemoryAttr(ctx, tracing.InstanceAttributes{
        Namespace:    req.Namespace,
        InstanceID:   req.Instance,
        Status:       core.LogUnknownStatus,
        WorkflowPath: req.Workflow,
        Callpath:     req.Callpath,
    }, req.State)
    ctx = tracing.WithTrack(ctx, tracing.BuildInstanceTrackViaCallpath(req.Callpath))
    ctx = tracing.WithTrack(ctx, tracing.BuildInstanceTrackViaCallpath(req.Callpath))

    ir := req.functionRequest

    scope := r.URL.Query().Get("scope")
    key := r.URL.Query().Get("key")
    vMimeType := r.Header.Get("content-type")

    switch r.Method {
    case http.MethodGet:

        varMeta, statusCode, err := getVariableMetaFromFlow(ctx, srv.flowToken, srv.flowAddr, ir, scope, key)
        if err != nil {
            reportError(statusCode, err)
            slog.WarnContext(ctx, "Failed retrieving a Variable.", "action", actionId, "key", key, "scope", scope)

            return
        }

        varData, err := getVariableDataViaID(ctx, srv.flowToken, srv.flowAddr, ir.Namespace, varMeta.ID.String())
        if err != nil {
            reportError(http.StatusInternalServerError, err)
            slog.WarnContext(ctx, "Failed retrieving a Variable.", "action", actionId, "key", key, "scope", scope)

            return
        }
        _, err = io.Copy(w, bytes.NewReader(varData.Data))
        if err != nil {
            reportError(http.StatusInternalServerError, err)
            slog.Error("Failed retrieving a Variable.", "action", actionId, "key", key, "scope", scope)

            return
        }

        slog.DebugContext(ctx, "Variable successfully retrieved.", "action", actionId, "key", key, "scope", scope)

    case http.MethodPost:

        statusCode, err := srv.setVar(ctx, ir, r.Body, scope, key, vMimeType)
        if err != nil {
            reportError(statusCode, err)
            slog.WarnContext(ctx, "Failed to set a Variable.", "action", actionId, "key", key, "scope", scope)

            return
        }

        slog.DebugContext(ctx, "Variable successfully stored.", "action", actionId, "key", key, "scope", scope, "mime_type", vMimeType)

    default:
        code := http.StatusMethodNotAllowed
        reportError(code, errors.New(http.StatusText(code)))
        slog.Warn("Unsupported HTTP method for var handler.", "action", actionId, "method", r.Method)

        return
    }
}

type activeRequest struct {
    *functionRequest
    cancel func()
    ctx    context.Context //nolint:containedctx
}

func (srv *LocalServer) registerActiveRequest(ir *functionRequest, ctx context.Context, cancel func()) {
    srv.requestsLock.Lock()

    srv.requests[ir.actionId] = &activeRequest{
        functionRequest: ir,
        ctx:             ctx,
        cancel:          cancel,
    }

    srv.requestsLock.Unlock()

    slog.InfoContext(ctx, "Serving.", "action", ir.actionId)
}

func (srv *LocalServer) deregisterActiveRequest(actionId string) {
    srv.requestsLock.Lock()

    delete(srv.requests, actionId)

    srv.requestsLock.Unlock()

    slog.Debug("Request deregistered.", "action", actionId)
}

func (srv *LocalServer) cancelActiveRequest(ctx context.Context, actionId string) {
    srv.requestsLock.Lock()
    req := srv.requests[actionId]
    srv.requestsLock.Unlock()

    if req == nil {
        return
    }

    slog.InfoContext(ctx, "Attempting to cancel.", "action", actionId)

    go srv.sendCancelToService(ctx, req.functionRequest)

    select {
    case <-req.ctx.Done():
    case <-time.After(10 * time.Second):
        slog.WarnContext(ctx, "Request failed to cancel punctually.", "action", actionId)
        req.cancel()
    }
}

func (srv *LocalServer) sendCancelToService(ctx context.Context, ir *functionRequest) {
    url := "http://localhost:8080"

    req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil)
    if err != nil {
        slog.ErrorContext(ctx, "Failed to create cancel request.", "action", ir.actionId, "error", err)
        return
    }

    req.Header.Set(actionIDHeader, ir.actionId)

    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        slog.ErrorContext(ctx, "Failed to send cancel to service.", "action", ir.actionId, "error", err)
        return
    }
    defer resp.Body.Close()

    if resp.StatusCode != http.StatusOK {
        slog.WarnContext(ctx, "Service responded to cancel request.", "action", ir.actionId, "resp-code", resp.StatusCode)
    }
}

type inboundRequest struct {
    w   http.ResponseWriter
    r   *http.Request
    end chan bool
}

func (srv *LocalServer) drainRequest(req *inboundRequest) {
    _ = req.r.Body.Close()

    code := http.StatusServiceUnavailable
    msg := http.StatusText(code)
    http.Error(req.w, msg, code)

    id := req.r.Header.Get(actionIDHeader)
    slog.Warn("Request aborted due to server unavailability", "action", id, "http_status_code", code, "reason", msg)

    defer func() {
        _ = recover()
    }()

    close(req.end)
}

func (srv *LocalServer) run() {
    slog.Info("Starting localhost HTTP server.", "addr", srv.server.Addr)

    err := srv.server.ListenAndServe()
    if err != nil && !errors.Is(err, http.ErrServerClosed) {
        slog.Error("Error running local server", "error", err)
        Shutdown(ERROR)

        return
    }
}

type functionRequest struct {
    actionId string
    engine.ActionContext
    deadline time.Time
    input    []byte
    files    []*functionFiles
}

type functionFiles struct {
    Key         string `json:"key"`
    As          string `json:"as"`
    Scope       string `json:"scope"`
    Type        string `json:"type"`
    Permissions string `json:"permissions"`
}

const sharedDir = "/mnt/shared"

func (srv *LocalServer) setVar(ctx context.Context, ir *functionRequest, r io.Reader, scope, key, vMimeType string) (int, error) {
    // Retrieve variable metadata
    varMeta, statusCode, err := getVariableMetaFromFlow(ctx, srv.flowToken, srv.flowAddr, ir, scope, key)
    if err != nil {
        target := &RessourceNotFoundError{}
        if errors.As(err, &target) {
            data, readErr := io.ReadAll(r)
            if readErr != nil {
                return http.StatusInternalServerError, fmt.Errorf("failed to read data from reader: %w", readErr)
            }

            reqD := createVarRequest{
                Name:     key,
                MimeType: vMimeType,
                Data:     data,
            }

            // Set scope-specific fields
            switch scope {
            case utils.VarScopeInstance:
                reqD.InstanceIDString = ir.Instance
            case utils.VarScopeWorkflow:
                reqD.WorkflowPath = ir.Workflow
            case utils.VarScopeNamespace:
                // Namespace scope requires no additional fields
            default:
                return http.StatusBadRequest, fmt.Errorf("unknown scope: %s", scope)
            }

            // Attempt to create the variable
            postStatusCode, postErr := postVarData(ctx, srv.flowToken, srv.flowAddr, ir.Namespace, reqD)
            if postErr != nil {
                return postStatusCode, fmt.Errorf("failed to post variable data: %w", postErr)
            }
            return http.StatusOK, nil
        }
        // Handle other errors from getVariableMetaFromFlow
        return statusCode, fmt.Errorf("failed to get variable metadata: %w", err)
    }

    // Patch existing variable data
    data, readErr := io.ReadAll(r)
    if readErr != nil {
        return http.StatusInternalServerError, fmt.Errorf("failed to read data from reader: %w", readErr)
    }

    reqD := datastore.RuntimeVariablePatch{
        Name:     &key,
        MimeType: &vMimeType,
        Data:     data,
    }
    patchStatusCode, patchErr := patchVarData(ctx, srv.flowToken, srv.flowAddr, ir.Namespace, varMeta.ID.String(), reqD)
    if patchErr != nil {
        return patchStatusCode, fmt.Errorf("failed to patch variable data: %w", patchErr)
    }
    return statusCode, nil
}