docker/swarmkit

View on GitHub
manager/logbroker/broker.go

Summary

Maintainability
B
6 hrs
Test Coverage
package logbroker

import (
    "context"
    "errors"
    "fmt"
    "io"
    "sync"

    "github.com/docker/go-events"
    "github.com/moby/swarmkit/v2/api"
    "github.com/moby/swarmkit/v2/ca"
    "github.com/moby/swarmkit/v2/identity"
    "github.com/moby/swarmkit/v2/log"
    "github.com/moby/swarmkit/v2/manager/state/store"
    "github.com/moby/swarmkit/v2/watch"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/status"
)

var (
    errAlreadyRunning = errors.New("broker is already running")
    errNotRunning     = errors.New("broker is not running")
)

type logMessage struct {
    *api.PublishLogsMessage
    completed bool
    err       error
}

// LogBroker coordinates log subscriptions to services and tasks. Clients can
// publish and subscribe to logs channels.
//
// Log subscriptions are pushed to the work nodes by creating log subscription
// tasks. As such, the LogBroker also acts as an orchestrator of these tasks.
type LogBroker struct {
    mu                sync.RWMutex
    logQueue          *watch.Queue
    subscriptionQueue *watch.Queue

    registeredSubscriptions map[string]*subscription
    subscriptionsByNode     map[string]map[*subscription]struct{}

    pctx      context.Context
    cancelAll context.CancelFunc

    store *store.MemoryStore
}

// New initializes and returns a new LogBroker
func New(store *store.MemoryStore) *LogBroker {
    return &LogBroker{
        store: store,
    }
}

// Start starts the log broker
func (lb *LogBroker) Start(ctx context.Context) error {
    lb.mu.Lock()
    defer lb.mu.Unlock()

    if lb.cancelAll != nil {
        return errAlreadyRunning
    }

    lb.pctx, lb.cancelAll = context.WithCancel(ctx)
    lb.logQueue = watch.NewQueue()
    lb.subscriptionQueue = watch.NewQueue()
    lb.registeredSubscriptions = make(map[string]*subscription)
    lb.subscriptionsByNode = make(map[string]map[*subscription]struct{})
    return nil
}

// Stop stops the log broker
func (lb *LogBroker) Stop() error {
    lb.mu.Lock()
    defer lb.mu.Unlock()

    if lb.cancelAll == nil {
        return errNotRunning
    }
    lb.cancelAll()
    lb.cancelAll = nil

    lb.logQueue.Close()
    lb.subscriptionQueue.Close()

    return nil
}

func validateSelector(selector *api.LogSelector) error {
    if selector == nil {
        return status.Errorf(codes.InvalidArgument, "log selector must be provided")
    }

    if len(selector.ServiceIDs) == 0 && len(selector.TaskIDs) == 0 && len(selector.NodeIDs) == 0 {
        return status.Errorf(codes.InvalidArgument, "log selector must not be empty")
    }

    return nil
}

func (lb *LogBroker) newSubscription(selector *api.LogSelector, options *api.LogSubscriptionOptions) *subscription {
    lb.mu.RLock()
    defer lb.mu.RUnlock()

    subscription := newSubscription(lb.store, &api.SubscriptionMessage{
        ID:       identity.NewID(),
        Selector: selector,
        Options:  options,
    }, lb.subscriptionQueue)

    return subscription
}

func (lb *LogBroker) getSubscription(id string) *subscription {
    lb.mu.RLock()
    defer lb.mu.RUnlock()

    subscription, ok := lb.registeredSubscriptions[id]
    if !ok {
        return nil
    }
    return subscription
}

func (lb *LogBroker) registerSubscription(subscription *subscription) {
    lb.mu.Lock()
    defer lb.mu.Unlock()

    lb.registeredSubscriptions[subscription.message.ID] = subscription
    lb.subscriptionQueue.Publish(subscription)

    for _, node := range subscription.Nodes() {
        if _, ok := lb.subscriptionsByNode[node]; !ok {
            // Mark nodes that won't receive the message as done.
            subscription.Done(node, fmt.Errorf("node %s is not available", node))
        } else {
            // otherwise, add the subscription to the node's subscriptions list
            lb.subscriptionsByNode[node][subscription] = struct{}{}
        }
    }
}

func (lb *LogBroker) unregisterSubscription(subscription *subscription) {
    lb.mu.Lock()
    defer lb.mu.Unlock()

    delete(lb.registeredSubscriptions, subscription.message.ID)

    // remove the subscription from all of the nodes
    for _, node := range subscription.Nodes() {
        // but only if a node exists
        if _, ok := lb.subscriptionsByNode[node]; ok {
            delete(lb.subscriptionsByNode[node], subscription)
        }
    }

    subscription.Close()
    lb.subscriptionQueue.Publish(subscription)
}

// watchSubscriptions grabs all current subscriptions and notifies of any
// subscription change for this node.
//
// Subscriptions may fire multiple times and the caller has to protect against
// dupes.
func (lb *LogBroker) watchSubscriptions(nodeID string) ([]*subscription, chan events.Event, func()) {
    lb.mu.RLock()
    defer lb.mu.RUnlock()

    // Watch for subscription changes for this node.
    ch, cancel := lb.subscriptionQueue.CallbackWatch(events.MatcherFunc(func(event events.Event) bool {
        s := event.(*subscription)
        return s.Contains(nodeID)
    }))

    // Grab current subscriptions.
    var subscriptions []*subscription
    for _, s := range lb.registeredSubscriptions {
        if s.Contains(nodeID) {
            subscriptions = append(subscriptions, s)
        }
    }

    return subscriptions, ch, cancel
}

func (lb *LogBroker) subscribe(id string) (chan events.Event, func()) {
    lb.mu.RLock()
    defer lb.mu.RUnlock()

    return lb.logQueue.CallbackWatch(events.MatcherFunc(func(event events.Event) bool {
        publish := event.(*logMessage)
        return publish.SubscriptionID == id
    }))
}

func (lb *LogBroker) publish(log *api.PublishLogsMessage) {
    lb.mu.RLock()
    defer lb.mu.RUnlock()

    lb.logQueue.Publish(&logMessage{PublishLogsMessage: log})
}

// markDone wraps (*Subscription).Done() so that the removal of the sub from
// the node's subscription list is possible
func (lb *LogBroker) markDone(sub *subscription, nodeID string, err error) {
    lb.mu.Lock()
    defer lb.mu.Unlock()

    // remove the subscription from the node's subscription list, if it exists
    if _, ok := lb.subscriptionsByNode[nodeID]; ok {
        delete(lb.subscriptionsByNode[nodeID], sub)
    }

    // mark the sub as done
    sub.Done(nodeID, err)
}

// SubscribeLogs creates a log subscription and streams back logs
func (lb *LogBroker) SubscribeLogs(request *api.SubscribeLogsRequest, stream api.Logs_SubscribeLogsServer) error {
    ctx := stream.Context()

    if err := validateSelector(request.Selector); err != nil {
        return err
    }

    lb.mu.Lock()
    pctx := lb.pctx
    lb.mu.Unlock()
    if pctx == nil {
        return errNotRunning
    }

    subscription := lb.newSubscription(request.Selector, request.Options)
    subscription.Run(pctx)
    defer subscription.Stop()

    logger := log.G(ctx).WithFields(
        log.Fields{
            "method":          "(*LogBroker).SubscribeLogs",
            "subscription.id": subscription.message.ID,
        },
    )
    logger.Debug("subscribed")

    publishCh, publishCancel := lb.subscribe(subscription.message.ID)
    defer publishCancel()

    lb.registerSubscription(subscription)
    defer lb.unregisterSubscription(subscription)

    completed := subscription.Wait(ctx)
    for {
        select {
        case <-ctx.Done():
            return ctx.Err()
        case <-pctx.Done():
            return pctx.Err()
        case event := <-publishCh:
            publish := event.(*logMessage)
            if publish.completed {
                return publish.err
            }
            if err := stream.Send(&api.SubscribeLogsMessage{
                Messages: publish.Messages,
            }); err != nil {
                return err
            }
        case <-completed:
            completed = nil
            lb.logQueue.Publish(&logMessage{
                PublishLogsMessage: &api.PublishLogsMessage{
                    SubscriptionID: subscription.message.ID,
                },
                completed: true,
                err:       subscription.Err(),
            })
        }
    }
}

func (lb *LogBroker) nodeConnected(nodeID string) {
    lb.mu.Lock()
    defer lb.mu.Unlock()

    if _, ok := lb.subscriptionsByNode[nodeID]; !ok {
        lb.subscriptionsByNode[nodeID] = make(map[*subscription]struct{})
    }
}

func (lb *LogBroker) nodeDisconnected(nodeID string) {
    lb.mu.Lock()
    defer lb.mu.Unlock()

    for sub := range lb.subscriptionsByNode[nodeID] {
        sub.Done(nodeID, fmt.Errorf("node %s disconnected unexpectedly", nodeID))
    }
    delete(lb.subscriptionsByNode, nodeID)
}

// ListenSubscriptions returns a stream of matching subscriptions for the current node
func (lb *LogBroker) ListenSubscriptions(request *api.ListenSubscriptionsRequest, stream api.LogBroker_ListenSubscriptionsServer) error {
    remote, err := ca.RemoteNode(stream.Context())
    if err != nil {
        return err
    }

    lb.mu.Lock()
    pctx := lb.pctx
    lb.mu.Unlock()
    if pctx == nil {
        return errNotRunning
    }

    lb.nodeConnected(remote.NodeID)
    defer lb.nodeDisconnected(remote.NodeID)

    logger := log.G(stream.Context()).WithFields(
        log.Fields{
            "method": "(*LogBroker).ListenSubscriptions",
            "node":   remote.NodeID,
        },
    )
    subscriptions, subscriptionCh, subscriptionCancel := lb.watchSubscriptions(remote.NodeID)
    defer subscriptionCancel()

    logger.Debug("node registered")

    activeSubscriptions := make(map[string]*subscription)

    // Start by sending down all active subscriptions.
    for _, subscription := range subscriptions {
        select {
        case <-stream.Context().Done():
            return stream.Context().Err()
        case <-pctx.Done():
            return nil
        default:
        }

        if err := stream.Send(subscription.message); err != nil {
            logger.Error(err)
            return err
        }
        activeSubscriptions[subscription.message.ID] = subscription
    }

    // Send down new subscriptions.
    for {
        select {
        case v := <-subscriptionCh:
            subscription := v.(*subscription)

            if subscription.Closed() {
                delete(activeSubscriptions, subscription.message.ID)
            } else {
                // Avoid sending down the same subscription multiple times
                if _, ok := activeSubscriptions[subscription.message.ID]; ok {
                    continue
                }
                activeSubscriptions[subscription.message.ID] = subscription
            }
            if err := stream.Send(subscription.message); err != nil {
                logger.Error(err)
                return err
            }
        case <-stream.Context().Done():
            return stream.Context().Err()
        case <-pctx.Done():
            return nil
        }
    }
}

// PublishLogs publishes log messages for a given subscription
func (lb *LogBroker) PublishLogs(stream api.LogBroker_PublishLogsServer) (err error) {
    remote, err := ca.RemoteNode(stream.Context())
    if err != nil {
        return err
    }

    var currentSubscription *subscription
    defer func() {
        if currentSubscription != nil {
            lb.markDone(currentSubscription, remote.NodeID, err)
        }
    }()

    for {
        logMsg, err := stream.Recv()
        if err == io.EOF {
            return stream.SendAndClose(&api.PublishLogsResponse{})
        }
        if err != nil {
            return err
        }

        if logMsg.SubscriptionID == "" {
            return status.Errorf(codes.InvalidArgument, "missing subscription ID")
        }

        if currentSubscription == nil {
            currentSubscription = lb.getSubscription(logMsg.SubscriptionID)
            if currentSubscription == nil {
                return status.Errorf(codes.NotFound, "unknown subscription ID")
            }
        } else {
            if logMsg.SubscriptionID != currentSubscription.message.ID {
                return status.Errorf(codes.InvalidArgument, "different subscription IDs in the same session")
            }
        }

        // if we have a close message, close out the subscription
        if logMsg.Close {
            // Mark done and then set to nil so if we error after this point,
            // we don't try to close again in the defer
            lb.markDone(currentSubscription, remote.NodeID, err)
            currentSubscription = nil
            return nil
        }

        // Make sure logs are emitted using the right Node ID to avoid impersonation.
        for _, msg := range logMsg.Messages {
            if msg.Context.NodeID != remote.NodeID {
                return status.Errorf(codes.PermissionDenied, "invalid NodeID: expected=%s;received=%s", remote.NodeID, msg.Context.NodeID)
            }
        }

        lb.publish(logMsg)
    }
}