docker/swarmkit

View on GitHub
manager/orchestrator/update/updater.go

Summary

Maintainability
F
3 days
Test Coverage
package update

import (
    "context"
    "errors"
    "fmt"
    "reflect"
    "sync"
    "time"

    "github.com/docker/go-events"
    gogotypes "github.com/gogo/protobuf/types"
    "github.com/moby/swarmkit/v2/api"
    "github.com/moby/swarmkit/v2/api/defaults"
    "github.com/moby/swarmkit/v2/log"
    "github.com/moby/swarmkit/v2/manager/orchestrator"
    "github.com/moby/swarmkit/v2/manager/orchestrator/restart"
    "github.com/moby/swarmkit/v2/manager/state"
    "github.com/moby/swarmkit/v2/manager/state/store"
    "github.com/moby/swarmkit/v2/protobuf/ptypes"
    "github.com/moby/swarmkit/v2/watch"
)

// Supervisor supervises a set of updates. It's responsible for keeping track of updates,
// shutting them down and replacing them.
type Supervisor struct {
    store    *store.MemoryStore
    restarts *restart.Supervisor
    updates  map[string]*Updater
    l        sync.Mutex
}

// NewSupervisor creates a new UpdateSupervisor.
func NewSupervisor(store *store.MemoryStore, restartSupervisor *restart.Supervisor) *Supervisor {
    return &Supervisor{
        store:    store,
        updates:  make(map[string]*Updater),
        restarts: restartSupervisor,
    }
}

// Update starts an Update of `slots` belonging to `service` in the background
// and returns immediately. Each slot contains a group of one or more tasks
// occupying the same slot (replicated service) or node (global service). There
// may be more than one task per slot in cases where an update is in progress
// and the new task was started before the old one was shut down. If an update
// for that service was already in progress, it will be cancelled before the
// new one starts.
func (u *Supervisor) Update(ctx context.Context, cluster *api.Cluster, service *api.Service, slots []orchestrator.Slot) {
    u.l.Lock()
    defer u.l.Unlock()

    id := service.ID

    if update, ok := u.updates[id]; ok {
        if reflect.DeepEqual(service.Spec, update.newService.Spec) {
            // There's already an update working towards this goal.
            return
        }
        update.Cancel()
    }

    update := NewUpdater(u.store, u.restarts, cluster, service)
    u.updates[id] = update
    go func() {
        update.Run(ctx, slots)
        u.l.Lock()
        if u.updates[id] == update {
            delete(u.updates, id)
        }
        u.l.Unlock()
    }()
}

// CancelAll cancels all current updates.
func (u *Supervisor) CancelAll() {
    u.l.Lock()
    defer u.l.Unlock()

    for _, update := range u.updates {
        update.Cancel()
    }
}

// Updater updates a set of tasks to a new version.
type Updater struct {
    store      *store.MemoryStore
    watchQueue *watch.Queue
    restarts   *restart.Supervisor

    cluster    *api.Cluster
    newService *api.Service

    updatedTasks   map[string]time.Time // task ID to creation time
    updatedTasksMu sync.Mutex

    // stopChan signals to the state machine to stop running.
    stopChan chan struct{}
    // doneChan is closed when the state machine terminates.
    doneChan chan struct{}
}

// NewUpdater creates a new Updater.
func NewUpdater(store *store.MemoryStore, restartSupervisor *restart.Supervisor, cluster *api.Cluster, newService *api.Service) *Updater {
    return &Updater{
        store:        store,
        watchQueue:   store.WatchQueue(),
        restarts:     restartSupervisor,
        cluster:      cluster.Copy(),
        newService:   newService.Copy(),
        updatedTasks: make(map[string]time.Time),
        stopChan:     make(chan struct{}),
        doneChan:     make(chan struct{}),
    }
}

// Cancel cancels the current update immediately. It blocks until the cancellation is confirmed.
func (u *Updater) Cancel() {
    close(u.stopChan)
    <-u.doneChan
}

// Run starts the update and returns only once its complete or cancelled.
func (u *Updater) Run(ctx context.Context, slots []orchestrator.Slot) {
    defer close(u.doneChan)

    service := u.newService

    // If the update is in a PAUSED state, we should not do anything.
    if service.UpdateStatus != nil &&
        (service.UpdateStatus.State == api.UpdateStatus_PAUSED ||
            service.UpdateStatus.State == api.UpdateStatus_ROLLBACK_PAUSED) {
        return
    }

    var dirtySlots []orchestrator.Slot
    for _, slot := range slots {
        if u.isSlotDirty(slot) {
            dirtySlots = append(dirtySlots, slot)
        }
    }
    // Abort immediately if all tasks are clean.
    if len(dirtySlots) == 0 {
        if service.UpdateStatus != nil &&
            (service.UpdateStatus.State == api.UpdateStatus_UPDATING ||
                service.UpdateStatus.State == api.UpdateStatus_ROLLBACK_STARTED) {
            u.completeUpdate(ctx, service.ID)
        }
        return
    }

    // If there's no update in progress, we are starting one.
    if service.UpdateStatus == nil {
        u.startUpdate(ctx, service.ID)
    }

    var (
        monitoringPeriod time.Duration
        updateConfig     *api.UpdateConfig
    )

    if service.UpdateStatus != nil && service.UpdateStatus.State == api.UpdateStatus_ROLLBACK_STARTED {
        monitoringPeriod, _ = gogotypes.DurationFromProto(defaults.Service.Rollback.Monitor)
        updateConfig = service.Spec.Rollback
        if updateConfig == nil {
            updateConfig = defaults.Service.Rollback
        }
    } else {
        monitoringPeriod, _ = gogotypes.DurationFromProto(defaults.Service.Update.Monitor)
        updateConfig = service.Spec.Update
        if updateConfig == nil {
            updateConfig = defaults.Service.Update
        }
    }

    parallelism := int(updateConfig.Parallelism)
    if updateConfig.Monitor != nil {
        newMonitoringPeriod, err := gogotypes.DurationFromProto(updateConfig.Monitor)
        if err == nil {
            monitoringPeriod = newMonitoringPeriod
        }
    }

    if parallelism == 0 {
        // TODO(aluzzardi): We could try to optimize unlimited parallelism by performing updates in a single
        // goroutine using a batch transaction.
        parallelism = len(dirtySlots)
    }

    // Start the workers.
    slotQueue := make(chan orchestrator.Slot)
    wg := sync.WaitGroup{}
    wg.Add(parallelism)
    for i := 0; i < parallelism; i++ {
        go func() {
            u.worker(ctx, slotQueue, updateConfig)
            wg.Done()
        }()
    }

    var failedTaskWatch chan events.Event

    if updateConfig.FailureAction != api.UpdateConfig_CONTINUE {
        var cancelWatch func()
        failedTaskWatch, cancelWatch = state.Watch(
            u.store.WatchQueue(),
            api.EventUpdateTask{
                Task:   &api.Task{ServiceID: service.ID, Status: api.TaskStatus{State: api.TaskStateRunning}},
                Checks: []api.TaskCheckFunc{api.TaskCheckServiceID, state.TaskCheckStateGreaterThan},
            },
        )
        defer cancelWatch()
    }

    stopped := false
    failedTasks := make(map[string]struct{})
    totalFailures := 0

    failureTriggersAction := func(failedTask *api.Task) bool {
        // Ignore tasks we have already seen as failures.
        if _, found := failedTasks[failedTask.ID]; found {
            return false
        }

        // If this failed/completed task is one that we
        // created as part of this update, we should
        // follow the failure action.
        u.updatedTasksMu.Lock()
        startedAt, found := u.updatedTasks[failedTask.ID]
        u.updatedTasksMu.Unlock()

        if found && (startedAt.IsZero() || time.Since(startedAt) <= monitoringPeriod) {
            failedTasks[failedTask.ID] = struct{}{}
            totalFailures++
            if float32(totalFailures)/float32(len(dirtySlots)) > updateConfig.MaxFailureRatio {
                switch updateConfig.FailureAction {
                case api.UpdateConfig_PAUSE:
                    stopped = true
                    message := fmt.Sprintf("update paused due to failure or early termination of task %s", failedTask.ID)
                    u.pauseUpdate(ctx, service.ID, message)
                    return true
                case api.UpdateConfig_ROLLBACK:
                    // Never roll back a rollback
                    if service.UpdateStatus != nil && service.UpdateStatus.State == api.UpdateStatus_ROLLBACK_STARTED {
                        message := fmt.Sprintf("rollback paused due to failure or early termination of task %s", failedTask.ID)
                        u.pauseUpdate(ctx, service.ID, message)
                        return true
                    }
                    stopped = true
                    message := fmt.Sprintf("update rolled back due to failure or early termination of task %s", failedTask.ID)
                    u.rollbackUpdate(ctx, service.ID, message)
                    return true
                }
            }
        }

        return false
    }

slotsLoop:
    for _, slot := range dirtySlots {
    retryLoop:
        for {
            // Wait for a worker to pick up the task or abort the update, whichever comes first.
            select {
            case <-u.stopChan:
                stopped = true
                break slotsLoop
            case ev := <-failedTaskWatch:
                if failureTriggersAction(ev.(api.EventUpdateTask).Task) {
                    break slotsLoop
                }
            case slotQueue <- slot:
                break retryLoop
            }
        }
    }

    close(slotQueue)
    wg.Wait()

    if !stopped {
        // if a delay is set we need to monitor for a period longer than the delay
        // otherwise we will leave the monitorLoop before the task is done delaying
        if updateConfig.Delay >= monitoringPeriod {
            monitoringPeriod = updateConfig.Delay + 1*time.Second
        }
        // Keep watching for task failures for one more monitoringPeriod,
        // before declaring the update complete.
        doneMonitoring := time.After(monitoringPeriod)
    monitorLoop:
        for {
            select {
            case <-u.stopChan:
                stopped = true
                break monitorLoop
            case <-doneMonitoring:
                break monitorLoop
            case ev := <-failedTaskWatch:
                if failureTriggersAction(ev.(api.EventUpdateTask).Task) {
                    break monitorLoop
                }
            }
        }
    }

    // TODO(aaronl): Potentially roll back the service if not enough tasks
    // have reached RUNNING by this point.

    if !stopped {
        u.completeUpdate(ctx, service.ID)
    }
}

func (u *Updater) worker(ctx context.Context, queue <-chan orchestrator.Slot, updateConfig *api.UpdateConfig) {
    for slot := range queue {
        // Do we have a task with the new spec in desired state = RUNNING?
        // If so, all we have to do to complete the update is remove the
        // other tasks. Or if we have a task with the new spec that has
        // desired state < RUNNING, advance it to running and remove the
        // other tasks.
        var (
            runningTask *api.Task
            cleanTask   *api.Task
        )
        for _, t := range slot {
            if !u.isTaskDirty(t) {
                if t.DesiredState == api.TaskStateRunning {
                    runningTask = t
                    break
                }
                if t.DesiredState < api.TaskStateRunning {
                    cleanTask = t
                }
            }
        }
        if runningTask != nil {
            if err := u.useExistingTask(ctx, slot, runningTask); err != nil {
                log.G(ctx).WithError(err).Error("update failed")
            }
        } else if cleanTask != nil {
            if err := u.useExistingTask(ctx, slot, cleanTask); err != nil {
                log.G(ctx).WithError(err).Error("update failed")
            }
        } else {
            updated := orchestrator.NewTask(u.cluster, u.newService, slot[0].Slot, "")
            if orchestrator.IsGlobalService(u.newService) {
                updated = orchestrator.NewTask(u.cluster, u.newService, slot[0].Slot, slot[0].NodeID)
            }
            updated.DesiredState = api.TaskStateReady

            if err := u.updateTask(ctx, slot, updated, updateConfig.Order); err != nil {
                log.G(ctx).WithError(err).WithField("task.id", updated.ID).Error("update failed")
            }
        }

        if updateConfig.Delay != 0 {
            select {
            case <-time.After(updateConfig.Delay):
            case <-u.stopChan:
                return
            }
        }
    }
}

func (u *Updater) updateTask(ctx context.Context, slot orchestrator.Slot, updated *api.Task, order api.UpdateConfig_UpdateOrder) error {
    // Kick off the watch before even creating the updated task. This is in order to avoid missing any event.
    taskUpdates, cancel := state.Watch(u.watchQueue, api.EventUpdateTask{
        Task:   &api.Task{ID: updated.ID},
        Checks: []api.TaskCheckFunc{api.TaskCheckID},
    })
    defer cancel()

    // Create an empty entry for this task, so the updater knows a failure
    // should count towards the failure count. The timestamp is added
    // if/when the task reaches RUNNING.
    u.updatedTasksMu.Lock()
    u.updatedTasks[updated.ID] = time.Time{}
    u.updatedTasksMu.Unlock()

    startThenStop := false
    var delayStartCh <-chan struct{}
    // Atomically create the updated task and bring down the old one.
    err := u.store.Batch(func(batch *store.Batch) error {
        err := batch.Update(func(tx store.Tx) error {
            if store.GetService(tx, updated.ServiceID) == nil {
                return errors.New("service was deleted")
            }

            return store.CreateTask(tx, updated)
        })
        if err != nil {
            return err
        }

        if order == api.UpdateConfig_START_FIRST {
            delayStartCh = u.restarts.DelayStart(ctx, nil, nil, updated.ID, 0, false)
            startThenStop = true
        } else {
            oldTask, err := u.removeOldTasks(ctx, batch, slot)
            if err != nil {
                return err
            }
            delayStartCh = u.restarts.DelayStart(ctx, nil, oldTask, updated.ID, 0, true)
        }

        return nil

    })
    if err != nil {
        return err
    }

    if delayStartCh != nil {
        select {
        case <-delayStartCh:
        case <-u.stopChan:
            return nil
        }
    }

    // Wait for the new task to come up.
    // TODO(aluzzardi): Consider adding a timeout here.
    for {
        select {
        case e := <-taskUpdates:
            updated = e.(api.EventUpdateTask).Task
            if updated.Status.State >= api.TaskStateRunning {
                u.updatedTasksMu.Lock()
                u.updatedTasks[updated.ID] = time.Now()
                u.updatedTasksMu.Unlock()

                if startThenStop && updated.Status.State == api.TaskStateRunning {
                    err := u.store.Batch(func(batch *store.Batch) error {
                        _, err := u.removeOldTasks(ctx, batch, slot)
                        if err != nil {
                            log.G(ctx).WithError(err).WithField("task.id", updated.ID).Warning("failed to remove old task after starting replacement")
                        }
                        return nil
                    })
                    return err
                }
                return nil
            }
        case <-u.stopChan:
            return nil
        }
    }
}

func (u *Updater) useExistingTask(ctx context.Context, slot orchestrator.Slot, existing *api.Task) error {
    var removeTasks []*api.Task
    for _, t := range slot {
        if t != existing {
            removeTasks = append(removeTasks, t)
        }
    }
    if len(removeTasks) != 0 || existing.DesiredState != api.TaskStateRunning {
        var delayStartCh <-chan struct{}
        err := u.store.Batch(func(batch *store.Batch) error {
            var oldTask *api.Task
            if len(removeTasks) != 0 {
                var err error
                oldTask, err = u.removeOldTasks(ctx, batch, removeTasks)
                if err != nil {
                    return err
                }
            }

            if existing.DesiredState != api.TaskStateRunning {
                delayStartCh = u.restarts.DelayStart(ctx, nil, oldTask, existing.ID, 0, true)
            }
            return nil
        })
        if err != nil {
            return err
        }

        if delayStartCh != nil {
            select {
            case <-delayStartCh:
            case <-u.stopChan:
                return nil
            }
        }
    }

    return nil
}

// removeOldTasks shuts down the given tasks and returns one of the tasks that
// was shut down, or an error.
func (u *Updater) removeOldTasks(ctx context.Context, batch *store.Batch, removeTasks []*api.Task) (*api.Task, error) {
    var (
        lastErr     error
        removedTask *api.Task
    )
    for _, original := range removeTasks {
        if original.DesiredState > api.TaskStateRunning {
            continue
        }
        err := batch.Update(func(tx store.Tx) error {
            t := store.GetTask(tx, original.ID)
            if t == nil {
                return fmt.Errorf("task %s not found while trying to shut it down", original.ID)
            }
            if t.DesiredState > api.TaskStateRunning {
                return fmt.Errorf(
                    "task %s was already shut down when reached by updater (state: %v)",
                    original.ID, t.DesiredState,
                )
            }
            t.DesiredState = api.TaskStateShutdown
            return store.UpdateTask(tx, t)
        })
        if err != nil {
            lastErr = err
        } else {
            removedTask = original
        }
    }

    if removedTask == nil {
        return nil, lastErr
    }
    return removedTask, nil
}

func (u *Updater) isTaskDirty(t *api.Task) bool {
    var n *api.Node
    u.store.View(func(tx store.ReadTx) {
        n = store.GetNode(tx, t.NodeID)
    })
    return orchestrator.IsTaskDirty(u.newService, t, n)
}

func (u *Updater) isSlotDirty(slot orchestrator.Slot) bool {
    return len(slot) > 1 || (len(slot) == 1 && u.isTaskDirty(slot[0]))
}

func (u *Updater) startUpdate(ctx context.Context, serviceID string) {
    err := u.store.Update(func(tx store.Tx) error {
        service := store.GetService(tx, serviceID)
        if service == nil {
            return nil
        }
        if service.UpdateStatus != nil {
            return nil
        }

        service.UpdateStatus = &api.UpdateStatus{
            State:     api.UpdateStatus_UPDATING,
            Message:   "update in progress",
            StartedAt: ptypes.MustTimestampProto(time.Now()),
        }

        return store.UpdateService(tx, service)
    })

    if err != nil {
        log.G(ctx).WithError(err).Errorf("failed to mark update of service %s in progress", serviceID)
    }
}

func (u *Updater) pauseUpdate(ctx context.Context, serviceID, message string) {
    log.G(ctx).Debugf("pausing update of service %s", serviceID)

    err := u.store.Update(func(tx store.Tx) error {
        service := store.GetService(tx, serviceID)
        if service == nil {
            return nil
        }
        if service.UpdateStatus == nil {
            // The service was updated since we started this update
            return nil
        }

        if service.UpdateStatus.State == api.UpdateStatus_ROLLBACK_STARTED {
            service.UpdateStatus.State = api.UpdateStatus_ROLLBACK_PAUSED
        } else {
            service.UpdateStatus.State = api.UpdateStatus_PAUSED
        }
        service.UpdateStatus.Message = message

        return store.UpdateService(tx, service)
    })

    if err != nil {
        log.G(ctx).WithError(err).Errorf("failed to pause update of service %s", serviceID)
    }
}

func (u *Updater) rollbackUpdate(ctx context.Context, serviceID, message string) {
    log.G(ctx).Debugf("starting rollback of service %s", serviceID)

    err := u.store.Update(func(tx store.Tx) error {
        service := store.GetService(tx, serviceID)
        if service == nil {
            return nil
        }
        if service.UpdateStatus == nil {
            // The service was updated since we started this update
            return nil
        }

        service.UpdateStatus.State = api.UpdateStatus_ROLLBACK_STARTED
        service.UpdateStatus.Message = message

        if service.PreviousSpec == nil {
            return errors.New("cannot roll back service because no previous spec is available")
        }
        service.Spec = *service.PreviousSpec
        service.SpecVersion = service.PreviousSpecVersion.Copy()
        service.PreviousSpec = nil
        service.PreviousSpecVersion = nil

        return store.UpdateService(tx, service)
    })

    if err != nil {
        log.G(ctx).WithError(err).Errorf("failed to start rollback of service %s", serviceID)
        return
    }
}

func (u *Updater) completeUpdate(ctx context.Context, serviceID string) {
    log.G(ctx).Debugf("update of service %s complete", serviceID)

    err := u.store.Update(func(tx store.Tx) error {
        service := store.GetService(tx, serviceID)
        if service == nil {
            return nil
        }
        if service.UpdateStatus == nil {
            // The service was changed since we started this update
            return nil
        }
        if service.UpdateStatus.State == api.UpdateStatus_ROLLBACK_STARTED {
            service.UpdateStatus.State = api.UpdateStatus_ROLLBACK_COMPLETED
            service.UpdateStatus.Message = "rollback completed"
        } else {
            service.UpdateStatus.State = api.UpdateStatus_COMPLETED
            service.UpdateStatus.Message = "update completed"
        }
        service.UpdateStatus.CompletedAt = ptypes.MustTimestampProto(time.Now())

        return store.UpdateService(tx, service)
    })

    if err != nil {
        log.G(ctx).WithError(err).Errorf("failed to mark update of service %s complete", serviceID)
    }
}