rueidislock/lock.go
package rueidislock
import (
"context"
"encoding/binary"
"errors"
"math/rand"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/redis/rueidis"
)
var sources sync.Pool
func init() {
sources = sync.Pool{New: func() any { return rand.New(rand.NewSource(time.Now().UnixNano())) }}
}
// LockerOption should be passed to NewLocker to construct a Locker
type LockerOption struct {
// ClientBuilder can be used to modify rueidis.Client used by Locker
ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error)
// KeyPrefix is the prefix of redis key for locks. Default value is "rueidislock".
KeyPrefix string
// ClientOption is passed to rueidis.NewClient or LockerOption.ClientBuilder to build a rueidis.Client
ClientOption rueidis.ClientOption
// KeyValidity is the validity duration of locks and will be extended periodically by the ExtendInterval. Default value is 5s.
KeyValidity time.Duration
// ExtendInterval is the interval to extend KeyValidity. Default value is 1s.
ExtendInterval time.Duration
// TryNextAfter is the timeout duration before trying the next redis key for locks. Default value is 20ms.
TryNextAfter time.Duration
// KeyMajority is at least how many redis keys in a total of KeyMajority*2-1 should be acquired to be a valid lock.
// Default value is 2.
KeyMajority int32
// NoLoopTracking will use NOLOOP in the CLIENT TRACKING command to avoid unnecessary notifications and thus have better performance.
// This can only be enabled if all your redis nodes >= 7.0.5. (https://github.com/redis/redis/pull/11052)
NoLoopTracking bool
// Use SET PX instead of SET PXAT when acquiring locks to be compatible with Redis < 6.2
FallbackSETPX bool
}
// Locker is the interface of rueidislock
type Locker interface {
// WithContext acquires a distributed redis lock by name by waiting for it. It may return ErrLockerClosed.
WithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error)
// TryWithContext tries to acquire a distributed redis lock by name without waiting. It may return ErrNotLocked.
TryWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error)
// Client exports the underlying rueidis.Client
Client() rueidis.Client
// Close closes the underlying rueidis.Client
Close()
}
// NewLocker creates the distributed Locker backed by redis client side caching
func NewLocker(option LockerOption) (Locker, error) {
if option.KeyPrefix == "" {
option.KeyPrefix = "rueidislock"
}
if option.KeyValidity <= 0 {
option.KeyValidity = time.Second * 5
}
if option.ExtendInterval <= 0 {
option.ExtendInterval = option.KeyValidity / 2
}
if option.TryNextAfter <= 0 {
option.TryNextAfter = time.Millisecond * 20
}
if option.KeyMajority <= 0 {
option.KeyMajority = 2
}
impl := &locker{
prefix: option.KeyPrefix,
validity: option.KeyValidity,
interval: option.ExtendInterval,
timeout: option.TryNextAfter,
majority: option.KeyMajority,
totalcnt: option.KeyMajority*2 - 1,
gates: make(map[string]*gate),
noloop: option.NoLoopTracking,
setpx: option.FallbackSETPX,
}
if option.ClientOption.DisableCache {
impl.noloop = true
} else {
if option.NoLoopTracking {
option.ClientOption.ClientTrackingOptions = []string{"OPTOUT", "NOLOOP"}
} else {
option.ClientOption.ClientTrackingOptions = []string{"OPTOUT"}
}
option.ClientOption.OnInvalidations = impl.onInvalidations
}
option.ClientOption.PipelineMultiplex = -1 // this ensures the CSC goes to the same connection.
var err error
if option.ClientBuilder != nil {
impl.client, err = option.ClientBuilder(option.ClientOption)
} else {
impl.client, err = rueidis.NewClient(option.ClientOption)
}
if err != nil {
return nil, err
}
return impl, nil
}
type locker struct {
client rueidis.Client
gates map[string]*gate
prefix string
validity time.Duration
interval time.Duration
timeout time.Duration
mu sync.RWMutex
majority int32
totalcnt int32
noloop bool
setpx bool
}
type gate struct {
ch chan struct{}
csc []chan struct{}
w int
}
func makegate(size int32) *gate {
csc := make([]chan struct{}, size)
for i := 0; i < len(csc); i++ {
csc[i] = make(chan struct{}, 1)
}
return &gate{ch: make(chan struct{}, 1), csc: csc}
}
func random() string {
val := make([]byte, 24)
src := sources.Get().(rand.Source64)
binary.BigEndian.PutUint64(val[0:8], src.Uint64())
binary.BigEndian.PutUint64(val[8:16], src.Uint64())
binary.BigEndian.PutUint64(val[16:24], src.Uint64())
sources.Put(src)
return rueidis.BinaryString(val)
}
func keyname(prefix, name string, i int32) string {
ia := strconv.Itoa(int(i))
sb := strings.Builder{}
sb.Grow(len(prefix) + len(name) + len(ia) + 2)
sb.WriteString(prefix)
sb.WriteByte(':')
sb.WriteString(ia)
sb.WriteByte(':')
sb.WriteString(name)
return sb.String()
}
func (m *locker) acquire(ctx context.Context, key, val string, deadline time.Time) (err error) {
ctx, cancel := context.WithTimeout(ctx, m.timeout)
var resp rueidis.RedisResult
if m.setpx {
resp = acqms.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(m.validity.Milliseconds(), 10)})
} else {
resp = acqat.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(deadline.UnixMilli(), 10)})
}
cancel()
if err = resp.Error(); rueidis.IsRedisNil(err) {
return ErrNotLocked
}
return err
}
func (m *locker) script(ctx context.Context, script *rueidis.Lua, key, val string, deadline time.Time) error {
ctx, cancel := context.WithDeadline(ctx, deadline)
resp := script.Exec(ctx, m.client, []string{key}, []string{val, strconv.FormatInt(deadline.UnixMilli(), 10)})
cancel()
if v, err := resp.AsInt64(); err != nil || v == 1 {
return err
}
return ErrNotLocked
}
func (m *locker) waitgate(ctx context.Context, name string) (g *gate, err error) {
m.mu.Lock()
g, ok := m.gates[name]
if !ok {
if m.gates == nil {
m.mu.Unlock()
return nil, ErrLockerClosed
}
g = makegate(m.totalcnt)
g.w++
m.gates[name] = g
m.mu.Unlock()
return g, nil
} else {
g.w++
m.mu.Unlock()
}
select {
case <-ctx.Done():
m.mu.Lock()
if g.w--; g.w == 0 && m.gates[name] == g {
delete(m.gates, name)
}
m.mu.Unlock()
return nil, ctx.Err()
case _, ok = <-g.ch:
if ok {
return g, nil
}
return nil, ErrLockerClosed
}
}
func (m *locker) trygate(name string) (g *gate) {
m.mu.Lock()
if _, ok := m.gates[name]; !ok && m.gates != nil {
g = makegate(m.totalcnt)
g.w++
m.gates[name] = g
}
m.mu.Unlock()
return g
}
func (m *locker) onInvalidations(messages []rueidis.RedisMessage) {
if messages == nil {
m.mu.RLock()
for _, g := range m.gates {
for _, ch := range g.csc {
select {
case ch <- struct{}{}:
default:
}
}
}
m.mu.RUnlock()
}
for _, msg := range messages {
k, _ := msg.ToString()
if ks := strings.SplitN(k, ":", 3); len(ks) == 3 {
m.mu.RLock()
g, ok := m.gates[ks[2]]
if ok {
n, _ := strconv.Atoi(ks[1])
select {
case g.csc[n] <- struct{}{}:
default:
}
}
m.mu.RUnlock()
}
}
}
func (m *locker) try(ctx context.Context, cancel context.CancelFunc, name string, g *gate) context.CancelFunc {
var err error
val := random()
deadline := time.Now().Add(m.validity)
cacneltm := time.AfterFunc(m.validity, cancel)
released := int32(0)
done := make(chan struct{})
monitoring := func(err error, key string, deadline time.Time, csc chan struct{}) {
if err == nil {
for timer := time.NewTimer(m.interval); err == nil; {
select {
case <-ctx.Done():
err = ctx.Err()
case <-timer.C:
deadline = deadline.Add(m.interval)
if err = m.script(ctx, extend, key, val, deadline); err == nil {
timer.Reset(m.interval)
if !m.noloop {
<-csc
}
}
case <-csc:
if err = m.script(ctx, extend, key, val, deadline); err == nil {
if !m.noloop {
<-csc
}
}
}
}
}
if err != ErrNotLocked {
_ = m.script(context.Background(), delkey, key, val, deadline)
}
if released := atomic.AddInt32(&released, 1); released >= m.majority {
cancel()
if released == m.totalcnt {
close(done)
m.mu.Lock()
if g.w--; g.w == 0 {
if m.gates[name] == g {
delete(m.gates, name)
}
} else if m.gates != nil {
select {
case g.ch <- struct{}{}:
default:
}
}
m.mu.Unlock()
}
}
}
acquire := func(err error, key string, ch chan struct{}) error {
select {
case <-ch:
default:
}
if err != ErrNotLocked {
err = m.acquire(ctx, key, val, deadline)
}
go monitoring(err, key, deadline, ch)
return err
}
var i, acquired, failures int32
for ; acquired < m.majority && failures < m.majority; i++ {
if err = acquire(err, keyname(m.prefix, name, i), g.csc[i]); err == nil {
acquired++
} else {
failures++
}
}
if i < m.totalcnt {
go func(i int32, err error) {
for ; i < m.totalcnt; i++ {
err = acquire(err, keyname(m.prefix, name, i), g.csc[i])
}
}(i, err)
}
if cacneltm.Stop() && failures < m.majority {
return func() {
cancel()
<-done
}
}
return nil
}
func (m *locker) TryWithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error) {
ctx, cancel := context.WithCancel(ctx)
if g := m.trygate(name); g != nil {
if cancel := m.try(ctx, cancel, name, g); cancel != nil {
return ctx, cancel, nil
}
}
cancel()
return ctx, cancel, ErrNotLocked
}
func (m *locker) WithContext(ctx context.Context, name string) (context.Context, context.CancelFunc, error) {
for {
ctx, cancel := context.WithCancel(ctx)
g, err := m.waitgate(ctx, name)
if g != nil {
if cancel := m.try(ctx, cancel, name, g); cancel != nil {
return ctx, cancel, nil
}
}
if cancel(); err != nil {
return ctx, cancel, err
}
}
}
func (m *locker) Client() rueidis.Client {
return m.client
}
func (m *locker) Close() {
m.mu.Lock()
for _, g := range m.gates {
close(g.ch)
}
m.gates = nil
m.mu.Unlock()
m.client.Close()
}
var (
delkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("DEL",KEYS[1]) end;return 0`)
extend = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then local r = redis.call("PEXPIREAT",KEYS[1],ARGV[2]);redis.call("GET",KEYS[1]);return r end;return 0`)
acqms = rueidis.NewLuaScript(`local r = redis.call("SET",KEYS[1],ARGV[1],"NX","PX",ARGV[2]);redis.call("GET",KEYS[1]);return r`)
acqat = rueidis.NewLuaScript(`local r = redis.call("SET",KEYS[1],ARGV[1],"NX","PXAT",ARGV[2]);redis.call("GET",KEYS[1]);return r`)
)
// ErrNotLocked is returned from the Locker.TryWithContext when it fails
var ErrNotLocked = errors.New("not locked")
// ErrLockerClosed is returned from the Locker.WithContext when the Locker is closed
var ErrLockerClosed = errors.New("locker closed")