rueian/rueidis

View on GitHub
mux.go

Summary

Maintainability
A
3 hrs
Test Coverage
A
100%
package rueidis

import (
    "context"
    "math/rand"
    "net"
    "runtime"
    "sync"
    "sync/atomic"
    "time"

    "github.com/redis/rueidis/internal/cmds"
    "github.com/redis/rueidis/internal/util"
)

type connFn func(dst string, opt *ClientOption) conn
type dialFn func(dst string, opt *ClientOption) (net.Conn, error)
type wireFn func() wire

type singleconnect struct {
    w wire
    e error
    g sync.WaitGroup
}

type batchcache struct {
    cIndexes []int
    commands []CacheableTTL
}

func (r *batchcache) Capacity() int {
    return cap(r.commands)
}

func (r *batchcache) ResetLen(n int) {
    r.cIndexes = r.cIndexes[:n]
    r.commands = r.commands[:n]
}

var batchcachep = util.NewPool(func(capacity int) *batchcache {
    return &batchcache{
        cIndexes: make([]int, 0, capacity),
        commands: make([]CacheableTTL, 0, capacity),
    }
})

type conn interface {
    Do(ctx context.Context, cmd Completed) RedisResult
    DoCache(ctx context.Context, cmd Cacheable, ttl time.Duration) RedisResult
    DoMulti(ctx context.Context, multi ...Completed) *redisresults
    DoMultiCache(ctx context.Context, multi ...CacheableTTL) *redisresults
    Receive(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error
    DoStream(ctx context.Context, cmd Completed) RedisResultStream
    DoMultiStream(ctx context.Context, multi ...Completed) MultiRedisResultStream
    Info() map[string]RedisMessage
    Version() int
    Error() error
    Close()
    Dial() error
    Override(conn)
    Acquire() wire
    Store(w wire)
    Addr() string
    SetOnCloseHook(func(error))
}

var _ conn = (*mux)(nil)

type mux struct {
    init   wire
    dead   wire
    clhks  atomic.Value
    dpool  *pool
    spool  *pool
    wireFn wireFn
    dst    string
    wire   []atomic.Value
    sc     []*singleconnect
    mu     []sync.Mutex
    maxp   int
}

func makeMux(dst string, option *ClientOption, dialFn dialFn) *mux {
    dead := deadFn()
    connFn := func() (net.Conn, error) {
        return dialFn(dst, option)
    }
    wireFn := func(pipeFn pipeFn) func() wire {
        return func() (w wire) {
            w, err := pipeFn(connFn, option)
            if err != nil {
                dead.error.Store(&errs{error: err})
                w = dead
            }
            return w
        }
    }
    return newMux(dst, option, (*pipe)(nil), dead, wireFn(newPipe), wireFn(newPipeNoBg))
}

func newMux(dst string, option *ClientOption, init, dead wire, wireFn wireFn, wireNoBgFn wireFn) *mux {
    var multiplex int
    if option.PipelineMultiplex >= 0 {
        multiplex = 1 << option.PipelineMultiplex
    } else {
        multiplex = 1
    }
    m := &mux{dst: dst, init: init, dead: dead, wireFn: wireFn,
        wire: make([]atomic.Value, multiplex),
        mu:   make([]sync.Mutex, multiplex),
        sc:   make([]*singleconnect, multiplex),
        maxp: runtime.GOMAXPROCS(0),
    }
    m.clhks.Store(emptyclhks)
    for i := 0; i < len(m.wire); i++ {
        m.wire[i].Store(init)
    }
    m.dpool = newPool(option.BlockingPoolSize, dead, wireFn)
    m.spool = newPool(option.BlockingPoolSize, dead, wireNoBgFn)
    return m
}

func (m *mux) SetOnCloseHook(fn func(error)) {
    m.clhks.Store(fn)
}

func (m *mux) setCloseHookOnWire(i uint16, w wire) {
    if w != m.dead && w != m.init {
        w.SetOnCloseHook(func(err error) {
            if err != ErrClosing {
                if m.wire[i].CompareAndSwap(w, m.init) {
                    m.clhks.Load().(func(error))(err)
                }
            }
        })
    }
}

func (m *mux) Override(cc conn) {
    if m2, ok := cc.(*mux); ok {
        for i := 0; i < len(m.wire) && i < len(m2.wire); i++ {
            w := m2.wire[i].Load().(wire)
            m.setCloseHookOnWire(uint16(i), w) // bind the new m to the old w
            m.wire[i].CompareAndSwap(m.init, w)
        }
    }
}

func (m *mux) _pipe(i uint16) (w wire, err error) {
    if w = m.wire[i].Load().(wire); w != m.init {
        return w, nil
    }

    m.mu[i].Lock()
    sc := m.sc[i]
    if m.sc[i] == nil {
        m.sc[i] = &singleconnect{}
        m.sc[i].g.Add(1)
    }
    m.mu[i].Unlock()

    if sc != nil {
        sc.g.Wait()
        return sc.w, sc.e
    }

    if w = m.wire[i].Load().(wire); w == m.init {
        if w = m.wireFn(); w != m.dead {
            m.setCloseHookOnWire(i, w)
            m.wire[i].Store(w)
        } else {
            if err = w.Error(); err != ErrClosing {
                m.clhks.Load().(func(error))(err)
            }
        }
    }

    m.mu[i].Lock()
    sc = m.sc[i]
    m.sc[i] = nil
    m.mu[i].Unlock()

    sc.w = w
    sc.e = err
    sc.g.Done()

    return w, err
}

func (m *mux) pipe(i uint16) wire {
    w, _ := m._pipe(i)
    return w // this should never be nil
}

func (m *mux) Dial() error {
    _, err := m._pipe(0)
    return err
}

func (m *mux) Info() map[string]RedisMessage {
    return m.pipe(0).Info()
}

func (m *mux) Version() int {
    return m.pipe(0).Version()
}

func (m *mux) Error() error {
    return m.pipe(0).Error()
}

func (m *mux) DoStream(ctx context.Context, cmd Completed) RedisResultStream {
    wire := m.spool.Acquire()
    return wire.DoStream(ctx, m.spool, cmd)
}

func (m *mux) DoMultiStream(ctx context.Context, multi ...Completed) MultiRedisResultStream {
    wire := m.spool.Acquire()
    return wire.DoMultiStream(ctx, m.spool, multi...)
}

func (m *mux) Do(ctx context.Context, cmd Completed) (resp RedisResult) {
    if cmd.IsBlock() {
        resp = m.blocking(ctx, cmd)
    } else {
        resp = m.pipeline(ctx, cmd)
    }
    return resp
}

func (m *mux) DoMulti(ctx context.Context, multi ...Completed) (resp *redisresults) {
    for _, cmd := range multi {
        if cmd.IsBlock() {
            goto block
        }
    }
    return m.pipelineMulti(ctx, multi)
block:
    cmds.ToBlock(&multi[0]) // mark the first cmd as block if one of them is block to shortcut later check.
    return m.blockingMulti(ctx, multi)
}

func (m *mux) blocking(ctx context.Context, cmd Completed) (resp RedisResult) {
    wire := m.dpool.Acquire()
    resp = wire.Do(ctx, cmd)
    if resp.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded)
        wire.Close()
    }
    m.dpool.Store(wire)
    return resp
}

func (m *mux) blockingMulti(ctx context.Context, cmd []Completed) (resp *redisresults) {
    wire := m.dpool.Acquire()
    resp = wire.DoMulti(ctx, cmd...)
    for _, res := range resp.s {
        if res.NonRedisError() != nil { // abort the wire if blocking command return early (ex. context.DeadlineExceeded)
            wire.Close()
            break
        }
    }
    m.dpool.Store(wire)
    return resp
}

func (m *mux) pipeline(ctx context.Context, cmd Completed) (resp RedisResult) {
    slot := slotfn(len(m.wire), cmd.Slot(), cmd.NoReply())
    wire := m.pipe(slot)
    if resp = wire.Do(ctx, cmd); isBroken(resp.NonRedisError(), wire) {
        m.wire[slot].CompareAndSwap(wire, m.init)
    }
    return resp
}

func (m *mux) pipelineMulti(ctx context.Context, cmd []Completed) (resp *redisresults) {
    slot := slotfn(len(m.wire), cmd[0].Slot(), cmd[0].NoReply())
    wire := m.pipe(slot)
    resp = wire.DoMulti(ctx, cmd...)
    for _, r := range resp.s {
        if isBroken(r.NonRedisError(), wire) {
            m.wire[slot].CompareAndSwap(wire, m.init)
            return resp
        }
    }
    return resp
}

func (m *mux) DoCache(ctx context.Context, cmd Cacheable, ttl time.Duration) RedisResult {
    slot := cmd.Slot() & uint16(len(m.wire)-1)
    wire := m.pipe(slot)
    resp := wire.DoCache(ctx, cmd, ttl)
    if isBroken(resp.NonRedisError(), wire) {
        m.wire[slot].CompareAndSwap(wire, m.init)
    }
    return resp
}

func (m *mux) DoMultiCache(ctx context.Context, multi ...CacheableTTL) (results *redisresults) {
    var slots *muxslots
    var mask = uint16(len(m.wire) - 1)

    if mask == 0 {
        return m.doMultiCache(ctx, 0, multi)
    }

    slots = muxslotsp.Get(len(m.wire), len(m.wire))
    for _, cmd := range multi {
        slots.s[cmd.Cmd.Slot()&mask]++
    }

    if slots.LessThen(2) {
        return m.doMultiCache(ctx, multi[0].Cmd.Slot()&mask, multi)
    }

    batches := batchcachemaps.Get(len(m.wire), len(m.wire))
    for slot, count := range slots.s {
        if count > 0 {
            batches.m[uint16(slot)] = batchcachep.Get(0, count)
        }
    }
    muxslotsp.Put(slots)

    for i, cmd := range multi {
        batch := batches.m[cmd.Cmd.Slot()&mask]
        batch.commands = append(batch.commands, cmd)
        batch.cIndexes = append(batch.cIndexes, i)
    }

    results = resultsp.Get(len(multi), len(multi))
    util.ParallelKeys(m.maxp, batches.m, func(slot uint16) {
        batch := batches.m[slot]
        resp := m.doMultiCache(ctx, slot, batch.commands)
        for i, r := range resp.s {
            results.s[batch.cIndexes[i]] = r
        }
        resultsp.Put(resp)
    })

    for _, batch := range batches.m {
        batchcachep.Put(batch)
    }
    batchcachemaps.Put(batches)

    return results
}

func (m *mux) doMultiCache(ctx context.Context, slot uint16, multi []CacheableTTL) (resps *redisresults) {
    wire := m.pipe(slot)
    resps = wire.DoMultiCache(ctx, multi...)
    for _, r := range resps.s {
        if isBroken(r.NonRedisError(), wire) {
            m.wire[slot].CompareAndSwap(wire, m.init)
            return resps
        }
    }
    return resps
}

func (m *mux) Receive(ctx context.Context, subscribe Completed, fn func(message PubSubMessage)) error {
    slot := slotfn(len(m.wire), subscribe.Slot(), subscribe.NoReply())
    wire := m.pipe(slot)
    err := wire.Receive(ctx, subscribe, fn)
    if isBroken(err, wire) {
        m.wire[slot].CompareAndSwap(wire, m.init)
    }
    return err
}

func (m *mux) Acquire() wire {
    return m.dpool.Acquire()
}

func (m *mux) Store(w wire) {
    w.SetPubSubHooks(PubSubHooks{})
    w.CleanSubscriptions()
    m.dpool.Store(w)
}

func (m *mux) Close() {
    for i := 0; i < len(m.wire); i++ {
        if prev := m.wire[i].Swap(m.dead).(wire); prev != m.init && prev != m.dead {
            prev.Close()
        }
    }
    m.dpool.Close()
    m.spool.Close()
}

func (m *mux) Addr() string {
    return m.dst
}

func isBroken(err error, w wire) bool {
    return err != nil && err != ErrClosing && w.Error() != nil
}

var rngPool = sync.Pool{
    New: func() any {
        return rand.New(rand.NewSource(time.Now().UnixNano()))
    },
}

func fastrand(n int) (r int) {
    s := rngPool.Get().(*rand.Rand)
    r = s.Intn(n)
    rngPool.Put(s)
    return
}

func slotfn(n int, ks uint16, noreply bool) uint16 {
    if n == 1 || ks == cmds.NoSlot || noreply {
        return 0
    }
    return uint16(fastrand(n))
}

type muxslots struct {
    s []int
}

func (r *muxslots) Capacity() int {
    return cap(r.s)
}

func (r *muxslots) ResetLen(n int) {
    r.s = r.s[:n]
    for i := 0; i < n; i++ {
        r.s[i] = 0
    }
}

func (r *muxslots) LessThen(n int) bool {
    count := 0
    for _, value := range r.s {
        if value > 0 {
            if count++; count == n {
                return false
            }
        }
    }
    return true
}

var muxslotsp = util.NewPool(func(capacity int) *muxslots {
    return &muxslots{s: make([]int, 0, capacity)}
})

type batchcachemap struct {
    m map[uint16]*batchcache
    n int
}

func (r *batchcachemap) Capacity() int {
    return r.n
}

func (r *batchcachemap) ResetLen(n int) {
    for k := range r.m {
        delete(r.m, k)
    }
}

var batchcachemaps = util.NewPool(func(capacity int) *batchcachemap {
    return &batchcachemap{m: make(map[uint16]*batchcache, capacity), n: capacity}
})