status-im/status-go

View on GitHub
waku/common/rate_limiter_test.go

Summary

Maintainability
A
0 mins
Test Coverage
// Copyright 2019 The Waku Library Authors.
//
// The Waku library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The Waku library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty off
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the Waku library. If not, see <http://www.gnu.org/licenses/>.
//
// This software uses the go-ethereum library, which is licensed
// under the GNU Lesser General Public Library, version 3 or any later.

package common

import (
    "bytes"
    "net"
    "testing"
    "time"

    "github.com/stretchr/testify/require"

    "github.com/ethereum/go-ethereum/p2p"
    "github.com/ethereum/go-ethereum/p2p/enode"
)

func TestPeerRateLimiterDecorator(t *testing.T) {
    in, out := p2p.MsgPipe()
    payload := []byte{0x01, 0x02, 0x03}
    msg := p2p.Msg{
        Code:       1,
        Size:       uint32(len(payload)),
        Payload:    bytes.NewReader(payload),
        ReceivedAt: time.Now(),
    }

    go func() {
        err := in.WriteMsg(msg)
        require.NoError(t, err)
    }()

    messages := make(chan p2p.Msg, 1)
    runLoop := func(rw p2p.MsgReadWriter) error {
        msg, err := rw.ReadMsg()
        if err != nil {
            return err
        }
        messages <- msg
        return nil
    }

    r := NewPeerRateLimiter(nil, &mockRateLimiterHandler{})
    err := r.Decorate(nil, out, runLoop)
    require.NoError(t, err)

    receivedMsg := <-messages
    receivedPayload := make([]byte, receivedMsg.Size)
    _, err = receivedMsg.Payload.Read(receivedPayload)
    require.NoError(t, err)
    require.Equal(t, msg.Code, receivedMsg.Code)
    require.Equal(t, payload, receivedPayload)
}

func TestPeerLimiterThrottlingWithZeroLimit(t *testing.T) {
    r := NewPeerRateLimiter(&PeerRateLimiterConfig{}, &mockRateLimiterHandler{})
    for i := 0; i < 1000; i++ {
        throttle := r.throttleIP("<nil>", 0)
        require.False(t, throttle)
        throttle = r.throttlePeer([]byte{0x01, 0x02, 0x03}, 0)
        require.False(t, throttle)
    }
}

func TestPeerPacketLimiterHandler(t *testing.T) {
    h := &mockRateLimiterHandler{}
    r := NewPeerRateLimiter(nil, h)
    p := &TestWakuPeer{
        peer: p2p.NewPeer(enode.ID{0xaa, 0xbb, 0xcc}, "test-peer", nil),
    }
    rw1, rw2 := p2p.MsgPipe()
    count := 100

    go func() {
        err := echoMessages(r, p, rw2)
        require.NoError(t, err)
    }()

    done := make(chan struct{})
    go func() {
        for i := 0; i < count; i++ {
            msg, err := rw1.ReadMsg()
            require.NoError(t, err)
            require.EqualValues(t, 101, msg.Code)
        }
        close(done)
    }()

    for i := 0; i < count; i++ {
        err := rw1.WriteMsg(p2p.Msg{Code: 101})
        require.NoError(t, err)
    }

    <-done

    require.EqualValues(t, 100-defaultPeerRateLimiterConfig.PacketLimitPerSecIP, h.exceedIPLimit)
    require.EqualValues(t, 100-defaultPeerRateLimiterConfig.PacketLimitPerSecPeerID, h.exceedPeerLimit)
}

func TestPeerBytesLimiterHandler(t *testing.T) {
    h := &mockRateLimiterHandler{}
    r := NewPeerRateLimiter(&PeerRateLimiterConfig{
        BytesLimitPerSecIP:     30,
        BytesLimitPerSecPeerID: 30,
    }, h)
    p := &TestWakuPeer{
        peer: p2p.NewPeer(enode.ID{0xaa, 0xbb, 0xcc}, "test-peer", nil),
    }
    rw1, rw2 := p2p.MsgPipe()
    count := 6

    go func() {
        err := echoMessages(r, p, rw2)
        require.NoError(t, err)
    }()

    done := make(chan struct{})
    go func() {
        for i := 0; i < count; i++ {
            msg, err := rw1.ReadMsg()
            require.NoError(t, err)
            require.EqualValues(t, 101, msg.Code)
            require.NoError(t, msg.Discard())
        }
        close(done)
    }()

    for i := 0; i < count; i++ {
        payload := make([]byte, 10)
        msg := p2p.Msg{
            Code:    101,
            Size:    uint32(len(payload)),
            Payload: bytes.NewReader(payload),
        }

        err := rw1.WriteMsg(msg)
        require.NoError(t, err)
    }

    <-done

    require.EqualValues(t, 3, h.exceedIPLimit)
    require.EqualValues(t, 3, h.exceedPeerLimit)
}

func TestPeerPacketLimiterHandlerWithWhitelisting(t *testing.T) {
    h := &mockRateLimiterHandler{}
    r := NewPeerRateLimiter(&PeerRateLimiterConfig{
        PacketLimitPerSecIP:     1,
        PacketLimitPerSecPeerID: 1,
        WhitelistedIPs:          []string{"<nil>"}, // no IP is represented as <nil> string
        WhitelistedPeerIDs:      []enode.ID{{0xaa, 0xbb, 0xcc}},
    }, h)
    p := &TestWakuPeer{
        peer: p2p.NewPeer(enode.ID{0xaa, 0xbb, 0xcc}, "test-peer", nil),
    }
    rw1, rw2 := p2p.MsgPipe()
    count := 100

    go func() {
        err := echoMessages(r, p, rw2)
        require.NoError(t, err)
    }()

    done := make(chan struct{})
    go func() {
        for i := 0; i < count; i++ {
            msg, err := rw1.ReadMsg()
            require.NoError(t, err)
            require.EqualValues(t, 101, msg.Code)
        }
        close(done)
    }()

    for i := 0; i < count; i++ {
        err := rw1.WriteMsg(p2p.Msg{Code: 101})
        require.NoError(t, err)
    }

    <-done

    require.Equal(t, 0, h.exceedIPLimit)
    require.Equal(t, 0, h.exceedPeerLimit)
}

func echoMessages(r *PeerRateLimiter, p RateLimiterPeer, rw p2p.MsgReadWriter) error {
    return r.Decorate(p, rw, func(rw p2p.MsgReadWriter) error {
        for {
            msg, err := rw.ReadMsg()
            if err != nil {
                return err
            }
            err = rw.WriteMsg(msg)
            if err != nil {
                return err
            }
        }
    })
}

type mockRateLimiterHandler struct {
    exceedPeerLimit int
    exceedIPLimit   int
}

func (m *mockRateLimiterHandler) ExceedPeerLimit() error { m.exceedPeerLimit++; return nil }
func (m *mockRateLimiterHandler) ExceedIPLimit() error   { m.exceedIPLimit++; return nil }

type TestWakuPeer struct {
    peer *p2p.Peer
}

func (p *TestWakuPeer) IP() net.IP {
    return p.peer.Node().IP()
}

func (p *TestWakuPeer) ID() []byte {
    id := p.peer.ID()
    return id[:]
}