status-im/status-go

View on GitHub
protocol/encryption/persistence.go

Summary

Maintainability
A
0 mins
Test Coverage
B
83%
package encryption

import (
    "context"
    "crypto/ecdsa"
    "database/sql"
    "errors"
    "strings"

    dr "github.com/status-im/doubleratchet"

    "github.com/status-im/status-go/eth-node/crypto"

    "github.com/status-im/status-go/protocol/encryption/multidevice"
)

// RatchetInfo holds the current ratchet state.
type RatchetInfo struct {
    ID             []byte
    Sk             []byte
    PrivateKey     []byte
    PublicKey      []byte
    Identity       []byte
    BundleID       []byte
    EphemeralKey   []byte
    InstallationID string
}

// A safe max number of rows.
const maxNumberOfRows = 100000000

type sqlitePersistence struct {
    DB             *sql.DB
    keysStorage    dr.KeysStorage
    sessionStorage dr.SessionStorage
}

func newSQLitePersistence(db *sql.DB) *sqlitePersistence {
    return &sqlitePersistence{
        DB:             db,
        keysStorage:    newSQLiteKeysStorage(db),
        sessionStorage: newSQLiteSessionStorage(db),
    }
}

// GetKeysStorage returns the associated double ratchet KeysStorage object
func (s *sqlitePersistence) KeysStorage() dr.KeysStorage {
    return s.keysStorage
}

// GetSessionStorage returns the associated double ratchet SessionStorage object
func (s *sqlitePersistence) SessionStorage() dr.SessionStorage {
    return s.sessionStorage
}

// AddPrivateBundle adds the specified BundleContainer to the database
func (s *sqlitePersistence) AddPrivateBundle(bc *BundleContainer) error {
    tx, err := s.DB.Begin()
    if err != nil {
        return err
    }

    for installationID, signedPreKey := range bc.GetBundle().GetSignedPreKeys() {
        var version uint32
        stmt, err := tx.Prepare(`SELECT version
                     FROM bundles
                     WHERE installation_id = ? AND identity = ?
                     ORDER BY version DESC
                     LIMIT 1`)
        if err != nil {
            return err
        }

        defer stmt.Close()

        err = stmt.QueryRow(installationID, bc.GetBundle().GetIdentity()).Scan(&version)
        if err != nil && err != sql.ErrNoRows {
            return err
        }

        stmt, err = tx.Prepare(`INSERT INTO bundles(identity, private_key, signed_pre_key, installation_id, version, timestamp)
                    VALUES(?, ?, ?, ?, ?, ?)`)
        if err != nil {
            return err
        }
        defer stmt.Close()

        _, err = stmt.Exec(
            bc.GetBundle().GetIdentity(),
            bc.GetPrivateSignedPreKey(),
            signedPreKey.GetSignedPreKey(),
            installationID,
            version+1,
            bc.GetBundle().GetTimestamp(),
        )
        if err != nil {
            _ = tx.Rollback()
            return err
        }
    }

    if err := tx.Commit(); err != nil {
        _ = tx.Rollback()
        return err
    }

    return nil
}

// AddPublicBundle adds the specified Bundle to the database
func (s *sqlitePersistence) AddPublicBundle(b *Bundle) error {
    tx, err := s.DB.Begin()

    if err != nil {
        return err
    }

    for installationID, signedPreKeyContainer := range b.GetSignedPreKeys() {
        signedPreKey := signedPreKeyContainer.GetSignedPreKey()
        version := signedPreKeyContainer.GetVersion()
        insertStmt, err := tx.Prepare(`INSERT INTO bundles(identity, signed_pre_key, installation_id, version, timestamp)
                           VALUES( ?, ?, ?, ?, ?)`)
        if err != nil {
            return err
        }
        defer insertStmt.Close()

        _, err = insertStmt.Exec(
            b.GetIdentity(),
            signedPreKey,
            installationID,
            version,
            b.GetTimestamp(),
        )
        if err != nil {
            _ = tx.Rollback()
            return err
        }
        // Mark old bundles as expired
        updateStmt, err := tx.Prepare(`UPDATE bundles
                           SET expired = 1
                           WHERE identity = ? AND installation_id = ? AND version < ?`)
        if err != nil {
            return err
        }
        defer updateStmt.Close()

        _, err = updateStmt.Exec(
            b.GetIdentity(),
            installationID,
            version,
        )
        if err != nil {
            _ = tx.Rollback()
            return err
        }

    }

    return tx.Commit()
}

// GetAnyPrivateBundle retrieves any bundle from the database containing a private key
func (s *sqlitePersistence) GetAnyPrivateBundle(myIdentityKey []byte, installations []*multidevice.Installation) (*BundleContainer, error) {

    versions := make(map[string]uint32)
    /* #nosec */
    statement := `SELECT identity, private_key, signed_pre_key, installation_id, timestamp, version
                  FROM bundles
              WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installations)-1) + ")"
    stmt, err := s.DB.Prepare(statement)
    if err != nil {
        return nil, err
    }
    defer stmt.Close()

    var timestamp int64
    var identity []byte
    var privateKey []byte
    var version uint32

    args := make([]interface{}, len(installations)+1)
    args[0] = myIdentityKey
    for i, installation := range installations {
        // Lookup up map for versions
        versions[installation.ID] = installation.Version

        args[i+1] = installation.ID
    }

    rows, err := stmt.Query(args...)
    rowCount := 0

    if err != nil {
        return nil, err
    }

    defer rows.Close()

    bundle := &Bundle{
        SignedPreKeys: make(map[string]*SignedPreKey),
    }

    bundleContainer := &BundleContainer{
        Bundle: bundle,
    }

    for rows.Next() {
        var signedPreKey []byte
        var installationID string
        rowCount++
        err = rows.Scan(
            &identity,
            &privateKey,
            &signedPreKey,
            &installationID,
            &timestamp,
            &version,
        )
        if err != nil {
            return nil, err
        }
        // If there is a private key, we set the timestamp of the bundle container
        if privateKey != nil {
            bundle.Timestamp = timestamp
        }

        bundle.SignedPreKeys[installationID] = &SignedPreKey{
            SignedPreKey:    signedPreKey,
            Version:         version,
            ProtocolVersion: versions[installationID],
        }
        bundle.Identity = identity
    }

    // If no records are found or no record with private key, return nil
    if rowCount == 0 || bundleContainer.GetBundle().Timestamp == 0 {
        return nil, nil
    }

    return bundleContainer, nil

}

// GetPrivateKeyBundle retrieves a private key for a bundle from the database
func (s *sqlitePersistence) GetPrivateKeyBundle(bundleID []byte) ([]byte, error) {
    stmt, err := s.DB.Prepare(`SELECT private_key
                   FROM bundles
                   WHERE signed_pre_key = ? LIMIT 1`)
    if err != nil {
        return nil, err
    }
    defer stmt.Close()

    var privateKey []byte

    err = stmt.QueryRow(bundleID).Scan(&privateKey)
    switch err {
    case sql.ErrNoRows:
        return nil, nil
    case nil:
        return privateKey, nil
    default:
        return nil, err
    }
}

// MarkBundleExpired expires any private bundle for a given identity
func (s *sqlitePersistence) MarkBundleExpired(identity []byte) error {
    stmt, err := s.DB.Prepare(`UPDATE bundles
                   SET expired = 1
                   WHERE identity = ? AND private_key IS NOT NULL`)
    if err != nil {
        return err
    }
    defer stmt.Close()

    _, err = stmt.Exec(identity)

    return err
}

// GetPublicBundle retrieves an existing Bundle for the specified public key from the database
func (s *sqlitePersistence) GetPublicBundle(publicKey *ecdsa.PublicKey, installations []*multidevice.Installation) (*Bundle, error) {

    if len(installations) == 0 {
        return nil, nil
    }

    versions := make(map[string]uint32)
    identity := crypto.CompressPubkey(publicKey)

    /* #nosec */
    statement := `SELECT signed_pre_key,installation_id, version
              FROM bundles
              WHERE expired = 0 AND identity = ? AND installation_id IN (?` + strings.Repeat(",?", len(installations)-1) + `)
              ORDER BY version DESC`
    stmt, err := s.DB.Prepare(statement)
    if err != nil {
        return nil, err
    }
    defer stmt.Close()

    args := make([]interface{}, len(installations)+1)
    args[0] = identity
    for i, installation := range installations {
        // Lookup up map for versions
        versions[installation.ID] = installation.Version
        args[i+1] = installation.ID
    }

    rows, err := stmt.Query(args...)
    rowCount := 0

    if err != nil {
        return nil, err
    }

    defer rows.Close()

    bundle := &Bundle{
        Identity:      identity,
        SignedPreKeys: make(map[string]*SignedPreKey),
    }

    for rows.Next() {
        var signedPreKey []byte
        var installationID string
        var version uint32
        rowCount++
        err = rows.Scan(
            &signedPreKey,
            &installationID,
            &version,
        )
        if err != nil {
            return nil, err
        }

        bundle.SignedPreKeys[installationID] = &SignedPreKey{
            SignedPreKey:    signedPreKey,
            Version:         version,
            ProtocolVersion: versions[installationID],
        }

    }

    if rowCount == 0 {
        return nil, nil
    }

    return bundle, nil

}

// AddRatchetInfo persists the specified ratchet info into the database
func (s *sqlitePersistence) AddRatchetInfo(key []byte, identity []byte, bundleID []byte, ephemeralKey []byte, installationID string) error {
    stmt, err := s.DB.Prepare(`INSERT INTO ratchet_info_v2(symmetric_key, identity, bundle_id, ephemeral_key, installation_id)
                   VALUES(?, ?, ?, ?, ?)`)
    if err != nil {
        return err
    }
    defer stmt.Close()

    _, err = stmt.Exec(
        key,
        identity,
        bundleID,
        ephemeralKey,
        installationID,
    )

    return err
}

// GetRatchetInfo retrieves the existing RatchetInfo for a specified bundle ID and interlocutor public key from the database
func (s *sqlitePersistence) GetRatchetInfo(bundleID []byte, theirIdentity []byte, installationID string) (*RatchetInfo, error) {
    stmt, err := s.DB.Prepare(`SELECT ratchet_info_v2.identity, ratchet_info_v2.symmetric_key, bundles.private_key, bundles.signed_pre_key, ratchet_info_v2.ephemeral_key, ratchet_info_v2.installation_id
                   FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key
                   WHERE ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ? AND bundle_id = ?
                   LIMIT 1`)
    if err != nil {
        return nil, err
    }
    defer stmt.Close()

    ratchetInfo := &RatchetInfo{
        BundleID: bundleID,
    }

    err = stmt.QueryRow(theirIdentity, installationID, bundleID).Scan(
        &ratchetInfo.Identity,
        &ratchetInfo.Sk,
        &ratchetInfo.PrivateKey,
        &ratchetInfo.PublicKey,
        &ratchetInfo.EphemeralKey,
        &ratchetInfo.InstallationID,
    )
    switch err {
    case sql.ErrNoRows:
        return nil, nil
    case nil:
        ratchetInfo.ID = append(bundleID, []byte(ratchetInfo.InstallationID)...)
        return ratchetInfo, nil
    default:
        return nil, err
    }
}

// GetAnyRatchetInfo retrieves any existing RatchetInfo for a specified interlocutor public key from the database
func (s *sqlitePersistence) GetAnyRatchetInfo(identity []byte, installationID string) (*RatchetInfo, error) {
    stmt, err := s.DB.Prepare(`SELECT symmetric_key, bundles.private_key, signed_pre_key, bundle_id, ephemeral_key
                   FROM ratchet_info_v2 JOIN bundles ON bundle_id = signed_pre_key
                   WHERE expired = 0 AND ratchet_info_v2.identity = ? AND ratchet_info_v2.installation_id = ?
                   LIMIT 1`)
    if err != nil {
        return nil, err
    }
    defer stmt.Close()

    ratchetInfo := &RatchetInfo{
        Identity:       identity,
        InstallationID: installationID,
    }

    err = stmt.QueryRow(identity, installationID).Scan(
        &ratchetInfo.Sk,
        &ratchetInfo.PrivateKey,
        &ratchetInfo.PublicKey,
        &ratchetInfo.BundleID,
        &ratchetInfo.EphemeralKey,
    )
    switch err {
    case sql.ErrNoRows:
        return nil, nil
    case nil:
        ratchetInfo.ID = append(ratchetInfo.BundleID, []byte(installationID)...)
        return ratchetInfo, nil
    default:
        return nil, err
    }
}

// RatchetInfoConfirmed clears the ephemeral key in the RatchetInfo
// associated with the specified bundle ID and interlocutor identity public key
func (s *sqlitePersistence) RatchetInfoConfirmed(bundleID []byte, theirIdentity []byte, installationID string) error {
    stmt, err := s.DB.Prepare(`UPDATE ratchet_info_v2
                               SET ephemeral_key = NULL
                   WHERE identity = ? AND bundle_id = ? AND installation_id = ?`)
    if err != nil {
        return err
    }
    defer stmt.Close()

    _, err = stmt.Exec(
        theirIdentity,
        bundleID,
        installationID,
    )

    return err
}

type sqliteKeysStorage struct {
    db *sql.DB
}

func newSQLiteKeysStorage(db *sql.DB) *sqliteKeysStorage {
    return &sqliteKeysStorage{
        db: db,
    }
}

// Get retrieves the message key for a specified public key and message number
func (s *sqliteKeysStorage) Get(pubKey dr.Key, msgNum uint) (dr.Key, bool, error) {
    var key []byte
    stmt, err := s.db.Prepare(`SELECT message_key
                               FROM keys
                   WHERE public_key = ? AND msg_num = ?
                   LIMIT 1`)

    if err != nil {
        return key, false, err
    }
    defer stmt.Close()

    err = stmt.QueryRow(pubKey, msgNum).Scan(&key)
    switch err {
    case sql.ErrNoRows:
        return key, false, nil
    case nil:
        return key, true, nil
    default:
        return key, false, err
    }
}

// Put stores a key with the specified public key, message number and message key
func (s *sqliteKeysStorage) Put(sessionID []byte, pubKey dr.Key, msgNum uint, mk dr.Key, seqNum uint) error {
    stmt, err := s.db.Prepare(`INSERT INTO keys(session_id, public_key, msg_num, message_key, seq_num)
                               VALUES(?, ?, ?, ?, ?)`)
    if err != nil {
        return err
    }
    defer stmt.Close()

    _, err = stmt.Exec(
        sessionID,
        pubKey,
        msgNum,
        mk,
        seqNum,
    )

    return err
}

// DeleteOldMks caps remove any key < seq_num, included
func (s *sqliteKeysStorage) DeleteOldMks(sessionID []byte, deleteUntil uint) error {
    stmt, err := s.db.Prepare(`DELETE FROM keys
                               WHERE session_id = ? AND seq_num <= ?`)
    if err != nil {
        return err
    }
    defer stmt.Close()

    _, err = stmt.Exec(
        sessionID,
        deleteUntil,
    )

    return err
}

// TruncateMks caps the number of keys to maxKeysPerSession deleting them in FIFO fashion
func (s *sqliteKeysStorage) TruncateMks(sessionID []byte, maxKeysPerSession int) error {
    stmt, err := s.db.Prepare(`DELETE FROM keys
                   WHERE rowid IN (SELECT rowid FROM keys WHERE session_id = ? ORDER BY seq_num DESC LIMIT ? OFFSET ?)`)
    if err != nil {
        return err
    }
    defer stmt.Close()

    _, err = stmt.Exec(
        sessionID,
        // We LIMIT to the max number of rows here, as OFFSET can't be used without a LIMIT
        maxNumberOfRows,
        maxKeysPerSession,
    )

    return err
}

// DeleteMk deletes the key with the specified public key and message key
func (s *sqliteKeysStorage) DeleteMk(pubKey dr.Key, msgNum uint) error {
    stmt, err := s.db.Prepare(`DELETE FROM keys
                   WHERE public_key = ? AND msg_num = ?`)
    if err != nil {
        return err
    }
    defer stmt.Close()

    _, err = stmt.Exec(
        pubKey,
        msgNum,
    )

    return err
}

// Count returns the count of keys with the specified public key
func (s *sqliteKeysStorage) Count(pubKey dr.Key) (uint, error) {
    stmt, err := s.db.Prepare(`SELECT COUNT(1)
                   FROM keys
                   WHERE public_key = ?`)
    if err != nil {
        return 0, err
    }
    defer stmt.Close()

    var count uint
    err = stmt.QueryRow(pubKey).Scan(&count)
    if err != nil {
        return 0, err
    }

    return count, nil
}

// CountAll returns the count of keys with the specified public key
func (s *sqliteKeysStorage) CountAll() (uint, error) {
    stmt, err := s.db.Prepare(`SELECT COUNT(1)
                   FROM keys`)
    if err != nil {
        return 0, err
    }
    defer stmt.Close()

    var count uint
    err = stmt.QueryRow().Scan(&count)
    if err != nil {
        return 0, err
    }

    return count, nil
}

// All returns nil
func (s *sqliteKeysStorage) All() (map[string]map[uint]dr.Key, error) {
    return nil, nil
}

type sqliteSessionStorage struct {
    db *sql.DB
}

func newSQLiteSessionStorage(db *sql.DB) *sqliteSessionStorage {
    return &sqliteSessionStorage{
        db: db,
    }
}

// Save persists the specified double ratchet state
func (s *sqliteSessionStorage) Save(id []byte, state *dr.State) error {
    dhr := state.DHr
    dhs := state.DHs
    dhsPublic := dhs.PublicKey()
    dhsPrivate := dhs.PrivateKey()
    pn := state.PN
    step := state.Step
    keysCount := state.KeysCount

    rootChainKey := state.RootCh.CK

    sendChainKey := state.SendCh.CK
    sendChainN := state.SendCh.N

    recvChainKey := state.RecvCh.CK
    recvChainN := state.RecvCh.N

    stmt, err := s.db.Prepare(`INSERT INTO sessions(id, dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count)
                   VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
    if err != nil {
        return err
    }
    defer stmt.Close()

    _, err = stmt.Exec(
        id,
        dhr,
        dhsPublic,
        dhsPrivate,
        rootChainKey,
        sendChainKey,
        sendChainN,
        recvChainKey,
        recvChainN,
        pn,
        step,
        keysCount,
    )

    return err
}

// Load retrieves the double ratchet state for a given ID
func (s *sqliteSessionStorage) Load(id []byte) (*dr.State, error) {
    stmt, err := s.db.Prepare(`SELECT dhr, dhs_public, dhs_private, root_chain_key, send_chain_key, send_chain_n, recv_chain_key, recv_chain_n, pn, step, keys_count
                   FROM sessions
                   WHERE id = ?`)
    if err != nil {
        return nil, err
    }

    defer stmt.Close()

    var (
        dhr          []byte
        dhsPublic    []byte
        dhsPrivate   []byte
        rootChainKey []byte
        sendChainKey []byte
        sendChainN   uint
        recvChainKey []byte
        recvChainN   uint
        pn           uint
        step         uint
        keysCount    uint
    )

    err = stmt.QueryRow(id).Scan(
        &dhr,
        &dhsPublic,
        &dhsPrivate,
        &rootChainKey,
        &sendChainKey,
        &sendChainN,
        &recvChainKey,
        &recvChainN,
        &pn,
        &step,
        &keysCount,
    )
    switch err {
    case sql.ErrNoRows:
        return nil, nil
    case nil:
        state := dr.DefaultState(rootChainKey)

        state.PN = uint32(pn)
        state.Step = step
        state.KeysCount = keysCount

        state.DHs = crypto.DHPair{
            PrvKey: dhsPrivate,
            PubKey: dhsPublic,
        }

        state.DHr = dhr

        state.SendCh.CK = sendChainKey
        state.SendCh.N = uint32(sendChainN)

        state.RecvCh.CK = recvChainKey
        state.RecvCh.N = uint32(recvChainN)

        return &state, nil
    default:
        return nil, err
    }
}

type HRCache struct {
    GroupID         []byte
    KeyID           []byte
    DeprecatedKeyID uint32
    Key             []byte
    Hash            []byte
    SeqNo           uint32
}

// GetHashRatchetCache retrieves a hash ratchet key by group ID and seqNo.
// If cache data with given seqNo (e.g. 0) is not found,
// then the query will return the cache data with the latest seqNo
func (s *sqlitePersistence) GetHashRatchetCache(ratchet *HashRatchetKeyCompatibility, seqNo uint32) (*HRCache, error) {
    tx, err := s.DB.BeginTx(context.Background(), &sql.TxOptions{})
    if err != nil {
        return nil, err
    }
    defer func() {
        if err == nil {
            err = tx.Commit()
            return
        }
        // don't shadow original error
        _ = tx.Rollback()
    }()

    var key, keyID []byte
    if !ratchet.IsOldFormat() {
        keyID, err = ratchet.GetKeyID()
        if err != nil {
            return nil, err
        }
    }

    err = tx.QueryRow("SELECT key FROM hash_ratchet_encryption WHERE key_id = ? OR (deprecated_key_id = ? AND group_id = ?)",
        keyID,
        ratchet.DeprecatedKeyID(),
        ratchet.GroupID,
    ).Scan(&key)
    if err == sql.ErrNoRows {
        return nil, nil
    }
    if err != nil {
        return nil, err
    }

    args := make([]interface{}, 0)
    args = append(args, ratchet.GroupID)
    args = append(args, keyID)
    args = append(args, ratchet.DeprecatedKeyID())
    var query string
    if seqNo == 0 {
        query = "SELECT seq_no, hash FROM hash_ratchet_encryption_cache WHERE group_id = ? AND (key_id = ? OR key_id = ?) ORDER BY seq_no DESC limit 1"
    } else {
        query = "SELECT seq_no, hash FROM hash_ratchet_encryption_cache WHERE group_id = ? AND (key_id = ? OR key_id = ?) AND seq_no == ? ORDER BY seq_no DESC limit 1"
        args = append(args, seqNo)
    }

    var hash []byte
    var seqNoPtr *uint32

    err = tx.QueryRow(query, args...).Scan(&seqNoPtr, &hash) //nolint: ineffassign,staticcheck
    switch err {
    case sql.ErrNoRows, nil:
        var seqNoResult uint32
        if seqNoPtr == nil {
            seqNoResult = 0
        } else {
            seqNoResult = *seqNoPtr
        }

        ratchet.Key = key
        keyID, err := ratchet.GetKeyID()

        if err != nil {
            return nil, err
        }

        res := &HRCache{
            KeyID: keyID,
            Key:   key,
            Hash:  hash,
            SeqNo: seqNoResult,
        }

        return res, nil
    default:
        return nil, err
    }
}

type HashRatchetKeyCompatibility struct {
    GroupID   []byte
    keyID     []byte
    Timestamp uint64
    Key       []byte
}

func (h *HashRatchetKeyCompatibility) DeprecatedKeyID() uint32 {
    return uint32(h.Timestamp)
}

func (h *HashRatchetKeyCompatibility) IsOldFormat() bool {
    return len(h.keyID) == 0 && len(h.Key) == 0
}

func (h *HashRatchetKeyCompatibility) GetKeyID() ([]byte, error) {
    if len(h.keyID) != 0 {
        return h.keyID, nil
    }

    if len(h.GroupID) == 0 || h.Timestamp == 0 || len(h.Key) == 0 {
        return nil, errors.New("could not create key")
    }

    return generateHashRatchetKeyID(h.GroupID, h.Timestamp, h.Key), nil
}

func (h *HashRatchetKeyCompatibility) GenerateNext() (*HashRatchetKeyCompatibility, error) {

    ratchet := &HashRatchetKeyCompatibility{
        GroupID: h.GroupID,
    }

    // Randomly generate a hash ratchet key
    hrKey, err := crypto.GenerateKey()
    if err != nil {
        return nil, err
    }
    hrKeyBytes := crypto.FromECDSA(hrKey)

    if err != nil {
        return nil, err
    }

    currentTime := GetCurrentTime()
    if h.Timestamp < currentTime {
        ratchet.Timestamp = bumpKeyID(currentTime)
    } else {
        ratchet.Timestamp = h.Timestamp + 1
    }

    ratchet.Key = hrKeyBytes

    _, err = ratchet.GetKeyID()
    if err != nil {
        return nil, err
    }

    return ratchet, nil
}

// GetCurrentKeyForGroup retrieves a key ID for given group ID
// (with an assumption that key ids are shared in the group, and
// at any given time there is a single key used)
func (s *sqlitePersistence) GetCurrentKeyForGroup(groupID []byte) (*HashRatchetKeyCompatibility, error) {
    ratchet := &HashRatchetKeyCompatibility{
        GroupID: groupID,
    }

    stmt, err := s.DB.Prepare(`SELECT key_id, key_timestamp, key
                   FROM hash_ratchet_encryption
                     WHERE group_id = ? order by key_timestamp desc limit 1`)
    if err != nil {
        return nil, err
    }
    defer stmt.Close()

    var keyID, key []byte
    var timestamp uint64
    err = stmt.QueryRow(groupID).Scan(&keyID, &timestamp, &key)

    switch err {
    case sql.ErrNoRows:
        return ratchet, nil
    case nil:
        ratchet.Key = key
        ratchet.Timestamp = timestamp
        _, err = ratchet.GetKeyID()
        if err != nil {
            return nil, err
        }
        return ratchet, nil
    default:
        return nil, err
    }
}

// GetKeysForGroup retrieves all key IDs for given group ID
func (s *sqlitePersistence) GetKeysForGroup(groupID []byte) ([]*HashRatchetKeyCompatibility, error) {

    var ratchets []*HashRatchetKeyCompatibility
    stmt, err := s.DB.Prepare(`SELECT key_id, key_timestamp, key
                   FROM hash_ratchet_encryption
                     WHERE group_id = ? order by key_timestamp desc`)
    if err != nil {
        return nil, err
    }
    defer stmt.Close()

    rows, err := stmt.Query(groupID)
    if err != nil {
        return nil, err
    }

    for rows.Next() {
        ratchet := &HashRatchetKeyCompatibility{GroupID: groupID}
        err := rows.Scan(&ratchet.keyID, &ratchet.Timestamp, &ratchet.Key)
        if err != nil {
            return nil, err
        }
        ratchets = append(ratchets, ratchet)
    }

    return ratchets, nil
}

// SaveHashRatchetKeyHash saves a hash ratchet key cache data
func (s *sqlitePersistence) SaveHashRatchetKeyHash(
    ratchet *HashRatchetKeyCompatibility,
    hash []byte,
    seqNo uint32,
) error {

    stmt, err := s.DB.Prepare(`INSERT INTO hash_ratchet_encryption_cache(group_id, key_id, hash, seq_no)
           VALUES(?, ?, ?, ?)`)
    if err != nil {
        return err
    }
    defer stmt.Close()

    keyID, err := ratchet.GetKeyID()
    if err != nil {
        return err
    }

    _, err = stmt.Exec(ratchet.GroupID, keyID, hash, seqNo)

    return err
}

// SaveHashRatchetKey saves a hash ratchet key
func (s *sqlitePersistence) SaveHashRatchetKey(ratchet *HashRatchetKeyCompatibility) error {
    stmt, err := s.DB.Prepare(`INSERT INTO hash_ratchet_encryption(group_id, key_id, key_timestamp, deprecated_key_id, key)
           VALUES(?,?,?,?,?)`)
    if err != nil {
        return err
    }
    defer stmt.Close()

    keyID, err := ratchet.GetKeyID()
    if err != nil {
        return err
    }

    _, err = stmt.Exec(ratchet.GroupID, keyID, ratchet.Timestamp, ratchet.DeprecatedKeyID(), ratchet.Key)

    return err
}

func (s *sqlitePersistence) GetHashRatchetKeyByID(keyID []byte) (*HashRatchetKeyCompatibility, error) {
    ratchet := &HashRatchetKeyCompatibility{
        keyID: keyID,
    }

    err := s.DB.QueryRow(`
        SELECT group_id, key_timestamp, key
        FROM hash_ratchet_encryption
        WHERE key_id = ?`, keyID).Scan(&ratchet.GroupID, &ratchet.Timestamp, &ratchet.Key)

    if err != nil {
        if err == sql.ErrNoRows {
            return nil, nil
        }
        return nil, err
    }

    return ratchet, nil
}