status-im/status-go

View on GitHub
discovery/muxer_test.go

Summary

Maintainability
A
0 mins
Test Coverage
package discovery

import (
    "errors"
    "sync"
    "testing"
    "time"

    "github.com/stretchr/testify/require"

    "github.com/ethereum/go-ethereum/p2p/discv5"
)

func newRegistry() *registry {
    return &registry{
        storage: map[string][]int{},
    }
}

type registry struct {
    mu      sync.Mutex
    storage map[string][]int
}

func (r *registry) Add(topic string, id int) {
    r.mu.Lock()
    defer r.mu.Unlock()
    r.storage[topic] = append(r.storage[topic], id)
}

func (r *registry) Get(topic string) []int {
    r.mu.Lock()
    defer r.mu.Unlock()
    return r.storage[topic]
}

type fake struct {
    started  bool
    err      error
    id       int
    registry *registry
}

func (f *fake) Start() error {
    if f.err != nil {
        return f.err
    }
    f.started = true
    return nil
}

func (f *fake) Stop() error {
    f.started = false
    if f.err != nil {
        return f.err
    }
    return nil
}

func (f *fake) Running() bool {
    return f.started
}

func (f *fake) Register(topic string, stop chan struct{}) error {
    if f.err != nil {
        return f.err
    }
    f.registry.Add(topic, f.id)
    return nil
}

func (f *fake) Discover(topic string, period <-chan time.Duration, found chan<- *discv5.Node, lookup chan<- bool) error {
    if f.err != nil {
        return f.err
    }
    for _, n := range f.registry.Get(topic) {
        found <- discv5.NewNode(discv5.NodeID{byte(n)}, nil, 0, 0)
    }
    return nil
}

type testErrorCase struct {
    desc   string
    errors []error
}

func errorCases() []testErrorCase {
    return []testErrorCase{
        {desc: "SingleError", errors: []error{nil, errors.New("test")}},
        {desc: "NoErrors", errors: []error{nil, nil}},
        {desc: "AllErrors", errors: []error{errors.New("test"), errors.New("test")}},
    }
}

func TestMuxerStart(t *testing.T) {
    for _, tc := range errorCases() {
        t.Run(tc.desc, func(t *testing.T) {
            discoveries := make([]Discovery, len(tc.errors))
            erred := false
            for i, err := range tc.errors {
                if err != nil {
                    erred = true
                }
                discoveries[i] = &fake{err: err}
            }
            muxer := NewMultiplexer(discoveries)
            if erred {
                require.Error(t, muxer.Start())
            } else {
                require.NoError(t, muxer.Start())
            }
            for _, d := range discoveries {
                require.Equal(t, !erred, d.Running())
            }
        })
    }
}

func TestMuxerStop(t *testing.T) {
    for _, tc := range errorCases() {
        t.Run(tc.desc, func(t *testing.T) {
            discoveries := make([]Discovery, len(tc.errors))
            erred := false
            for i, err := range tc.errors {
                if err != nil {
                    erred = true
                }
                discoveries[i] = &fake{started: true, err: err}
            }
            muxer := NewMultiplexer(discoveries)
            if erred {
                require.Error(t, muxer.Stop())
            } else {
                require.NoError(t, muxer.Stop())
            }
            for _, d := range discoveries {
                require.False(t, d.Running())
            }
        })
    }
}

func TestMuxerRunning(t *testing.T) {
    for _, tc := range []struct {
        desc    string
        started []bool
    }{
        {desc: "FirstRunning", started: []bool{false, true}},
        {desc: "SecondRunning", started: []bool{true, false}},
        {desc: "AllRunning", started: []bool{true, true}},
        {desc: "NoRunning", started: []bool{false, false}},
    } {
        t.Run(tc.desc, func(t *testing.T) {
            discoveries := make([]Discovery, len(tc.started))
            allstarted := false
            for i, start := range tc.started {
                allstarted = start || allstarted
                discoveries[i] = &fake{started: start}
            }
            require.Equal(t, allstarted, NewMultiplexer(discoveries).Running())
        })
    }
}

func TestMuxerRegister(t *testing.T) {
    for _, tc := range []struct {
        desc   string
        errors []error
        topics []string
    }{
        {"NoErrors", []error{nil, nil, nil}, []string{"a"}},
        {"MultipleTopics", []error{nil, nil, nil}, []string{"a", "b", "c"}},
        {"SingleError", []error{nil, errors.New("test"), nil}, []string{"a"}},
        {"AllErrors", []error{errors.New("test"), errors.New("test"), errors.New("test")}, []string{"a"}},
    } {
        t.Run(tc.desc, func(t *testing.T) {
            reg := newRegistry()
            discoveries := make([]Discovery, len(tc.errors))
            erred := 0
            for i := range discoveries {
                if tc.errors[i] != nil {
                    erred++
                }
                discoveries[i] = &fake{id: i, err: tc.errors[i], registry: reg}
            }
            muxer := NewMultiplexer(discoveries)
            for _, topic := range tc.topics {
                if erred != 0 {
                    require.Error(t, muxer.Register(topic, nil))
                } else {
                    require.NoError(t, muxer.Register(topic, nil))
                }
                require.Equal(t, len(discoveries)-erred, len(reg.Get(topic)))
            }
        })
    }
}

func TestMuxerDiscovery(t *testing.T) {
    for _, tc := range []struct {
        desc   string
        errors []error
        topics []string
        ids    [][]int
    }{
        {"EqualNoErrors", []error{nil, nil}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}},
        {"MultiTopicsSingleSource", []error{nil, nil}, []string{"a", "b"}, [][]int{{11, 22, 33}, {}}},
        {"SingleError", []error{nil, errors.New("test")}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}},
        {"AllErrors", []error{errors.New("test"), errors.New("test")}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}},
    } {
        t.Run(tc.desc, func(t *testing.T) {
            discoveries := make([]Discovery, len(tc.errors))
            erred := false
            expected := 0
            for i := range discoveries {
                if tc.errors[i] == nil {
                    expected += len(tc.ids[i])
                } else {
                    erred = true
                }
                reg := newRegistry()
                discoveries[i] = &fake{id: i, err: tc.errors[i], registry: reg}
                for _, topic := range tc.topics {
                    for _, id := range tc.ids[i] {
                        reg.Add(topic, id)
                    }
                }
            }
            muxer := NewMultiplexer(discoveries)
            for _, topic := range tc.topics {
                found := make(chan *discv5.Node, expected)
                period := make(chan time.Duration)
                close(period)
                if erred {
                    // TODO test period channel
                    require.Error(t, muxer.Discover(topic, period, found, nil))
                } else {
                    require.NoError(t, muxer.Discover(topic, period, found, nil))
                }
                close(found)
                count := 0
                for range found {
                    count++
                }
                require.Equal(t, expected, count)
            }
        })
    }
}