status-im/status-go

View on GitHub
wakuv2/common/filter_test.go

Summary

Maintainability
A
50 mins
Test Coverage
package common

import (
    crand "crypto/rand"
    mrand "math/rand"
    "testing"
    "time"

    "github.com/stretchr/testify/require"
    "go.uber.org/zap"
    "golang.org/x/exp/maps"
    "google.golang.org/protobuf/proto"

    "github.com/waku-org/go-waku/waku/v2/payload"
    "github.com/waku-org/go-waku/waku/v2/protocol"
    "github.com/waku-org/go-waku/waku/v2/protocol/pb"

    "github.com/ethereum/go-ethereum/common"
    "github.com/ethereum/go-ethereum/crypto"
)

const testShard = "/waku/2/rs/16/32"

type FilterTestCase struct {
    f      *Filter
    id     string
    alive  bool
    msgCnt int
}

func createLogger(t *testing.T) *zap.Logger {
    config := zap.NewDevelopmentConfig()
    config.Level = zap.NewAtomicLevelAt(zap.DebugLevel)
    logger, err := config.Build()
    require.NoError(t, err)
    return logger
}

func generateFilter(t *testing.T, symmetric bool) (*Filter, error) {
    var f Filter
    f.Messages = NewMemoryMessageStore()

    f.PubsubTopic = "test"

    const topicNum = 8
    f.ContentTopics = make(TopicSet, topicNum)
    for i := 0; i < topicNum; i++ {
        topic := make([]byte, 4)
        _, err := crand.Read(topic) // nolint: gosec
        require.NoError(t, err)
        topic[0] = 0x01

        f.ContentTopics[BytesToTopic(topic)] = struct{}{}
    }

    key, err := crypto.GenerateKey()
    require.NoError(t, err)

    f.Src = &key.PublicKey

    if symmetric {
        f.KeySym = make([]byte, AESKeyLength)
        _, err := crand.Read(f.KeySym) // nolint: gosec
        require.NoError(t, err)
        f.SymKeyHash = crypto.Keccak256Hash(f.KeySym)
    } else {
        f.KeyAsym, err = crypto.GenerateKey()
        require.NoError(t, err)
    }

    return &f, nil
}

func generateTestCases(t *testing.T, SizeTestFilters int) []FilterTestCase {
    cases := make([]FilterTestCase, SizeTestFilters)
    for i := 0; i < SizeTestFilters; i++ {
        f, _ := generateFilter(t, true)
        cases[i].f = f
        cases[i].alive = mrand.Int()&1 == 0 // nolint: gosec
    }
    return cases
}

func TestInstallFilters(t *testing.T) {
    const SizeTestFilters = 256
    filters := NewFilters(testShard, createLogger(t))
    tst := generateTestCases(t, SizeTestFilters)

    var err error
    var j string
    for i := 0; i < SizeTestFilters; i++ {
        j, err = filters.Install(tst[i].f)
        require.NoError(t, err)

        tst[i].id = j
        require.Len(t, j, KeyIDSize*2)
    }

    for _, testCase := range tst {
        if !testCase.alive {
            filters.Uninstall(testCase.id)
        }
    }

    for _, testCase := range tst {
        fil := filters.Get(testCase.id)
        exist := fil != nil
        require.Equal(t, exist, testCase.alive)
    }
}

func TestInstallSymKeyGeneratesHash(t *testing.T) {
    filters := NewFilters(testShard, createLogger(t))
    filter, _ := generateFilter(t, true)

    // save the current SymKeyHash for comparison
    initialSymKeyHash := filter.SymKeyHash

    // ensure the SymKeyHash is invalid, for Install to recreate it
    var invalid common.Hash
    filter.SymKeyHash = invalid

    _, err := filters.Install(filter)
    require.NoError(t, err)

    for i, b := range filter.SymKeyHash {
        require.Equal(t, b, initialSymKeyHash[i])
    }
}

func TestInstallIdenticalFilters(t *testing.T) {
    filters := NewFilters(testShard, createLogger(t))
    filter1, _ := generateFilter(t, true)

    // Copy the first filter since some of its fields
    // are randomly gnerated.
    filter2 := &Filter{
        KeySym:        filter1.KeySym,
        PubsubTopic:   filter1.PubsubTopic,
        ContentTopics: filter1.ContentTopics,
        Messages:      NewMemoryMessageStore(),
    }

    _, err := filters.Install(filter1)
    require.NoError(t, err)

    _, err = filters.Install(filter2)
    require.NoError(t, err)

    recvMessage := generateCompatibleReceivedMessage(t, filter1)
    msg := recvMessage.Open(filter1)
    require.NotNil(t, msg)
}

func TestInstallFilterWithSymAndAsymKeys(t *testing.T) {
    filters := NewFilters(testShard, createLogger(t))
    filter1, _ := generateFilter(t, true)

    asymKey, err := crypto.GenerateKey()
    require.NoError(t, err)

    // Copy the first filter since some of its fields
    // are randomly gnerated.
    filter := &Filter{
        KeySym:        filter1.KeySym,
        KeyAsym:       asymKey,
        PubsubTopic:   filter1.PubsubTopic,
        ContentTopics: filter1.ContentTopics,
        Messages:      NewMemoryMessageStore(),
    }

    _, err = filters.Install(filter)
    require.Error(t, err)
}

func cloneFilter(orig *Filter) *Filter {
    var clone Filter
    clone.Messages = NewMemoryMessageStore()
    clone.Src = orig.Src
    clone.KeyAsym = orig.KeyAsym
    clone.KeySym = orig.KeySym
    clone.PubsubTopic = orig.PubsubTopic
    clone.ContentTopics = orig.ContentTopics
    clone.SymKeyHash = orig.SymKeyHash
    return &clone
}

func generateCompatibleReceivedMessage(t *testing.T, f *Filter) *ReceivedMessage {
    keyInfo := &payload.KeyInfo{}
    keyInfo.Kind = payload.Symmetric
    keyInfo.SymKey = f.KeySym

    var version uint32 = 1
    p := new(payload.Payload)
    p.Data = make([]byte, 20)
    _, err := crand.Read(p.Data) // nolint: gosec
    require.NoError(t, err)
    p.Key = keyInfo
    payload, err := p.Encode(version)
    require.NoError(t, err)

    msg := &pb.WakuMessage{
        Payload:      payload,
        Version:      &version,
        ContentTopic: maps.Keys(f.ContentTopics)[2].ContentTopic(),
        Timestamp:    proto.Int64(time.Now().UnixNano()),
        Meta:         []byte{},
    }
    envelope := protocol.NewEnvelope(msg, time.Now().UnixNano(), f.PubsubTopic)

    result := NewReceivedMessage(envelope, "test")
    result.SymKeyHash = crypto.Keccak256Hash(f.KeySym)

    return result
}

func TestWatchers(t *testing.T) {
    const NumFilters = 16
    const NumMessages = 256
    var i int
    var j uint32
    var e *ReceivedMessage
    var x, firstID string
    var err error

    filters := NewFilters("/waku/2/rs/16/32", createLogger(t))
    tst := generateTestCases(t, NumFilters)
    for i = 0; i < NumFilters; i++ {
        tst[i].f.Src = nil
        x, err = filters.Install(tst[i].f)
        require.NoError(t, err)

        tst[i].id = x
        if len(firstID) == 0 {
            firstID = x
        }
    }

    lastID := x

    var envelopes [NumMessages]*ReceivedMessage
    for i = 0; i < NumMessages; i++ {
        j = mrand.Uint32() % NumFilters // nolint: gosec
        e = generateCompatibleReceivedMessage(t, tst[j].f)
        envelopes[i] = e
        tst[j].msgCnt++
    }

    for i = 0; i < NumMessages; i++ {
        filters.NotifyWatchers(envelopes[i])
    }

    var total int
    var mail []*ReceivedMessage
    var count [NumFilters]int

    for i = 0; i < NumFilters; i++ {
        mail = tst[i].f.Retrieve()
        count[i] = len(mail)
        total += len(mail)
    }
    require.Equal(t, total, NumMessages)

    for i = 0; i < NumFilters; i++ {
        mail = tst[i].f.Retrieve()
        require.Zero(t, len(mail))
        require.Equal(t, tst[i].msgCnt, count[i])
    }

    // another round with a cloned filter

    clone := cloneFilter(tst[0].f)
    filters.Uninstall(lastID)
    total = 0
    last := NumFilters - 1
    tst[last].f = clone
    _, err = filters.Install(clone)
    require.NoError(t, err)

    for i = 0; i < NumFilters; i++ {
        tst[i].msgCnt = 0
        count[i] = 0
    }

    // make sure that the first watcher receives at least one message
    e = generateCompatibleReceivedMessage(t, tst[0].f)
    envelopes[0] = e
    tst[0].msgCnt++
    for i = 1; i < NumMessages; i++ {
        j = mrand.Uint32() % NumFilters // nolint: gosec
        e = generateCompatibleReceivedMessage(t, tst[j].f)
        envelopes[i] = e
        tst[j].msgCnt++
    }

    for i = 0; i < NumMessages; i++ {
        filters.NotifyWatchers(envelopes[i])
    }

    for i = 0; i < NumFilters; i++ {
        mail = tst[i].f.Retrieve()
        count[i] = len(mail)
        total += len(mail)
    }

    combined := tst[0].msgCnt + tst[last].msgCnt
    require.Equal(t, total, NumMessages+count[0])
    require.Equal(t, combined, count[0])
    require.Equal(t, combined, count[last])

    for i = 1; i < NumFilters-1; i++ {
        mail = tst[i].f.Retrieve()
        require.Zero(t, len(mail))
        require.Equal(t, tst[i].msgCnt, count[i])
    }
}