config.go

Summary

Maintainability
A
0 mins
Test Coverage
A
100%
package retry

import (
    "errors"
    "fmt"
    "log"
    "math"
    "time"

    "golang.org/x/sync/errgroup"
)

type mode byte

const (
    // Simple mode - time increases by sleep + jitter*attempt.
    Simple mode = 0
    // Linear mode - time increases by sleep*attempt + jitter.
    Linear mode = 1
    // Exponential mode - time increases by sleep*2^attempt + jitter.
    Exponential mode = 2
    // Fibonacci mode - time increases by sleep*fibonacci(attempt) + jitter.
    Fibonacci mode = 3
)

const (
    minParallel = 0
    minCount    = 1
    two         = 2
    minSleep    = time.Second / 2
    minDuration = time.Duration(0)
)

// Step represents a single execution step to re-try.
type Step struct {
    Func func() error
    Name string
}

// Config holds configuration.
type Config struct {
    fatal       []error
    sleep       time.Duration
    jitter      time.Duration
    count       int
    parallelism int
    mode        mode
    verbose     bool
}

// New creates new `Config` with given options
// If no options given default configuration will
// be applied: 1 retry in 1 second.
func New(opts ...option) (c *Config) {
    c = &Config{}

    for _, o := range opts {
        o(c)
    }

    c.validate()

    return c
}

// Single executes 'fn', until no error returned, at most `Count` times (default is 1,
// so `fn` will be executed at most 2 times), each execution delayed on time given
// as `Sleep` option (default is 1 second).
func (c *Config) Single(name string, fn func() error) (err error) {
    for n := 0; n < c.count; n++ {
        if err = fn(); err == nil {
            return nil
        }

        if c.isFatal(err) {
            break
        }

        if c.verbose {
            log.Printf("step %s:%d err: %v", name, n, err)
        }

        if n < c.count {
            time.Sleep(c.stepDuration(n + 1))
        }
    }

    return fmt.Errorf("%s: %w", name, err)
}

// Chain executes several `steps` one by one, returning first error.
func (c *Config) Chain(steps ...Step) (err error) {
    var step *Step

    for i := 0; i < len(steps); i++ {
        step = &steps[i]

        if err = c.Single(step.Name, step.Func); err != nil {
            return fmt.Errorf("chain: %w", err)
        }
    }

    return nil
}

// Parallel executes several `steps` in parallel.
func (c *Config) Parallel(steps ...Step) (err error) {
    var eg errgroup.Group

    if c.parallelism > 0 {
        eg.SetLimit(c.parallelism)
    }

    for i := 0; i < len(steps); i++ {
        step := steps[i]

        eg.Go(func() error {
            return c.Single(step.Name, step.Func)
        })
    }

    if err = eg.Wait(); err != nil {
        return fmt.Errorf("parallel: %w", err)
    }

    return nil
}

func (c *Config) validate() {
    if c.count < minCount {
        c.count = minCount
    }

    if c.sleep <= minDuration {
        c.sleep = minSleep
    }

    if c.jitter < minDuration {
        c.jitter = minDuration
    }

    if c.parallelism < minParallel {
        c.parallelism = minParallel
    }
}

func (c *Config) isFatal(err error) (yes bool) {
    for i := 0; i < len(c.fatal); i++ {
        if yes = errors.Is(c.fatal[i], err); yes {
            return true
        }
    }

    return false
}

func (c *Config) stepDuration(n int) (d time.Duration) {
    switch c.mode {
    case Linear:
        return c.sleep*time.Duration(n) + c.jitter
    case Exponential:
        return c.sleep*time.Duration(ipow2(n)) + c.jitter
    case Fibonacci:
        return c.sleep*time.Duration(fibonacci(n)) + c.jitter
    }

    return c.sleep + c.jitter*time.Duration(n)
}

func ipow2(v int) (rv int64) {
    return int64(math.Pow(two, float64(v)))
}

func fibonacci(n int) int64 {
    if n < two {
        return int64(n)
    }

    return fibonacci(n-1) + fibonacci(n-two)
}