server/pkg/tasks_manager/actions.go

Summary

Maintainability
A
0 mins
Test Coverage
C
78%
package tasks_manager

import (
    "context"
    "errors"
    "fmt"
    "time"

    "github.com/hashicorp/vault/sdk/logical"

    "github.com/werf/trdl/server/pkg/tasks_manager/worker"
)

var (
    ErrBusy            = errors.New("busy")
    ErrContextCanceled = errors.New("context canceled")
)

const taskReasonInvalidatedTask = "the task canceled due to restart of the plugin"

func (m *Manager) RunTask(ctx context.Context, reqStorage logical.Storage, taskFunc func(context.Context, logical.Storage) error) (string, error) {
    var taskUUID string
    err := m.doTaskWrap(ctx, reqStorage, taskFunc, func(newTaskFunc func(ctx context.Context) error) error {
        busy, err := m.isBusy(ctx, reqStorage)
        if err != nil {
            return err
        }

        if busy {
            return ErrBusy
        }

        taskUUID, err = m.queueTask(ctx, newTaskFunc)
        return err
    })

    return taskUUID, err
}

func (m *Manager) AddOptionalTask(ctx context.Context, reqStorage logical.Storage, taskFunc func(context.Context, logical.Storage) error) (string, bool, error) {
    taskUUID, err := m.RunTask(ctx, reqStorage, taskFunc)
    if err != nil {
        if err == ErrBusy {
            return taskUUID, false, nil
        }

        return "", false, err
    }

    return taskUUID, true, nil
}

func (m *Manager) AddTask(ctx context.Context, reqStorage logical.Storage, taskFunc func(context.Context, logical.Storage) error) (string, error) {
    var taskUUID string
    err := m.doTaskWrap(ctx, reqStorage, taskFunc, func(newTaskFunc func(ctx context.Context) error) error {
        var err error
        taskUUID, err = m.queueTask(ctx, newTaskFunc)

        return err
    })

    return taskUUID, err
}

func (m *Manager) doTaskWrap(ctx context.Context, reqStorage logical.Storage, taskFunc func(context.Context, logical.Storage) error, f func(func(ctx context.Context) error) error) error {
    m.mu.Lock()
    defer m.mu.Unlock()

    // initialize on first task
    if m.Storage == nil {
        m.Storage = reqStorage
        if err := m.invalidateStorage(ctx, reqStorage); err != nil {
            return fmt.Errorf("unable to invalidate storage: %w", err)
        }
    }

    config, err := getConfiguration(ctx, reqStorage)
    if err != nil {
        return fmt.Errorf("unable to get tasks manager configuration: %w", err)
    }

    var taskTimeoutDuration time.Duration
    if config != nil {
        taskTimeoutDuration = config.TaskTimeout
    } else {
        taskTimeoutDuration = defaultTaskTimeoutDuration
    }

    workerTaskFunc := m.WrapTaskFunc(taskFunc, taskTimeoutDuration)

    return f(workerTaskFunc)
}

// WrapTaskFunc separates processing of the context and the taskFunc execution in the background
func (m *Manager) WrapTaskFunc(taskFunc func(context.Context, logical.Storage) error, taskTimeoutDuration time.Duration) func(ctx context.Context) error {
    return func(ctx context.Context) error {
        ctxWithTimeout, ctxCancelFunc := context.WithTimeout(ctx, taskTimeoutDuration)
        defer ctxCancelFunc()

        resCh := make(chan error)
        go func() {
            defer func() {
                p := recover()
                if p == nil || fmt.Sprint(p) == "send on closed channel" {
                    return
                }

                panic(p)
            }()

            resCh <- taskFunc(ctxWithTimeout, m.Storage)
        }()

        select {
        case <-ctxWithTimeout.Done():
            close(resCh)
            m.logger.Debug("task failed: context canceled")
            return ErrContextCanceled
        case err := <-resCh:
            if err != nil {
                m.logger.Debug(fmt.Sprintf("task failed: %s", err))
                return err
            }

            m.logger.Debug("task succeeded")
            return nil
        }
    }
}

func (m *Manager) invalidateStorage(ctx context.Context, reqStorage logical.Storage) error {
    var list []string
    for _, state := range []taskState{taskStateRunning, taskStateQueued} {
        prefix := taskStorageKeyPrefix(state)
        l, err := reqStorage.List(ctx, prefix)
        if err != nil {
            return fmt.Errorf("unable to list %q in storage: %w", prefix, err)
        }

        list = append(list, l...)
    }

    for _, uuid := range list {
        if err := switchTaskToCompletedInStorage(ctx, reqStorage, taskStatusCanceled, uuid, switchTaskToCompletedInStorageOptions{
            reason: taskReasonInvalidatedTask,
        }); err != nil {
            return fmt.Errorf("unable to invalidate task %q: %w", uuid, err)
        }
    }

    return nil
}

func (m *Manager) queueTask(ctx context.Context, workerTaskFunc func(context.Context) error) (string, error) {
    queuedTaskUUID, err := addNewTaskToStorage(ctx, m.Storage)
    if err != nil {
        return "", err
    }

    m.taskChan <- &worker.Task{Context: ctx, UUID: queuedTaskUUID, Action: workerTaskFunc}

    return queuedTaskUUID, nil
}

func (m *Manager) isBusy(ctx context.Context, reqStorage logical.Storage) (bool, error) {
    // busy if there are running or queued tasks
    for _, prefix := range []string{storageKeyPrefixRunningTask, storageKeyPrefixQueuedTask} {
        list, err := reqStorage.List(ctx, prefix)
        if err != nil {
            return false, fmt.Errorf("unable to list %q in storage: %w", prefix, err)
        }

        if len(list) != 0 {
            return true, nil
        }
    }

    return false, nil
}