status-im/status-go

View on GitHub
protocol/encryption/encryption_multi_device_test.go

Summary

Maintainability
A
0 mins
Test Coverage
package encryption

import (
    "crypto/ecdsa"
    "fmt"
    "testing"

    "github.com/status-im/status-go/appdatabase"
    "github.com/status-im/status-go/protocol/sqlite"
    "github.com/status-im/status-go/protocol/tt"
    "github.com/status-im/status-go/t/helpers"

    "github.com/stretchr/testify/suite"
    "go.uber.org/zap"

    "github.com/status-im/status-go/eth-node/crypto"

    "github.com/status-im/status-go/protocol/encryption/multidevice"
)

const (
    aliceUser = "alice"
    bobUser   = "bob"
)

func TestEncryptionServiceMultiDeviceSuite(t *testing.T) {
    suite.Run(t, new(EncryptionServiceMultiDeviceSuite))
}

type serviceAndKey struct {
    services []*Protocol
    key      *ecdsa.PrivateKey
}

type EncryptionServiceMultiDeviceSuite struct {
    suite.Suite
    services map[string]*serviceAndKey
    logger   *zap.Logger
}

func setupUser(user string, s *EncryptionServiceMultiDeviceSuite, n int) error {
    key, err := crypto.GenerateKey()
    if err != nil {
        return err
    }

    s.services[user] = &serviceAndKey{
        key:      key,
        services: make([]*Protocol, n),
    }

    for i := 0; i < n; i++ {
        installationID := fmt.Sprintf("%s%d", user, i+1)

        db, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
        if err != nil {
            return err
        }
        err = sqlite.Migrate(db)
        if err != nil {
            return err
        }

        protocol := New(
            db,
            installationID,
            s.logger.With(zap.String("user", user)),
        )
        s.services[user].services[i] = protocol
    }

    return nil
}

func (s *EncryptionServiceMultiDeviceSuite) SetupTest() {
    s.logger = tt.MustCreateTestLogger()

    s.services = make(map[string]*serviceAndKey)
    err := setupUser(aliceUser, s, 4)
    s.Require().NoError(err)

    err = setupUser(bobUser, s, 4)
    s.Require().NoError(err)
}

func (s *EncryptionServiceMultiDeviceSuite) TearDownTest() {
    _ = s.logger.Sync()
}

func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundle() {
    aliceKey := s.services[aliceUser].key

    alice2Bundle, err := s.services[aliceUser].services[1].GetBundle(aliceKey)
    s.Require().NoError(err)

    alice2IdentityPK, err := ExtractIdentity(alice2Bundle)
    s.Require().NoError(err)

    alice2Identity := fmt.Sprintf("0x%x", crypto.FromECDSAPub(alice2IdentityPK))

    alice3Bundle, err := s.services[aliceUser].services[2].GetBundle(aliceKey)
    s.Require().NoError(err)

    alice3IdentityPK, err := ExtractIdentity(alice2Bundle)
    s.Require().NoError(err)

    alice3Identity := fmt.Sprintf("0x%x", crypto.FromECDSAPub(alice3IdentityPK))

    // Add alice2 bundle
    response, err := s.services[aliceUser].services[0].ProcessPublicBundle(aliceKey, alice2Bundle)
    s.Require().NoError(err)
    s.Require().Equal(multidevice.Installation{
        Identity: alice2Identity,
        Version:  1,
        ID:       "alice2",
    }, *response[0])

    // Add alice3 bundle
    response, err = s.services[aliceUser].services[0].ProcessPublicBundle(aliceKey, alice3Bundle)
    s.Require().NoError(err)
    s.Require().Equal(multidevice.Installation{
        Identity: alice3Identity,
        Version:  1,
        ID:       "alice3",
    }, *response[0])

    // No installation is enabled
    alice1MergedBundle1, err := s.services[aliceUser].services[0].GetBundle(aliceKey)
    s.Require().NoError(err)

    s.Require().Equal(1, len(alice1MergedBundle1.GetSignedPreKeys()))
    s.Require().NotNil(alice1MergedBundle1.GetSignedPreKeys()["alice1"])

    // We enable the installations
    err = s.services[aliceUser].services[0].EnableInstallation(&aliceKey.PublicKey, "alice2")
    s.Require().NoError(err)

    err = s.services[aliceUser].services[0].EnableInstallation(&aliceKey.PublicKey, "alice3")
    s.Require().NoError(err)

    alice1MergedBundle2, err := s.services[aliceUser].services[0].GetBundle(aliceKey)
    s.Require().NoError(err)

    // We get back a bundle with all the installations
    s.Require().Equal(3, len(alice1MergedBundle2.GetSignedPreKeys()))
    s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice1"])
    s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice2"])
    s.Require().NotNil(alice1MergedBundle2.GetSignedPreKeys()["alice3"])

    // We disable the installations
    err = s.services[aliceUser].services[0].DisableInstallation(&aliceKey.PublicKey, "alice2")
    s.Require().NoError(err)

    alice1MergedBundle3, err := s.services[aliceUser].services[0].GetBundle(aliceKey)
    s.Require().NoError(err)

    // We get back a bundle with all the installations
    s.Require().Equal(2, len(alice1MergedBundle3.GetSignedPreKeys()))
    s.Require().NotNil(alice1MergedBundle3.GetSignedPreKeys()["alice1"])
    s.Require().NotNil(alice1MergedBundle3.GetSignedPreKeys()["alice3"])
}

func (s *EncryptionServiceMultiDeviceSuite) TestProcessPublicBundleOutOfOrder() {
    aliceKey, err := crypto.GenerateKey()
    s.Require().NoError(err)

    // Alice1 creates a bundle
    alice1Bundle, err := s.services[aliceUser].services[0].GetBundle(aliceKey)
    s.Require().NoError(err)

    // Alice2 Receives the bundle
    _, err = s.services[aliceUser].services[1].ProcessPublicBundle(aliceKey, alice1Bundle)
    s.Require().NoError(err)

    // Alice2 Creates a Bundle
    _, err = s.services[aliceUser].services[1].GetBundle(aliceKey)
    s.Require().NoError(err)

    // We enable the installation
    err = s.services[aliceUser].services[1].EnableInstallation(&aliceKey.PublicKey, "alice1")
    s.Require().NoError(err)

    // It should contain both bundles
    alice2MergedBundle1, err := s.services[aliceUser].services[1].GetBundle(aliceKey)
    s.Require().NoError(err)

    s.Require().NotNil(alice2MergedBundle1.GetSignedPreKeys()["alice1"])
    s.Require().NotNil(alice2MergedBundle1.GetSignedPreKeys()["alice2"])
}

func pairDevices(s *serviceAndKey, target int) error {
    device := s.services[target]
    for i := 0; i < len(s.services); i++ {
        b, err := s.services[i].GetBundle(s.key)

        if err != nil {
            return err
        }

        _, err = device.ProcessPublicBundle(s.key, b)
        if err != nil {
            return err
        }

        err = device.EnableInstallation(&s.key.PublicKey, s.services[i].encryptor.config.InstallationID)
        if err != nil {
            return nil
        }
    }
    return nil
}

func (s *EncryptionServiceMultiDeviceSuite) TestMaxDevices() {
    err := pairDevices(s.services[aliceUser], 0)
    s.Require().NoError(err)
    alice1 := s.services[aliceUser].services[0]
    bob1 := s.services[bobUser].services[0]
    aliceKey := s.services[aliceUser].key
    bobKey := s.services[bobUser].key

    // Check bundle is ok
    // No installation is enabled
    aliceBundle, err := alice1.GetBundle(aliceKey)
    s.Require().NoError(err)

    // Check all installations are correctly working, and that the oldest device is not there
    preKeys := aliceBundle.GetSignedPreKeys()
    s.Require().Equal(3, len(preKeys))
    s.Require().NotNil(preKeys["alice1"])
    // alice2 being the oldest device is rotated out, as we reached the maximum
    s.Require().Nil(preKeys["alice2"])
    s.Require().NotNil(preKeys["alice3"])
    s.Require().NotNil(preKeys["alice4"])

    // We propagate this to bob
    _, err = bob1.ProcessPublicBundle(bobKey, aliceBundle)
    s.Require().NoError(err)

    // Bob sends a message to alice
    msg, err := bob1.BuildEncryptedMessage(bobKey, &aliceKey.PublicKey, []byte("test"))
    s.Require().NoError(err)
    payload := msg.Message.GetEncryptedMessage()
    s.Require().Equal(3, len(payload))
    s.Require().NotNil(payload["alice1"])
    s.Require().NotNil(payload["alice3"])
    s.Require().NotNil(payload["alice4"])

    // We disable the last installation
    err = s.services[aliceUser].services[0].DisableInstallation(&aliceKey.PublicKey, "alice4")
    s.Require().NoError(err)

    // We check the bundle is updated
    aliceBundle, err = alice1.GetBundle(aliceKey)
    s.Require().NoError(err)

    // Check all installations are there
    preKeys = aliceBundle.GetSignedPreKeys()
    s.Require().Equal(3, len(preKeys))
    s.Require().NotNil(preKeys["alice1"])
    s.Require().NotNil(preKeys["alice2"])
    s.Require().NotNil(preKeys["alice3"])
    // alice4 is disabled at this point, alice2 is back in
    s.Require().Nil(preKeys["alice4"])

    // We propagate this to bob
    _, err = bob1.ProcessPublicBundle(bobKey, aliceBundle)
    s.Require().NoError(err)

    // Bob sends a message to alice
    msg, err = bob1.BuildEncryptedMessage(bobKey, &aliceKey.PublicKey, []byte("test"))
    s.Require().NoError(err)
    payload = msg.Message.GetEncryptedMessage()
    s.Require().Equal(3, len(payload))
    s.Require().NotNil(payload["alice1"])
    s.Require().NotNil(payload["alice2"])
    s.Require().NotNil(payload["alice3"])
}

func (s *EncryptionServiceMultiDeviceSuite) TestMaxDevicesRefreshedBundle() {
    alice1 := s.services[aliceUser].services[0]
    alice2 := s.services[aliceUser].services[1]
    alice3 := s.services[aliceUser].services[2]
    alice4 := s.services[aliceUser].services[3]
    bob1 := s.services[bobUser].services[0]
    bobKey := s.services[bobUser].key
    aliceKey := s.services[aliceUser].key

    // We create alice bundles, in order
    alice1Bundle, err := alice1.GetBundle(aliceKey)
    s.Require().NoError(err)

    alice2Bundle, err := alice2.GetBundle(aliceKey)
    s.Require().NoError(err)

    alice3Bundle, err := alice3.GetBundle(aliceKey)
    s.Require().NoError(err)

    alice4Bundle, err := alice4.GetBundle(aliceKey)
    s.Require().NoError(err)

    // We send all the bundles to bob
    _, err = bob1.ProcessPublicBundle(bobKey, alice1Bundle)
    s.Require().NoError(err)

    _, err = bob1.ProcessPublicBundle(bobKey, alice2Bundle)
    s.Require().NoError(err)

    _, err = bob1.ProcessPublicBundle(bobKey, alice3Bundle)
    s.Require().NoError(err)

    _, err = bob1.ProcessPublicBundle(bobKey, alice4Bundle)
    s.Require().NoError(err)

    // Bob sends a message to alice
    msg1, err := bob1.BuildEncryptedMessage(bobKey, &aliceKey.PublicKey, []byte("test"))
    s.Require().NoError(err)
    payload := msg1.Message.GetEncryptedMessage()
    s.Require().Equal(3, len(payload))
    // Alice1 is the oldest bundle and is rotated out
    // as we send maximum to 3 devices
    s.Require().Nil(payload["alice1"])
    s.Require().NotNil(payload["alice2"])
    s.Require().NotNil(payload["alice3"])
    s.Require().NotNil(payload["alice4"])

    // We send a message to bob from alice1, the timestamp should be refreshed
    msg2, err := alice1.BuildEncryptedMessage(aliceKey, &bobKey.PublicKey, []byte("test"))
    s.Require().NoError(err)

    alice1Bundle = msg2.Message.GetBundles()[0]

    // Bob processes the bundle
    _, err = bob1.ProcessPublicBundle(bobKey, alice1Bundle)
    s.Require().NoError(err)

    // Bob sends a message to alice
    msg3, err := bob1.BuildEncryptedMessage(bobKey, &aliceKey.PublicKey, []byte("test"))
    s.Require().NoError(err)
    payload = msg3.Message.GetEncryptedMessage()
    s.Require().Equal(3, len(payload))
    // Alice 1 is added back to the list of active devices
    s.Require().NotNil(payload["alice1"])
    // Alice 2 is rotated out as the oldest device in terms of activity
    s.Require().Nil(payload["alice2"])
    // Alice 3, 4 are still in
    s.Require().NotNil(payload["alice3"])
    s.Require().NotNil(payload["alice4"])
}