status-im/status-go

View on GitHub
waku/common/rate_limiter.go

Summary

Maintainability
A
0 mins
Test Coverage
D
62%
// Copyright 2019 The Waku Library Authors.
//
// The Waku library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The Waku library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty off
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the Waku library. If not, see <http://www.gnu.org/licenses/>.
//
// This software uses the go-ethereum library, which is licensed
// under the GNU Lesser General Public Library, version 3 or any later.

package common

import (
    "bytes"
    "errors"
    "fmt"
    "net"
    "time"

    "github.com/tsenart/tb"

    "github.com/status-im/status-go/common"

    "github.com/ethereum/go-ethereum/p2p"
    "github.com/ethereum/go-ethereum/p2p/enode"
)

var errRateLimitExceeded = errors.New("rate limit has been exceeded")

type runLoop func(rw p2p.MsgReadWriter) error

// RateLimiterPeer interface represents a Peer that is capable of being rate limited
type RateLimiterPeer interface {
    ID() []byte
    IP() net.IP
}

// RateLimiterHandler interface represents handler functionality for a Rate Limiter in the cases of
// exceeding a peer limit and exceeding an IP limit
type RateLimiterHandler interface {
    ExceedPeerLimit() error
    ExceedIPLimit() error
}

// MetricsRateLimiterHandler implements RateLimiterHandler, represents a handler for reporting rate limit Exceed data
// to the metrics collection service (currently prometheus)
type MetricsRateLimiterHandler struct{}

func (MetricsRateLimiterHandler) ExceedPeerLimit() error {
    RateLimitsExceeded.WithLabelValues("peer_id").Inc()
    return nil
}
func (MetricsRateLimiterHandler) ExceedIPLimit() error {
    RateLimitsExceeded.WithLabelValues("ip").Inc()
    return nil
}

// RateLimits contains information about rate limit settings.
// It's agnostic on what it's being rate limited on (bytes or number of packets currently)
// It's exchanged with the status-update packet code
type RateLimits struct {
    IPLimits     uint64 // amount per second from a single IP (default 0, no limits)
    PeerIDLimits uint64 // amount per second from a single peer ID (default 0, no limits)
    TopicLimits  uint64 // amount per second from a single topic (default 0, no limits)
}

func (r RateLimits) IsZero() bool {
    return r == (RateLimits{})
}

// DropPeerRateLimiterHandler implements RateLimiterHandler, represents a handler that introduces Tolerance to the
// number of Peer connections before Limit Exceeded errors are returned.
type DropPeerRateLimiterHandler struct {
    // Tolerance is a number by which a limit must be exceeded before a peer is dropped.
    Tolerance int64

    peerLimitExceeds int64
    ipLimitExceeds   int64
}

func (h *DropPeerRateLimiterHandler) ExceedPeerLimit() error {
    h.peerLimitExceeds++
    if h.Tolerance > 0 && h.peerLimitExceeds >= h.Tolerance {
        return errRateLimitExceeded
    }
    return nil
}

func (h *DropPeerRateLimiterHandler) ExceedIPLimit() error {
    h.ipLimitExceeds++
    if h.Tolerance > 0 && h.ipLimitExceeds >= h.Tolerance {
        return errRateLimitExceeded
    }
    return nil
}

// PeerRateLimiterConfig represents configurations for initialising a PeerRateLimiter
type PeerRateLimiterConfig struct {
    PacketLimitPerSecIP     int64
    PacketLimitPerSecPeerID int64
    BytesLimitPerSecIP      int64
    BytesLimitPerSecPeerID  int64
    WhitelistedIPs          []string
    WhitelistedPeerIDs      []enode.ID
}

var defaultPeerRateLimiterConfig = PeerRateLimiterConfig{
    PacketLimitPerSecIP:     10,
    PacketLimitPerSecPeerID: 5,
    BytesLimitPerSecIP:      1048576, // 1MB
    BytesLimitPerSecPeerID:  1048576, // 1MB
    WhitelistedIPs:          nil,
    WhitelistedPeerIDs:      nil,
}

// PeerRateLimiter represents a rate limiter that limits communication between Peers
type PeerRateLimiter struct {
    packetThrottler *tb.Throttler
    bytesThrottler  *tb.Throttler

    PacketLimitPerSecIP     int64
    PacketLimitPerSecPeerID int64

    BytesLimitPerSecIP     int64
    BytesLimitPerSecPeerID int64

    whitelistedPeerIDs []enode.ID
    whitelistedIPs     []string

    handlers []RateLimiterHandler
}

func NewPeerRateLimiter(cfg *PeerRateLimiterConfig, handlers ...RateLimiterHandler) *PeerRateLimiter {
    if cfg == nil {
        cfgCopy := defaultPeerRateLimiterConfig
        cfg = &cfgCopy
    }

    return &PeerRateLimiter{
        packetThrottler:         tb.NewThrottler(time.Millisecond * 100),
        bytesThrottler:          tb.NewThrottler(time.Millisecond * 100),
        PacketLimitPerSecIP:     cfg.PacketLimitPerSecIP,
        PacketLimitPerSecPeerID: cfg.PacketLimitPerSecPeerID,
        BytesLimitPerSecIP:      cfg.BytesLimitPerSecIP,
        BytesLimitPerSecPeerID:  cfg.BytesLimitPerSecPeerID,
        whitelistedPeerIDs:      cfg.WhitelistedPeerIDs,
        whitelistedIPs:          cfg.WhitelistedIPs,
        handlers:                handlers,
    }
}

func (r *PeerRateLimiter) Decorate(p RateLimiterPeer, rw p2p.MsgReadWriter, runLoop runLoop) error {
    errC := make(chan error, 1)

    in, out := p2p.MsgPipe()
    defer func() {
        if err := in.Close(); err != nil {
            // Don't block as otherwise we might leak go routines
            select {
            case errC <- err:
            default:
            }
        }
    }()
    defer func() {
        if err := out.Close(); err != nil {
            errC <- err
        }
    }()

    // Read from the original reader and write to the message pipe.
    go func() {
        defer common.LogOnPanic()
        for {
            packet, err := rw.ReadMsg()
            if err != nil {
                // Don't block as otherwise we might leak go routines
                select {
                case errC <- fmt.Errorf("failed to read packet: %v", err):
                    return
                default:
                    return
                }
            }

            RateLimitsProcessed.Inc()

            var ip string
            if p != nil {
                // this relies on <nil> being the string representation of nil
                // as IP() might return a nil value
                ip = p.IP().String()
            }
            if halted := r.throttleIP(ip, packet.Size); halted {
                for _, h := range r.handlers {
                    if err := h.ExceedIPLimit(); err != nil {
                        // Don't block as otherwise we might leak go routines
                        select {

                        case errC <- fmt.Errorf("exceed rate limit by IP: %v", err):
                            return
                        default:
                            return
                        }
                    }
                }
            }

            var peerID []byte
            if p != nil {
                peerID = p.ID()
            }
            if halted := r.throttlePeer(peerID, packet.Size); halted {
                for _, h := range r.handlers {
                    if err := h.ExceedPeerLimit(); err != nil {
                        // Don't block as otherwise we might leak go routines
                        select {
                        case errC <- fmt.Errorf("exceeded rate limit by peer: %v", err):
                            return
                        default:
                            return
                        }
                    }
                }
            }

            if err := in.WriteMsg(packet); err != nil {
                // Don't block as otherwise we might leak go routines
                select {
                case errC <- fmt.Errorf("failed to write packet to pipe: %v", err):
                    return
                default:
                    return
                }
            }
        }
    }()

    // Read from the message pipe and write to the original writer.
    go func() {
        defer common.LogOnPanic()
        for {
            packet, err := in.ReadMsg()
            if err != nil {
                // Don't block as otherwise we might leak go routines
                select {
                case errC <- fmt.Errorf("failed to read packet from pipe: %v", err):
                    return
                default:
                    return
                }
            }
            if err := rw.WriteMsg(packet); err != nil {
                // Don't block as otherwise we might leak go routines
                select {
                case errC <- fmt.Errorf("failed to write packet: %v", err):
                    return
                default:
                    return
                }
            }
        }
    }()

    go func() {
        defer common.LogOnPanic()
        // Don't block as otherwise we might leak go routines
        select {
        case errC <- runLoop(out):
            return
        default:
            return
        }
    }()

    return <-errC
}

// throttleIP throttles packets incoming from a given IP.
func (r *PeerRateLimiter) throttleIP(ip string, size uint32) bool {
    if stringSliceContains(r.whitelistedIPs, ip) {
        return false
    }

    var packetLimiterResponse bool
    var bytesLimiterResponse bool

    if r.PacketLimitPerSecIP != 0 {
        packetLimiterResponse = r.packetThrottler.Halt(ip, 1, r.PacketLimitPerSecIP)
    }
    if r.BytesLimitPerSecIP != 0 {
        bytesLimiterResponse = r.bytesThrottler.Halt(ip, int64(size), r.BytesLimitPerSecIP)
    }

    return packetLimiterResponse || bytesLimiterResponse
}

// throttlePeer throttles packets incoming from a peer.
func (r *PeerRateLimiter) throttlePeer(peerID []byte, size uint32) bool {
    var id enode.ID
    copy(id[:], peerID)
    if enodeIDSliceContains(r.whitelistedPeerIDs, id) {
        return false
    }

    var packetLimiterResponse bool
    var bytesLimiterResponse bool

    if r.PacketLimitPerSecPeerID != 0 {
        packetLimiterResponse = r.packetThrottler.Halt(id.String(), 1, r.PacketLimitPerSecPeerID)
    }

    if r.BytesLimitPerSecPeerID != 0 {
        bytesLimiterResponse = r.bytesThrottler.Halt(id.String(), int64(size), r.BytesLimitPerSecPeerID)
    }

    return packetLimiterResponse || bytesLimiterResponse
}

func stringSliceContains(s []string, searched string) bool {
    for _, item := range s {
        if item == searched {
            return true
        }
    }
    return false
}

func enodeIDSliceContains(s []enode.ID, searched enode.ID) bool {
    for _, item := range s {
        if bytes.Equal(item.Bytes(), searched.Bytes()) {
            return true
        }
    }
    return false
}