docker/swarmkit

View on GitHub
protobuf/plugin/raftproxy/raftproxy.go

Summary

Maintainability
A
1 hr
Test Coverage
package raftproxy

import (
    "strings"

    "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
    "github.com/gogo/protobuf/protoc-gen-gogo/generator"
)

type raftProxyGen struct {
    gen *generator.Generator
}

func init() {
    generator.RegisterPlugin(new(raftProxyGen))
}

func (g *raftProxyGen) Init(gen *generator.Generator) {
    g.gen = gen
}

func (g *raftProxyGen) Name() string {
    return "raftproxy"
}

func (g *raftProxyGen) genProxyStruct(s *descriptor.ServiceDescriptorProto) {
    g.gen.P("type " + serviceTypeName(s) + " struct {")
    g.gen.P("\tlocal " + s.GetName() + "Server")
    g.gen.P("\tconnSelector raftselector.ConnProvider")
    g.gen.P("\tlocalCtxMods, remoteCtxMods []func(context.Context)(context.Context, error)")
    g.gen.P("}")
}

func (g *raftProxyGen) genProxyConstructor(s *descriptor.ServiceDescriptorProto) {
    g.gen.P("func NewRaftProxy" + s.GetName() + "Server(local " + s.GetName() + "Server, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context)(context.Context, error)) " + s.GetName() + "Server {")
    g.gen.P(`redirectChecker := func(ctx context.Context)(context.Context, error) {
        p, ok := peer.FromContext(ctx)
        if !ok {
            return ctx, status.Errorf(codes.InvalidArgument, "remote addr is not found in context")
        }
        addr := p.Addr.String()
        md, ok := metadata.FromIncomingContext(ctx)
        if ok && len(md["redirect"]) != 0 {
            return ctx, status.Errorf(codes.ResourceExhausted, "more than one redirect to leader from: %s", md["redirect"])
        }
        if !ok {
            md = metadata.New(map[string]string{})
        }
        md["redirect"] = append(md["redirect"], addr)
        return metadata.NewOutgoingContext(ctx, md), nil
    }
    remoteMods := []func(context.Context)(context.Context, error){redirectChecker}
    remoteMods = append(remoteMods, remoteCtxMod)

    var localMods []func(context.Context)(context.Context, error)
    if localCtxMod != nil {
        localMods = []func(context.Context)(context.Context, error){localCtxMod}
    }
    `)
    g.gen.P("return &" + serviceTypeName(s) + `{
        local: local,
        connSelector: connSelector,
        localCtxMods: localMods,
        remoteCtxMods: remoteMods,
    }`)
    g.gen.P("}")
}

func (g *raftProxyGen) genRunCtxMods(s *descriptor.ServiceDescriptorProto) {
    g.gen.P("func (p *" + serviceTypeName(s) + `) runCtxMods(ctx context.Context, ctxMods []func(context.Context)(context.Context, error)) (context.Context, error) {
    var err error
    for _, mod := range ctxMods {
        ctx, err = mod(ctx)
        if err != nil {
            return ctx, err
        }
    }
    return ctx, nil
}`)
}

func getInputTypeName(m *descriptor.MethodDescriptorProto) string {
    parts := strings.Split(m.GetInputType(), ".")
    return parts[len(parts)-1]
}

func getOutputTypeName(m *descriptor.MethodDescriptorProto) string {
    parts := strings.Split(m.GetOutputType(), ".")
    return parts[len(parts)-1]
}

func serviceTypeName(s *descriptor.ServiceDescriptorProto) string {
    return "raftProxy" + s.GetName() + "Server"
}

func sigPrefix(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) string {
    return "func (p *" + serviceTypeName(s) + ") " + m.GetName() + "("
}

func (g *raftProxyGen) genStreamWrapper(streamType string) {
    // Generate stream wrapper that returns a modified context
    g.gen.P(`type ` + streamType + `Wrapper struct {
    ` + streamType + `
    ctx context.Context
}
`)
    g.gen.P(`func (s ` + streamType + `Wrapper) Context() context.Context {
    return s.ctx
}
`)
}

func (g *raftProxyGen) genClientStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
    streamType := s.GetName() + "_" + m.GetName() + "Server"

    // Generate stream wrapper that returns a modified context
    g.genStreamWrapper(streamType)

    g.gen.P(sigPrefix(s, m) + "stream " + streamType + `) error {
    ctx := stream.Context()
    conn, err := p.connSelector.LeaderConn(ctx)
    if err != nil {
        if err == raftselector.ErrIsLeader {
            ctx, err = p.runCtxMods(ctx, p.localCtxMods)
            if err != nil {
                return err
            }
            streamWrapper := ` + streamType + `Wrapper{
                ` + streamType + `: stream,
                ctx: ctx,
            }
            return p.local.` + m.GetName() + `(streamWrapper)
        }
        return err
    }
    ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
    if err != nil {
        return err
    }`)
    g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx)")
    g.gen.P(`
    if err != nil {
            return err
    }`)
    g.gen.P(`
    for {
        msg, err := stream.Recv()
        if err == io.EOF {
            break
        }
        if err != nil {
            return err
        }
        if err := clientStream.Send(msg); err != nil {
            return err
        }
    }

    reply, err := clientStream.CloseAndRecv()
    if err != nil {
        return err
    }

    return stream.SendAndClose(reply)`)
    g.gen.P("}")
}

func (g *raftProxyGen) genServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
    streamType := s.GetName() + "_" + m.GetName() + "Server"

    g.genStreamWrapper(streamType)

    g.gen.P(sigPrefix(s, m) + "r *" + getInputTypeName(m) + ", stream " + streamType + `) error {
    ctx := stream.Context()
    conn, err := p.connSelector.LeaderConn(ctx)
    if err != nil {
        if err == raftselector.ErrIsLeader {
            ctx, err = p.runCtxMods(ctx, p.localCtxMods)
            if err != nil {
                return err
            }
            streamWrapper := ` + streamType + `Wrapper{
                ` + streamType + `: stream,
                ctx: ctx,
            }
            return p.local.` + m.GetName() + `(r, streamWrapper)
        }
        return err
    }
    ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
    if err != nil {
        return err
    }`)
    g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx, r)")
    g.gen.P(`
    if err != nil {
            return err
    }`)
    g.gen.P(`
    for {
        msg, err := clientStream.Recv()
        if err == io.EOF {
            break
        }
        if err != nil {
            return err
        }
        if err := stream.Send(msg); err != nil {
            return err
        }
    }
    return nil`)
    g.gen.P("}")
}

func (g *raftProxyGen) genClientServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
    streamType := s.GetName() + "_" + m.GetName() + "Server"

    g.genStreamWrapper(streamType)

    g.gen.P(sigPrefix(s, m) + "stream " + streamType + `) error {
    ctx := stream.Context()
    conn, err := p.connSelector.LeaderConn(ctx)
    if err != nil {
        if err == raftselector.ErrIsLeader {
            ctx, err = p.runCtxMods(ctx, p.localCtxMods)
            if err != nil {
                return err
            }
            streamWrapper := ` + streamType + `Wrapper{
                ` + streamType + `: stream,
                ctx: ctx,
            }
            return p.local.` + m.GetName() + `(streamWrapper)
        }
        return err
    }
    ctx, err = p.runCtxMods(ctx, p.remoteCtxMods)
    if err != nil {
        return err
    }`)
    g.gen.P("clientStream, err := New" + s.GetName() + "Client(conn)." + m.GetName() + "(ctx)")
    g.gen.P(`
    if err != nil {
            return err
    }`)
    g.gen.P(`errc := make(chan error, 1)
    go func() {
        msg, err := stream.Recv()
        if err == io.EOF {
            close(errc)
            return
        }
        if err != nil {
            errc <- err
            return
        }
        if err := clientStream.Send(msg); err != nil {
            errc <- err
            return
        }
    }()`)
    g.gen.P(`
    for {
        msg, err := clientStream.Recv()
        if err == io.EOF {
            break
        }
        if err != nil {
            return err
        }
        if err := stream.Send(msg); err != nil {
            return err
        }
    }
    clientStream.CloseSend()
    return <-errc`)
    g.gen.P("}")
}

func (g *raftProxyGen) genSimpleMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
    g.gen.P(sigPrefix(s, m) + "ctx context.Context, r *" + getInputTypeName(m) + ") (*" + getOutputTypeName(m) + ", error) {")
    g.gen.P(`
    conn, err := p.connSelector.LeaderConn(ctx)
    if err != nil {
        if err == raftselector.ErrIsLeader {
            ctx, err = p.runCtxMods(ctx, p.localCtxMods)
            if err != nil {
                return nil, err
            }
            return p.local.` + m.GetName() + `(ctx, r)
        }
        return nil, err
    }
    modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods)
    if err != nil {
        return nil, err
    }`)
    g.gen.P(`
    resp, err := New` + s.GetName() + `Client(conn).` + m.GetName() + `(modCtx, r)
    if err != nil {
        if !strings.Contains(err.Error(), "is closing") && !strings.Contains(err.Error(), "the connection is unavailable") && !strings.Contains(err.Error(), "connection error") {
            return resp, err
        }
        conn, err := p.pollNewLeaderConn(ctx)
        if err != nil {
            if err == raftselector.ErrIsLeader {
                return p.local.` + m.GetName() + `(ctx, r)
            }
            return nil, err
        }
        return New` + s.GetName() + `Client(conn).` + m.GetName() + `(modCtx, r)
    }`)
    g.gen.P("return resp, err")
    g.gen.P("}")
}

func (g *raftProxyGen) genProxyMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) {
    g.gen.P()
    switch {
    case m.GetServerStreaming() && m.GetClientStreaming():
        g.genClientServerStreamingMethod(s, m)
    case m.GetServerStreaming():
        g.genServerStreamingMethod(s, m)
    case m.GetClientStreaming():
        g.genClientStreamingMethod(s, m)
    default:
        g.genSimpleMethod(s, m)
    }
    g.gen.P()
}

func (g *raftProxyGen) genPollNewLeaderConn(s *descriptor.ServiceDescriptorProto) {
    g.gen.P(`func (p *` + serviceTypeName(s) + `) pollNewLeaderConn(ctx context.Context) (*grpc.ClientConn, error) {
        ticker := rafttime.NewTicker(500 * rafttime.Millisecond)
        defer ticker.Stop()
        for {
            select {
            case <-ticker.C:
                conn, err := p.connSelector.LeaderConn(ctx)
                if err != nil {
                    return nil, err
                }

                client := NewHealthClient(conn)

                resp, err := client.Check(ctx, &HealthCheckRequest{Service: "Raft"})
                if err != nil || resp.Status != HealthCheckResponse_SERVING {
                    continue
                }
                return conn, nil
            case <-ctx.Done():
                return nil, ctx.Err()
            }
        }
    }`)
}

func (g *raftProxyGen) Generate(file *generator.FileDescriptor) {
    if len(file.FileDescriptorProto.Service) == 0 {
        return
    }
    g.gen.P()
    for _, s := range file.Service {
        g.genProxyStruct(s)
        g.genProxyConstructor(s)
        g.genRunCtxMods(s)
        g.genPollNewLeaderConn(s)
        for _, m := range s.Method {
            g.genProxyMethod(s, m)
        }
    }
    g.gen.P()
}

func (g *raftProxyGen) GenerateImports(file *generator.FileDescriptor) {
    if len(file.Service) == 0 {
        return
    }
    g.gen.PrintImport("raftselector", "github.com/moby/swarmkit/v2/manager/raftselector")
    g.gen.PrintImport("codes", "google.golang.org/grpc/codes")
    g.gen.PrintImport("status", "google.golang.org/grpc/status")
    g.gen.PrintImport("metadata", "google.golang.org/grpc/metadata")
    g.gen.PrintImport("peer", "google.golang.org/grpc/peer")
    // don't conflict with import added by ptypes
    g.gen.PrintImport("rafttime", "time")
}