streamdal/go-sdk

View on GitHub
register.go

Summary

Maintainability
B
6 hrs
Test Coverage
F
50%
package streamdal

import (
    "context"
    "fmt"
    "runtime"
    "strings"
    "time"

    "github.com/pkg/errors"
    "github.com/relistan/go-director"

    "github.com/streamdal/streamdal/libs/protos/build/go/protos"
    "github.com/streamdal/streamdal/libs/protos/build/go/protos/shared"

    "github.com/streamdal/go-sdk/validate"
)

var (
    ErrPipelineNotPaused = errors.New("pipeline not paused")
    ErrPipelineNotActive = errors.New("pipeline not active or does not exist")
)

func (s *Streamdal) genClientInfo() *protos.ClientInfo {
    return &protos.ClientInfo{
        ClientType:     protos.ClientType(s.config.ClientType),
        LibraryName:    "go-sdk",
        LibraryVersion: "v0.0.86",
        Language:       "go",
        Arch:           runtime.GOARCH,
        Os:             runtime.GOOS,
    }
}

func (s *Streamdal) register(looper director.Looper) error {
    req := &protos.RegisterRequest{
        ServiceName: s.config.ServiceName,
        SessionId:   s.sessionID,
        ClientInfo:  s.genClientInfo(),
        Audiences:   make([]*protos.Audience, 0),
        DryRun:      s.config.DryRun,
    }

    s.audiencesMtx.Lock()
    for _, aud := range s.config.Audiences {
        pAud := aud.toProto(s.config.ServiceName)
        req.Audiences = append(req.Audiences, pAud)
        s.audiences[audToStr(pAud)] = struct{}{}
    }
    s.audiencesMtx.Unlock()

    var (
        stream             protos.Internal_RegisterClient
        err                error
        quit               bool
        initialRegister    = true
        initialRegisterErr error
    )

    // This might not error even if the handler returns an err - need to attempt
    // to perform a recv to verify.
    srv, err := s.serverClient.Register(s.config.ShutdownCtx, req)
    if err != nil {
        return errors.Wrap(err, "unable to complete initial registration with streamdal server")
    }

    stream = srv

    looper.Loop(func() error {
        if quit {
            time.Sleep(time.Millisecond * 100)
            return nil
        }

        // This is here to enable reconnects; no way to hit this case for a
        // "first register attempt" because "stream" won't be nil on initial launch.
        if stream == nil {
            s.config.Logger.Debug("stream is nil, attempting to register")

            if err := s.serverClient.Reconnect(); err != nil {
                s.config.Logger.Errorf("Failed to reconnect with streamdal server: %s, retrying in '%s'", err, ReconnectSleep.String())
                time.Sleep(ReconnectSleep)
                return nil
            }

            s.config.Logger.Debug("successfully reconnected to streamdal server")

            newStream, err := s.serverClient.Register(s.config.ShutdownCtx, req)
            if err != nil {
                if strings.Contains(err.Error(), context.Canceled.Error()) {
                    s.config.Logger.Debug("context cancelled during connect")
                    quit = true
                    looper.Quit()

                    return nil
                }

                s.config.Logger.Errorf("Failed to re-register with streamdal server: %s, retrying in '%s'", err, ReconnectSleep.String())
                time.Sleep(ReconnectSleep)

                return nil
            }

            s.config.Logger.Debug("successfully re-registered to streamdal server")

            stream = newStream

            // Re-announce audience (if we had any) - this is needed so that
            // streamdal server repopulates live entry in live:* prefix (which is used
            // for DetachPipeline())
            s.addAudiences(s.config.ShutdownCtx)
        }

        // Blocks until something is received
        cmd, err := stream.Recv()
        if err != nil {
            // This is the first registration attempt and it has failed.
            // Depending on IgnoreStartupError, we may need to stop the loop
            // and tell the caller that we failed to complete registration.
            if initialRegister && !s.config.IgnoreStartupError {
                initialRegisterErr = err
                quit = true
                looper.Quit()

                return nil
            }

            if err.Error() == "rpc error: code = Canceled desc = context canceled" {
                s.config.Logger.Errorf("context cancelled during recv: %s", err)
                quit = true
                looper.Quit()
                return nil
            }

            // Reset stream - cause re-register on error
            stream = nil

            // Nicer reconnect messages
            if strings.Contains(err.Error(), "reading from server: EOF") {
                s.config.Logger.Warnf("streamdal server is unavailable, retrying in %s...", ReconnectSleep.String())
            } else if strings.Contains(err.Error(), "server shutting down") {
                s.config.Logger.Warnf("streamdal server is shutting down, retrying in %s...", ReconnectSleep.String())
            } else {
                s.config.Logger.Warnf("error receiving message, retrying in %s: %s", ReconnectSleep.String(), err)
            }

            time.Sleep(ReconnectSleep)

            return nil
        }

        // Initial registration has succeeded - no longer need to bail out if we
        // encounter any errors
        initialRegister = false

        if err := s.handleCommand(stream.Context(), cmd); err != nil {
            s.config.Logger.Errorf("Failed to handle command: %s", cmd.Command)
            return nil
        }

        return nil
    })

    if initialRegister {
        return errors.Wrap(initialRegisterErr,
            "failed to complete initial registration with streamdal server (and IgnoreStartupError is set to 'false')",
        )
    }

    return nil
}

func (s *Streamdal) handleCommand(ctx context.Context, cmd *protos.Command) error {
    if cmd == nil {
        s.config.Logger.Debug("Received nil command, ignoring")
        return nil
    }

    if cmd.GetKeepAlive() != nil {
        s.config.Logger.Debug("Received keep alive")
        return nil
    }

    if cmd.Audience != nil && cmd.Audience.ServiceName != s.config.ServiceName {
        s.config.Logger.Debugf("Received command for different service name: %s, ignoring command", cmd.Audience.ServiceName)
        return nil
    }

    var err error

    switch cmd.Command.(type) {
    case *protos.Command_SetPipelines:
        s.config.Logger.Debug("Received set pipelines command")
        err = s.setPipelines(ctx, cmd)
    case *protos.Command_Kv:
        s.config.Logger.Debug("Received kv command")
        err = s.handleKVCommand(ctx, cmd.GetKv())
    case *protos.Command_Tail:
        s.config.Logger.Debug("Received tail command")
        err = s.handleTailCommand(ctx, cmd)
    default:
        err = fmt.Errorf("unknown command type: %+v", cmd.Command)
    }

    return err
}

func (s *Streamdal) handleTailCommand(_ context.Context, cmd *protos.Command) error {
    tail := cmd.GetTail()

    if tail == nil {
        s.config.Logger.Errorf("Received tail command with nil tail; full cmd: %+v", cmd)
        return nil
    }

    if tail.GetRequest() == nil {
        s.config.Logger.Errorf("Received tail command with nil Request; full cmd: %+v", cmd)
        return nil
    }

    audStr := audToStr(tail.GetRequest().Audience)

    var err error

    switch tail.GetRequest().Type {
    case protos.TailRequestType_TAIL_REQUEST_TYPE_START:
        s.config.Logger.Debugf("Received start tail command for audience '%s'", audStr)
        err = s.startTailHandler(context.Background(), cmd)
    case protos.TailRequestType_TAIL_REQUEST_TYPE_STOP:
        s.config.Logger.Debugf("Received stop tail command for audience '%s'", audStr)
        err = s.stopTailHandler(context.Background(), cmd)
    case protos.TailRequestType_TAIL_REQUEST_TYPE_PAUSE:
        s.config.Logger.Debugf("Received pause tail command for audience '%s'", audStr)
        err = s.pauseTailHandler(context.Background(), cmd)
    case protos.TailRequestType_TAIL_REQUEST_TYPE_RESUME:
        s.config.Logger.Debugf("Received resume tail command for audience '%s'", audStr)
        err = s.resumeTailHandler(context.Background(), cmd)
    default:
        return fmt.Errorf("unknown tail command type: %s", tail.GetRequest().Type)
    }

    return err
}

func (s *Streamdal) handleKVCommand(_ context.Context, kv *protos.KVCommand) error {
    if err := validate.KVCommand(kv); err != nil {
        return errors.Wrap(err, "failed to validate kv command")
    }

    for _, i := range kv.Instructions {
        if err := validate.KVInstruction(i); err != nil {
            s.config.Logger.Debugf("KV instruction '%s' failed validate: %s (skipping)", i.Action, err)
            continue
        }

        switch i.Action {
        case shared.KVAction_KV_ACTION_CREATE, shared.KVAction_KV_ACTION_UPDATE:
            s.config.Logger.Debugf("attempting to perform '%s' KV instruction for key '%s'", i.Action, i.Object.Key)
            s.kv.Set(i.Object.Key, string(i.Object.Value))
        case shared.KVAction_KV_ACTION_DELETE:
            s.config.Logger.Debugf("attempting to perform '%s' KV instruction for key '%s'", i.Action, i.Object.Key)
            s.kv.Delete(i.Object.Key)
        case shared.KVAction_KV_ACTION_DELETE_ALL:
            s.config.Logger.Debugf("attempting to perform '%s' KV instruction", i.Action)
            s.kv.Purge()
        default:
            s.config.Logger.Debugf("invalid KV action '%s' - skipping", i.Action)
            continue
        }
    }

    return nil
}

func (s *Streamdal) setPipelines(_ context.Context, cmd *protos.Command) error {
    if cmd == nil {
        return ErrEmptyCommand
    }

    if err := validate.SetPipelinesCommand(cmd); err != nil {
        return errors.Wrap(err, "failed to validate set pipelines command")
    }

    s.pipelinesMtx.Lock()
    defer s.pipelinesMtx.Unlock()

    s.pipelines[audToStr(cmd.Audience)] = cmd.GetSetPipelines().Pipelines

    s.config.Logger.Debugf("saved '%d' pipelines for audience '%s'", len(cmd.GetSetPipelines().Pipelines), audToStr(cmd.Audience))

    return nil
}