ratelimit.go
package ratelimit
import (
"context"
"errors"
"os"
"time"
"github.com/sirupsen/logrus"
)
type RateLimit struct {
d time.Duration
limit int
ch chan struct{}
ctx context.Context
t *time.Ticker
lastCall time.Time
log *logrus.Logger
}
// New returns a Ratelimit instance and initialize it
func New(ctx context.Context, d time.Duration, limit int) (*RateLimit, error) {
if limit <= 0 || d <= 0 {
return nil, errors.New("ratelimit: duration or limit cannot be <= 0")
}
r := RateLimit{
d: d,
limit: limit,
ch: make(chan struct{}, limit),
ctx: ctx,
log: initLog(os.Getenv("RATELIMIT_LOGLEVEL")),
lastCall: time.Now(),
}
r.backgroundRoutine()
r.handleCtx()
return &r, nil
}
// backgroundRoutine launches a goroutine to empty the channel every r.d duration
func (r *RateLimit) backgroundRoutine() {
r.log.Debugln("Start backgroundRoutine")
go func() {
r.t = time.NewTicker(r.d)
loop:
for {
select {
case <-r.t.C:
r.emptyChan()
case <-r.ctx.Done():
break loop
}
}
r.log.Debugln("Stop backgroundRoutine")
}()
}
func (r *RateLimit) handleCtx() {
go func() {
<-r.ctx.Done()
r.log.Debugln("Stop Ticker")
r.t.Stop()
r.log.Debugln("Empty chan")
r.emptyChan()
r.log.Debugln("End of handleCtx")
}()
}
// WaitIfLimitReached wait if limit has been reached
// do not use IsLimitReached and WaitIFLimitReached in the same algo
func (r *RateLimit) WaitIfLimitReached() {
r.lastCall = time.Now()
for {
select {
case <-r.ctx.Done():
r.log.Debugln("End WaitIfLimitReached")
return
case r.ch <- struct{}{}:
return
default:
time.Sleep(10 * time.Millisecond)
}
}
}
// IsLimitReached returns true if limit has been reached
// do not use IsLimitReached and WaitIFLimitReached in the same algo
func (r *RateLimit) IsLimitReached() bool {
r.lastCall = time.Now()
if r.ctx.Err() != nil {
// program is going to be terminated
return false
}
select {
case r.ch <- struct{}{}:
return false
default:
return true
}
}
func (r *RateLimit) GetLastCall() time.Time {
return r.lastCall
}
func (r *RateLimit) emptyChan() {
if r.ctx.Err() == nil {
length := len(r.ch)
for i := 0; i < length; i++ {
_, ok := <-r.ch
if !ok {
break // channel is closed
}
}
}
}
func initLog(debugLevel string) *logrus.Logger {
l := logrus.New()
// Log as JSON instead of the default ASCII formatter.
//log.SetFormatter(&log.JSONFormatter{})
l.SetFormatter(&logrus.TextFormatter{
DisableColors: false,
FullTimestamp: false,
DisableTimestamp: true,
})
// Output to stdout instead of the default stderr
// Can be any io.Writer, see below for File example
l.SetOutput(os.Stdout)
switch debugLevel {
case "debug":
l.SetLevel(logrus.DebugLevel)
case "info":
l.SetLevel(logrus.InfoLevel)
case "warn":
l.SetLevel(logrus.WarnLevel)
case "error":
l.SetLevel(logrus.ErrorLevel)
default:
l.SetLevel(logrus.InfoLevel)
}
return l
}
// Stop close background Goroutine
func (r *RateLimit) Stop() {
r.log.Debugln("Stop Ticker")
r.t.Stop()
r.log.Debugln("Empty chan")
r.emptyChan()
time.Sleep(100 * time.Millisecond)
}