status-im/status-go

View on GitHub
mailserver/mailserver_test.go

Summary

Maintainability
A
0 mins
Test Coverage
// Copyright 2017 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

package mailserver

import (
    "crypto/ecdsa"
    "encoding/binary"
    "errors"
    "fmt"
    "testing"
    "time"

    "github.com/stretchr/testify/suite"

    "github.com/ethereum/go-ethereum/common"
    "github.com/ethereum/go-ethereum/crypto"
    "github.com/ethereum/go-ethereum/rlp"

    "github.com/status-im/status-go/eth-node/types"
    "github.com/status-im/status-go/params"
    waku "github.com/status-im/status-go/waku"
    wakucommon "github.com/status-im/status-go/waku/common"
)

const powRequirement = 0.00001

var keyID string
var seed = time.Now().Unix()
var testPayload = []byte("test payload")

type ServerTestParams struct {
    topic types.TopicType
    birth uint32
    low   uint32
    upp   uint32
    limit uint32
    key   *ecdsa.PrivateKey
}

func TestMailserverSuite(t *testing.T) {
    suite.Run(t, new(MailserverSuite))
}

type MailserverSuite struct {
    suite.Suite
    server  *WakuMailServer
    shh     *waku.Waku
    config  *params.WakuConfig
    dataDir string
}

func (s *MailserverSuite) SetupTest() {
    s.server = &WakuMailServer{}
    s.shh = waku.New(&waku.DefaultConfig, nil)
    s.shh.RegisterMailServer(s.server)

    tmpDir := s.T().TempDir()
    s.dataDir = tmpDir

    s.config = &params.WakuConfig{
        DataDir:            tmpDir,
        MailServerPassword: "testpassword",
    }
}

func (s *MailserverSuite) TestInit() {
    testCases := []struct {
        config        params.WakuConfig
        expectedError error
        info          string
    }{
        {
            config:        params.WakuConfig{DataDir: ""},
            expectedError: errDirectoryNotProvided,
            info:          "config with empty DataDir",
        },
        {
            config: params.WakuConfig{
                DataDir:            s.config.DataDir,
                MailServerPassword: "pwd",
            },
            expectedError: nil,
            info:          "config with correct DataDir and Password",
        },
        {
            config: params.WakuConfig{
                DataDir:             s.config.DataDir,
                MailServerPassword:  "pwd",
                MailServerRateLimit: 5,
            },
            expectedError: nil,
            info:          "config with rate limit",
        },
    }

    for _, testCase := range testCases {
        // to satisfy gosec: C601 checks
        tc := testCase
        s.T().Run(tc.info, func(*testing.T) {
            mailServer := &WakuMailServer{}
            shh := waku.New(&waku.DefaultConfig, nil)
            shh.RegisterMailServer(mailServer)

            err := mailServer.Init(shh, &tc.config)
            s.Require().Equal(tc.expectedError, err)
            if err == nil {
                defer mailServer.Close()
            }

            // db should be open only if there was no error
            if tc.expectedError == nil {
                s.NotNil(mailServer.ms.db)
            } else {
                s.Nil(mailServer.ms)
            }

            if tc.config.MailServerRateLimit > 0 {
                s.NotNil(mailServer.ms.rateLimiter)
            }
        })
    }
}

func (s *MailserverSuite) TestArchive() {
    config := *s.config

    err := s.server.Init(s.shh, &config)
    s.Require().NoError(err)
    defer s.server.Close()

    env, err := generateEnvelope(time.Now())
    s.NoError(err)
    rawEnvelope, err := rlp.EncodeToBytes(env)
    s.NoError(err)

    s.server.Archive(env)
    key := NewDBKey(env.Expiry-env.TTL, types.TopicType(env.Topic), types.Hash(env.Hash()))
    archivedEnvelope, err := s.server.ms.db.GetEnvelope(key)
    s.NoError(err)

    s.Equal(rawEnvelope, archivedEnvelope)
}

func (s *MailserverSuite) TestManageLimits() {
    err := s.server.Init(s.shh, s.config)
    s.NoError(err)
    s.server.ms.rateLimiter = newRateLimiter(time.Duration(5) * time.Millisecond)
    s.False(s.server.ms.exceedsPeerRequests(types.BytesToHash([]byte("peerID"))))
    s.Equal(1, len(s.server.ms.rateLimiter.db))
    firstSaved := s.server.ms.rateLimiter.db["peerID"]

    // second call when limit is not accomplished does not store a new limit
    s.True(s.server.ms.exceedsPeerRequests(types.BytesToHash([]byte("peerID"))))
    s.Equal(1, len(s.server.ms.rateLimiter.db))
    s.Equal(firstSaved, s.server.ms.rateLimiter.db["peerID"])
}

func (s *MailserverSuite) TestDBKey() {
    var h types.Hash
    var emptyTopic types.TopicType
    i := uint32(time.Now().Unix())
    k := NewDBKey(i, emptyTopic, h)
    s.Equal(len(k.Bytes()), DBKeyLength, "wrong DB key length")
    s.Equal(byte(i%0x100), k.Bytes()[3], "raw representation should be big endian")
    s.Equal(byte(i/0x1000000), k.Bytes()[0], "big endian expected")
}

func (s *MailserverSuite) TestRequestPaginationLimit() {
    s.setupServer(s.server)
    defer s.server.Close()

    var (
        sentEnvelopes  []*wakucommon.Envelope
        sentHashes     []common.Hash
        receivedHashes []common.Hash
        archiveKeys    []string
    )

    now := time.Now()
    count := uint32(10)

    for i := count; i > 0; i-- {
        sentTime := now.Add(time.Duration(-i) * time.Second)
        env, err := generateEnvelope(sentTime)
        s.NoError(err)
        s.server.Archive(env)
        key := NewDBKey(env.Expiry-env.TTL, types.TopicType(env.Topic), types.Hash(env.Hash()))
        archiveKeys = append(archiveKeys, fmt.Sprintf("%x", key.Cursor()))
        sentEnvelopes = append(sentEnvelopes, env)
        sentHashes = append(sentHashes, env.Hash())
    }

    reqLimit := uint32(6)
    peerID, request, err := s.prepareRequest(sentEnvelopes, reqLimit)
    s.NoError(err)
    payload, err := s.server.decompositeRequest(peerID, request)
    s.NoError(err)
    s.Nil(payload.Cursor)
    s.Equal(reqLimit, payload.Limit)

    receivedHashes, cursor, _ := processRequestAndCollectHashes(s.server, payload)

    // 10 envelopes sent
    s.Equal(count, uint32(len(sentEnvelopes)))
    // 6 envelopes received
    s.Len(receivedHashes, int(payload.Limit))
    // the 6 envelopes received should be in forward order
    s.Equal(sentHashes[:payload.Limit], receivedHashes)
    // cursor should be the key of the last envelope of the last page
    s.Equal(archiveKeys[payload.Limit-1], fmt.Sprintf("%x", cursor))

    // second page
    payload.Cursor = cursor
    receivedHashes, cursor, _ = processRequestAndCollectHashes(s.server, payload)

    // 4 envelopes received
    s.Equal(int(count-payload.Limit), len(receivedHashes))
    // cursor is nil because there are no other pages
    s.Nil(cursor)
}

func (s *MailserverSuite) TestMailServer() {
    s.setupServer(s.server)
    defer s.server.Close()

    env, err := generateEnvelope(time.Now())
    s.NoError(err)

    s.server.Archive(env)

    testCases := []struct {
        params *ServerTestParams
        expect bool
        isOK   bool
        info   string
    }{
        {
            params: s.defaultServerParams(env),
            expect: true,
            isOK:   true,
            info:   "Processing a request where from and to are equal to an existing register, should provide results",
        },
        {
            params: func() *ServerTestParams {
                params := s.defaultServerParams(env)
                params.low = params.birth + 1
                params.upp = params.birth + 1

                return params
            }(),
            expect: false,
            isOK:   true,
            info:   "Processing a request where from and to are greater than any existing register, should not provide results",
        },
        {
            params: func() *ServerTestParams {
                params := s.defaultServerParams(env)
                params.upp = params.birth + 1
                params.topic[0] = 0xFF

                return params
            }(),
            expect: false,
            isOK:   true,
            info:   "Processing a request where to is greater than any existing register and with a specific topic, should not provide results",
        },
        {
            params: func() *ServerTestParams {
                params := s.defaultServerParams(env)
                params.low = params.birth
                params.upp = params.birth - 1

                return params
            }(),
            isOK: false,
            info: "Processing a request where to is lower than from should fail",
        },
        {
            params: func() *ServerTestParams {
                params := s.defaultServerParams(env)
                params.low = 0
                params.upp = params.birth + 24

                return params
            }(),
            isOK: false,
            info: "Processing a request where difference between from and to is > 24 should fail",
        },
    }
    for _, testCase := range testCases {
        // to satisfy gosec: C601 checks
        tc := testCase
        s.T().Run(tc.info, func(*testing.T) {
            request := s.createRequest(tc.params)
            src := crypto.FromECDSAPub(&tc.params.key.PublicKey)
            payload, err := s.server.decompositeRequest(src, request)
            s.Equal(tc.isOK, err == nil)
            if err == nil {
                s.Equal(tc.params.low, payload.Lower)
                s.Equal(tc.params.upp, payload.Upper)
                s.Equal(tc.params.limit, payload.Limit)
                s.Equal(types.TopicToBloom(tc.params.topic), payload.Bloom)
                s.Equal(tc.expect, s.messageExists(env, tc.params.low, tc.params.upp, payload.Bloom, tc.params.limit))

                src[0]++
                _, err = s.server.decompositeRequest(src, request)
                s.True(err == nil)
            }
        })
    }
}

func (s *MailserverSuite) TestDecodeRequest() {
    s.setupServer(s.server)
    defer s.server.Close()

    payload := MessagesRequestPayload{
        Lower:  50,
        Upper:  100,
        Bloom:  []byte{0x01},
        Topics: [][]byte{},
        Limit:  10,
        Cursor: []byte{},
        Batch:  true,
    }
    data, err := rlp.EncodeToBytes(payload)
    s.Require().NoError(err)

    id, err := s.shh.NewKeyPair()
    s.Require().NoError(err)
    srcKey, err := s.shh.GetPrivateKey(id)
    s.Require().NoError(err)

    env := s.createEnvelope(types.TopicType{0x01}, data, srcKey)

    decodedPayload, err := s.server.decodeRequest(nil, env)
    s.Require().NoError(err)
    s.Equal(payload, decodedPayload)
}

func (s *MailserverSuite) TestDecodeRequestNoUpper() {
    s.setupServer(s.server)
    defer s.server.Close()

    payload := MessagesRequestPayload{
        Lower:  50,
        Bloom:  []byte{0x01},
        Limit:  10,
        Cursor: []byte{},
        Batch:  true,
    }
    data, err := rlp.EncodeToBytes(payload)
    s.Require().NoError(err)

    id, err := s.shh.NewKeyPair()
    s.Require().NoError(err)
    srcKey, err := s.shh.GetPrivateKey(id)
    s.Require().NoError(err)

    env := s.createEnvelope(types.TopicType{0x01}, data, srcKey)

    decodedPayload, err := s.server.decodeRequest(nil, env)
    s.Require().NoError(err)
    s.NotEqual(0, decodedPayload.Upper)
}

func (s *MailserverSuite) TestProcessRequestDeadlockHandling() {
    s.setupServer(s.server)
    defer s.server.Close()

    var archievedEnvelopes []*wakucommon.Envelope

    now := time.Now()
    count := uint32(10)

    // Archieve some envelopes.
    for i := count; i > 0; i-- {
        sentTime := now.Add(time.Duration(-i) * time.Second)
        env, err := generateEnvelope(sentTime)
        s.NoError(err)
        s.server.Archive(env)
        archievedEnvelopes = append(archievedEnvelopes, env)
    }

    // Prepare a request.
    peerID, request, err := s.prepareRequest(archievedEnvelopes, 5)
    s.NoError(err)
    payload, err := s.server.decompositeRequest(peerID, request)
    s.NoError(err)

    testCases := []struct {
        Name    string
        Timeout time.Duration
        Verify  func(
            Iterator,
            time.Duration, // processRequestInBundles timeout
            chan []rlp.RawValue,
        )
    }{
        {
            Name:    "finish processing using `done` channel",
            Timeout: time.Second * 5,
            Verify: func(
                iter Iterator,
                timeout time.Duration,
                bundles chan []rlp.RawValue,
            ) {
                done := make(chan struct{})
                processFinished := make(chan struct{})

                go func() {
                    s.server.ms.processRequestInBundles(iter, payload.Bloom, payload.Topics, int(payload.Limit), timeout, "req-01", bundles, done)
                    close(processFinished)
                }()
                go close(done)

                select {
                case <-processFinished:
                case <-time.After(time.Second):
                    s.FailNow("waiting for processing finish timed out")
                }
            },
        },
        {
            Name:    "finish processing due to timeout",
            Timeout: time.Second,
            Verify: func(
                iter Iterator,
                timeout time.Duration,
                bundles chan []rlp.RawValue,
            ) {
                done := make(chan struct{}) // won't be closed because we test timeout of `processRequestInBundles()`
                processFinished := make(chan struct{})

                go func() {
                    s.server.ms.processRequestInBundles(iter, payload.Bloom, payload.Topics, int(payload.Limit), time.Second, "req-01", bundles, done)
                    close(processFinished)
                }()

                select {
                case <-processFinished:
                case <-time.After(time.Second * 5):
                    s.FailNow("waiting for processing finish timed out")
                }
            },
        },
    }

    for _, tc := range testCases {
        s.T().Run(tc.Name, func(t *testing.T) {
            iter, err := s.server.ms.createIterator(payload)
            s.Require().NoError(err)

            defer func() { _ = iter.Release() }()

            // Nothing reads from this unbuffered channel which simulates a situation
            // when a connection between a peer and mail server was dropped.
            bundles := make(chan []rlp.RawValue)

            tc.Verify(iter, tc.Timeout, bundles)
        })
    }
}

func (s *MailserverSuite) messageExists(envelope *wakucommon.Envelope, low, upp uint32, bloom []byte, limit uint32) bool {
    receivedHashes, _, _ := processRequestAndCollectHashes(s.server, MessagesRequestPayload{
        Lower: low,
        Upper: upp,
        Bloom: bloom,
        Limit: limit,
    })
    for _, hash := range receivedHashes {
        if hash == envelope.Hash() {
            return true
        }
    }
    return false
}

func (s *MailserverSuite) setupServer(server *WakuMailServer) {
    const password = "password_for_this_test"

    s.shh = waku.New(&waku.DefaultConfig, nil)
    s.shh.RegisterMailServer(server)

    err := server.Init(s.shh, &params.WakuConfig{
        DataDir:            s.dataDir,
        MailServerPassword: password,
        MinimumPoW:         powRequirement,
    })
    if err != nil {
        s.T().Fatal(err)
    }

    keyID, err = s.shh.AddSymKeyFromPassword(password)
    if err != nil {
        s.T().Fatalf("failed to create symmetric key for mail request: %s", err)
    }
}

func (s *MailserverSuite) prepareRequest(envelopes []*wakucommon.Envelope, limit uint32) (
    []byte, *wakucommon.Envelope, error,
) {
    if len(envelopes) == 0 {
        return nil, nil, errors.New("envelopes is empty")
    }

    now := time.Now()

    params := s.defaultServerParams(envelopes[0])
    params.low = uint32(now.Add(time.Duration(-len(envelopes)) * time.Second).Unix())
    params.upp = uint32(now.Unix())
    params.limit = limit

    request := s.createRequest(params)
    peerID := crypto.FromECDSAPub(&params.key.PublicKey)

    return peerID, request, nil
}

func (s *MailserverSuite) defaultServerParams(env *wakucommon.Envelope) *ServerTestParams {
    id, err := s.shh.NewKeyPair()
    if err != nil {
        s.T().Fatalf("failed to generate new key pair with seed %d: %s.", seed, err)
    }
    testPeerID, err := s.shh.GetPrivateKey(id)
    if err != nil {
        s.T().Fatalf("failed to retrieve new key pair with seed %d: %s.", seed, err)
    }
    birth := env.Expiry - env.TTL

    return &ServerTestParams{
        topic: types.TopicType(env.Topic),
        birth: birth,
        low:   birth - 1,
        upp:   birth + 1,
        limit: 0,
        key:   testPeerID,
    }
}

func (s *MailserverSuite) createRequest(p *ServerTestParams) *wakucommon.Envelope {
    bloom := types.TopicToBloom(p.topic)
    data := make([]byte, 8)
    binary.BigEndian.PutUint32(data, p.low)
    binary.BigEndian.PutUint32(data[4:], p.upp)
    data = append(data, bloom...)

    if p.limit != 0 {
        limitData := make([]byte, 4)
        binary.BigEndian.PutUint32(limitData, p.limit)
        data = append(data, limitData...)
    }

    return s.createEnvelope(p.topic, data, p.key)
}

func (s *MailserverSuite) createEnvelope(topic types.TopicType, data []byte, srcKey *ecdsa.PrivateKey) *wakucommon.Envelope {
    key, err := s.shh.GetSymKey(keyID)
    if err != nil {
        s.T().Fatalf("failed to retrieve sym key with seed %d: %s.", seed, err)
    }

    params := &wakucommon.MessageParams{
        KeySym:   key,
        Topic:    wakucommon.TopicType(topic),
        Payload:  data,
        PoW:      powRequirement * 2,
        WorkTime: 2,
        Src:      srcKey,
    }

    msg, err := wakucommon.NewSentMessage(params)
    if err != nil {
        s.T().Fatalf("failed to create new message with seed %d: %s.", seed, err)
    }

    env, err := msg.Wrap(params, time.Now())
    if err != nil {
        s.T().Fatalf("failed to wrap with seed %d: %s.", seed, err)
    }
    return env
}

func generateEnvelopeWithKeys(sentTime time.Time, keySym []byte, keyAsym *ecdsa.PublicKey) (*wakucommon.Envelope, error) {
    params := &wakucommon.MessageParams{
        Topic:    wakucommon.TopicType{0x1F, 0x7E, 0xA1, 0x7F},
        Payload:  testPayload,
        PoW:      powRequirement,
        WorkTime: 2,
    }

    if len(keySym) > 0 {
        params.KeySym = keySym
    } else if keyAsym != nil {
        params.Dst = keyAsym
    }

    msg, err := wakucommon.NewSentMessage(params)
    if err != nil {
        return nil, fmt.Errorf("failed to create new message with seed %d: %s", seed, err)
    }
    env, err := msg.Wrap(params, sentTime)
    if err != nil {
        return nil, fmt.Errorf("failed to wrap with seed %d: %s", seed, err)
    }

    return env, nil
}

func generateEnvelope(sentTime time.Time) (*wakucommon.Envelope, error) {
    h := crypto.Keccak256Hash([]byte("test sample data"))
    return generateEnvelopeWithKeys(sentTime, h[:], nil)
}

func processRequestAndCollectHashes(server *WakuMailServer, payload MessagesRequestPayload) ([]common.Hash, []byte, types.Hash) {
    iter, _ := server.ms.createIterator(payload)
    defer func() { _ = iter.Release() }()
    bundles := make(chan []rlp.RawValue, 10)
    done := make(chan struct{})

    var hashes []common.Hash
    go func() {
        for bundle := range bundles {
            for _, rawEnvelope := range bundle {
                var env *wakucommon.Envelope
                if err := rlp.DecodeBytes(rawEnvelope, &env); err != nil {
                    panic(err)
                }
                hashes = append(hashes, env.Hash())
            }
        }
        close(done)
    }()

    cursor, lastHash := server.ms.processRequestInBundles(iter, payload.Bloom, payload.Topics, int(payload.Limit), time.Minute, "req-01", bundles, done)

    <-done

    return hashes, cursor, lastHash
}