docker/swarmkit

View on GitHub
manager/state/raft/transport/peer.go

Summary

Maintainability
A
1 hr
Test Coverage
package transport

import (
    "context"
    "fmt"
    "sync"
    "time"

    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"

    "github.com/moby/swarmkit/v2/api"
    "github.com/moby/swarmkit/v2/log"
    "github.com/moby/swarmkit/v2/manager/state/raft/membership"
    "github.com/pkg/errors"
    "go.etcd.io/etcd/raft/v3"
    "go.etcd.io/etcd/raft/v3/raftpb"
    "google.golang.org/grpc/status"
)

const (
    // GRPCMaxMsgSize is the max allowed gRPC message size for raft messages.
    GRPCMaxMsgSize = 4 << 20
)

type peer struct {
    id uint64

    tr *Transport

    msgc chan raftpb.Message

    ctx    context.Context
    cancel context.CancelFunc
    done   chan struct{}

    mu      sync.Mutex
    cc      *grpc.ClientConn
    addr    string
    newAddr string

    active       bool
    becameActive time.Time
}

func newPeer(id uint64, addr string, tr *Transport) (*peer, error) {
    cc, err := tr.dial(addr)
    if err != nil {
        return nil, errors.Wrapf(err, "failed to create conn for %x with addr %s", id, addr)
    }
    ctx, cancel := context.WithCancel(tr.ctx)
    ctx = log.WithField(ctx, "peer_id", fmt.Sprintf("%x", id))
    p := &peer{
        id:     id,
        addr:   addr,
        cc:     cc,
        tr:     tr,
        ctx:    ctx,
        cancel: cancel,
        msgc:   make(chan raftpb.Message, 4096),
        done:   make(chan struct{}),
    }
    go p.run(ctx)
    return p, nil
}

func (p *peer) send(m raftpb.Message) (err error) {
    p.mu.Lock()
    defer func() {
        if err != nil {
            p.active = false
            p.becameActive = time.Time{}
        }
        p.mu.Unlock()
    }()
    select {
    case <-p.ctx.Done():
        return p.ctx.Err()
    default:
    }
    select {
    case p.msgc <- m:
    case <-p.ctx.Done():
        return p.ctx.Err()
    default:
        p.tr.config.ReportUnreachable(p.id)
        return errors.Errorf("peer is unreachable")
    }
    return nil
}

func (p *peer) update(addr string) error {
    p.mu.Lock()
    defer p.mu.Unlock()
    if p.addr == addr {
        return nil
    }
    cc, err := p.tr.dial(addr)
    if err != nil {
        return err
    }

    p.cc.Close()
    p.cc = cc
    p.addr = addr
    return nil
}

func (p *peer) updateAddr(addr string) error {
    p.mu.Lock()
    defer p.mu.Unlock()
    if p.addr == addr {
        return nil
    }
    log.G(p.ctx).Debugf("peer %x updated to address %s, it will be used if old failed", p.id, addr)
    p.newAddr = addr
    return nil
}

func (p *peer) conn() *grpc.ClientConn {
    p.mu.Lock()
    defer p.mu.Unlock()
    return p.cc
}

func (p *peer) address() string {
    p.mu.Lock()
    defer p.mu.Unlock()
    return p.addr
}

func (p *peer) resolveAddr(ctx context.Context, id uint64) (string, error) {
    resp, err := api.NewRaftClient(p.conn()).ResolveAddress(ctx, &api.ResolveAddressRequest{RaftID: id})
    if err != nil {
        return "", errors.Wrap(err, "failed to resolve address")
    }
    return resp.Addr, nil
}

// Returns the raft message struct size (not including the payload size) for the given raftpb.Message.
// The payload is typically the snapshot or append entries.
func raftMessageStructSize(m *raftpb.Message) int {
    return (&api.ProcessRaftMessageRequest{Message: m}).Size() - len(m.Snapshot.Data)
}

// Returns the max allowable payload based on MaxRaftMsgSize and
// the struct size for the given raftpb.Message.
func raftMessagePayloadSize(m *raftpb.Message) int {
    return GRPCMaxMsgSize - raftMessageStructSize(m)
}

// Split a large raft message into smaller messages.
// Currently this means splitting the []Snapshot.Data into chunks whose size
// is dictacted by MaxRaftMsgSize.
func splitSnapshotData(ctx context.Context, m *raftpb.Message) []api.StreamRaftMessageRequest {
    var messages []api.StreamRaftMessageRequest
    if m.Type != raftpb.MsgSnap {
        return messages
    }

    // get the size of the data to be split.
    size := len(m.Snapshot.Data)

    // Get the max payload size.
    payloadSize := raftMessagePayloadSize(m)

    // split the snapshot into smaller messages.
    for snapDataIndex := 0; snapDataIndex < size; {
        chunkSize := size - snapDataIndex
        if chunkSize > payloadSize {
            chunkSize = payloadSize
        }

        raftMsg := *m

        // sub-slice for this snapshot chunk.
        raftMsg.Snapshot.Data = m.Snapshot.Data[snapDataIndex : snapDataIndex+chunkSize]

        snapDataIndex += chunkSize

        // add message to the list of messages to be sent.
        msg := api.StreamRaftMessageRequest{Message: &raftMsg}
        messages = append(messages, msg)
    }

    return messages
}

// Function to check if this message needs to be split to be streamed
// (because it is larger than GRPCMaxMsgSize).
// Returns true if the message type is MsgSnap
// and size larger than MaxRaftMsgSize.
func needsSplitting(m *raftpb.Message) bool {
    raftMsg := api.ProcessRaftMessageRequest{Message: m}
    return m.Type == raftpb.MsgSnap && raftMsg.Size() > GRPCMaxMsgSize
}

func (p *peer) sendProcessMessage(ctx context.Context, m raftpb.Message) error {
    // These lines used to be in the code, but they've been removed. I'm
    // leaving them in in a comment just in case they cause some unforeseen
    // breakage later, to show why they were removed.
    //
    // ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout)
    // defer cancel()
    //
    // Basically, these lines created a timeout that applied not to each chunk
    // of a streaming message, but to the whole streaming process. With a
    // sufficiently large raft log, the bandwidth on some connections can not
    // physically be enough to fit within the default 2 second timeout.
    // Further, it seems that because of some gRPC magic, the timeout was
    // getting propagated to the stream *server*, meaning it wasn't even the
    // sender timing out, it was the receiver.
    //
    // It should be fine to remove this timeout. The whole purpose of this
    // method is to send very large raft messages that could take several
    // seconds to send.

    ctx, cancel := context.WithCancel(ctx)
    defer cancel()

    // This is a bootleg watchdog timer. If the timer elapses without something
    // being written to the bump channel, it will cancel the context.
    //
    // We use this because the operations on this stream *must* either time out
    // or succeed for raft to function correctly. We can't just time out the
    // whole operation, because of the reasons stated above. But we also only
    // set the context once, when we create the stream, and so can't set an
    // individual timeout for each stream operation.
    //
    // By doing it as this watchdog-type structure, we can time out individual
    // operations by canceling the context on our own terms.
    t := time.AfterFunc(p.tr.config.SendTimeout, cancel)
    defer t.Stop()

    bump := func() { t.Reset(p.tr.config.SendTimeout) }

    var err error
    var stream api.Raft_StreamRaftMessageClient
    stream, err = api.NewRaftClient(p.conn()).StreamRaftMessage(ctx)

    if err == nil {
        // Split the message if needed.
        // Currently only supported for MsgSnap.
        var msgs []api.StreamRaftMessageRequest
        if needsSplitting(&m) {
            msgs = splitSnapshotData(ctx, &m)
        } else {
            raftMsg := api.StreamRaftMessageRequest{Message: &m}
            msgs = append(msgs, raftMsg)
        }

        // Stream
        for _, msg := range msgs {
            err = stream.Send(&msg)
            if err != nil {
                log.G(ctx).WithError(err).Error("error streaming message to peer")
                stream.CloseAndRecv()
                break
            }

            // If the send succeeds, bump the watchdog timer.
            bump()
        }

        // Finished sending all the messages.
        // Close and receive response.
        if err == nil {
            _, err = stream.CloseAndRecv()

            if err != nil {
                log.G(ctx).WithError(err).Error("error receiving response")
            }
        }
    } else {
        log.G(ctx).WithError(err).Error("error sending message to peer")
    }

    // Try doing a regular rpc if the receiver doesn't support streaming.
    s, _ := status.FromError(err)
    if s.Code() == codes.Unimplemented {
        log.G(ctx).Info("sending message to raft peer using ProcessRaftMessage()")
        _, err = api.NewRaftClient(p.conn()).ProcessRaftMessage(ctx, &api.ProcessRaftMessageRequest{Message: &m})
    }

    // Handle errors.
    s, _ = status.FromError(err)
    if s.Code() == codes.NotFound && s.Message() == membership.ErrMemberRemoved.Error() {
        p.tr.config.NodeRemoved()
    }
    if m.Type == raftpb.MsgSnap {
        if err != nil {
            p.tr.config.ReportSnapshot(m.To, raft.SnapshotFailure)
        } else {
            p.tr.config.ReportSnapshot(m.To, raft.SnapshotFinish)
        }
    }
    if err != nil {
        p.tr.config.ReportUnreachable(m.To)
        return err
    }
    return nil
}

func healthCheckConn(ctx context.Context, cc *grpc.ClientConn) error {
    resp, err := api.NewHealthClient(cc).Check(ctx, &api.HealthCheckRequest{Service: "Raft"})
    if err != nil {
        return errors.Wrap(err, "failed to check health")
    }
    if resp.Status != api.HealthCheckResponse_SERVING {
        return errors.Errorf("health check returned status %s", resp.Status)
    }
    return nil
}

func (p *peer) healthCheck(ctx context.Context) error {
    ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout)
    defer cancel()
    return healthCheckConn(ctx, p.conn())
}

func (p *peer) setActive() {
    p.mu.Lock()
    if !p.active {
        p.active = true
        p.becameActive = time.Now()
    }
    p.mu.Unlock()
}

func (p *peer) setInactive() {
    p.mu.Lock()
    p.active = false
    p.becameActive = time.Time{}
    p.mu.Unlock()
}

func (p *peer) activeTime() time.Time {
    p.mu.Lock()
    defer p.mu.Unlock()
    return p.becameActive
}

func (p *peer) drain() error {
    ctx, cancel := context.WithTimeout(context.Background(), 16*time.Second)
    defer cancel()
    for {
        select {
        case m, ok := <-p.msgc:
            if !ok {
                // all messages proceeded
                return nil
            }
            if err := p.sendProcessMessage(ctx, m); err != nil {
                return errors.Wrap(err, "send drain message")
            }
        case <-ctx.Done():
            return ctx.Err()
        }
    }
}

func (p *peer) handleAddressChange(ctx context.Context) error {
    p.mu.Lock()
    newAddr := p.newAddr
    p.newAddr = ""
    p.mu.Unlock()
    if newAddr == "" {
        return nil
    }
    cc, err := p.tr.dial(newAddr)
    if err != nil {
        return err
    }
    ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout)
    defer cancel()
    if err := healthCheckConn(ctx, cc); err != nil {
        cc.Close()
        return err
    }
    // there is possibility of race if host changing address too fast, but
    // it's unlikely and eventually thing should be settled
    p.mu.Lock()
    p.cc.Close()
    p.cc = cc
    p.addr = newAddr
    p.tr.config.UpdateNode(p.id, p.addr)
    p.mu.Unlock()
    return nil
}

func (p *peer) run(ctx context.Context) {
    defer func() {
        p.mu.Lock()
        p.active = false
        p.becameActive = time.Time{}
        // at this point we can be sure that nobody will write to msgc
        if p.msgc != nil {
            close(p.msgc)
        }
        p.mu.Unlock()
        if err := p.drain(); err != nil {
            log.G(ctx).WithError(err).Error("failed to drain message queue")
        }
        close(p.done)
    }()
    if err := p.healthCheck(ctx); err == nil {
        p.setActive()
    }
    for {
        select {
        case <-ctx.Done():
            return
        default:
        }

        select {
        case m := <-p.msgc:
            // we do not propagate context here, because this operation should be finished
            // or timed out for correct raft work.
            err := p.sendProcessMessage(context.Background(), m)
            if err != nil {
                log.G(ctx).WithError(err).Debugf("failed to send message %s", m.Type)
                p.setInactive()
                if err := p.handleAddressChange(ctx); err != nil {
                    log.G(ctx).WithError(err).Error("failed to change address after failure")
                }
                continue
            }
            p.setActive()
        case <-ctx.Done():
            return
        }
    }
}

func (p *peer) stop() {
    p.cancel()
    <-p.done
}