status-im/status-go

View on GitHub
protocol/common/raw_messages_persistence.go

Summary

Maintainability
A
0 mins
Test Coverage
B
87%
package common

import (
    "bytes"
    "context"
    "crypto/ecdsa"
    "database/sql"
    "encoding/gob"
    "errors"
    "strings"
    "time"

    "github.com/status-im/status-go/eth-node/crypto"
    "github.com/status-im/status-go/eth-node/types"
    "github.com/status-im/status-go/protocol/protobuf"
)

type RawMessageConfirmation struct {
    // DataSyncID is the ID of the datasync message sent
    DataSyncID []byte
    // MessageID is the message id of the message
    MessageID []byte
    // PublicKey is the compressed receiver public key
    PublicKey []byte
    // ConfirmedAt is the unix timestamp in seconds of when the message was confirmed
    ConfirmedAt int64
}

type RawMessagesPersistence struct {
    db *sql.DB
}

func NewRawMessagesPersistence(db *sql.DB) *RawMessagesPersistence {
    return &RawMessagesPersistence{db: db}
}

func (db RawMessagesPersistence) SaveRawMessage(message *RawMessage) error {
    tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
    if err != nil {
        return err
    }
    defer func() {
        if err == nil {
            err = tx.Commit()
            return
        }
        // don't shadow original error
        _ = tx.Rollback()
    }()

    var pubKeys [][]byte
    for _, pk := range message.Recipients {
        pubKeys = append(pubKeys, crypto.CompressPubkey(pk))
    }
    // Encode recipients
    var encodedRecipients bytes.Buffer
    encoder := gob.NewEncoder(&encodedRecipients)

    if err := encoder.Encode(pubKeys); err != nil {
        return err
    }

    // If the message is not sent, we check whether there's a record
    // in the database already and preserve the state
    if !message.Sent {
        oldMessage, err := db.rawMessageByID(tx, message.ID)
        if err != nil && err != sql.ErrNoRows {
            return err
        }
        if oldMessage != nil {
            message.Sent = oldMessage.Sent
        }
    }
    var sender []byte
    if message.Sender != nil {
        sender = crypto.FromECDSA(message.Sender)
    }
    _, err = tx.Exec(`
         INSERT INTO
         raw_messages
         (
           id,
           local_chat_id,
           last_sent,
           send_count,
           sent,
           message_type,
           recipients,
           skip_encryption,
           send_push_notification,
           skip_group_message_wrap,
           send_on_personal_topic,
           payload,
           sender,
           community_id,
           resend_type,
           pubsub_topic,
           hash_ratchet_group_id,
           community_key_ex_msg_type,
           resend_method
        )
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
        message.ID,
        message.LocalChatID,
        message.LastSent,
        message.SendCount,
        message.Sent,
        message.MessageType,
        encodedRecipients.Bytes(),
        message.SkipEncryptionLayer,
        message.SendPushNotification,
        message.SkipGroupMessageWrap,
        message.SendOnPersonalTopic,
        message.Payload,
        sender,
        message.CommunityID,
        message.ResendType,
        message.PubsubTopic,
        message.HashRatchetGroupID,
        message.CommunityKeyExMsgType,
        message.ResendMethod,
    )
    return err
}

func (db RawMessagesPersistence) RawMessageByID(id string) (*RawMessage, error) {
    tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
    if err != nil {
        return nil, err
    }
    defer func() {
        if err == nil {
            err = tx.Commit()
            return
        }
        // don't shadow original error
        _ = tx.Rollback()
    }()

    return db.rawMessageByID(tx, id)
}

func (db RawMessagesPersistence) rawMessageByID(tx *sql.Tx, id string) (*RawMessage, error) {
    var rawPubKeys [][]byte
    var encodedRecipients []byte
    var skipGroupMessageWrap, sendOnPersonalTopic sql.NullBool
    var sender []byte
    message := &RawMessage{}

    err := tx.QueryRow(`
            SELECT
              id,
              local_chat_id,
              last_sent,
              send_count,
              sent,
              message_type,
              recipients,
              skip_encryption,
              send_push_notification,
              skip_group_message_wrap,
              send_on_personal_topic,
              payload,
              sender,
              community_id,
              resend_type,
              pubsub_topic,
              hash_ratchet_group_id,
              community_key_ex_msg_type,
              resend_method
            FROM
                raw_messages
            WHERE
                id = ?`,
        id,
    ).Scan(
        &message.ID,
        &message.LocalChatID,
        &message.LastSent,
        &message.SendCount,
        &message.Sent,
        &message.MessageType,
        &encodedRecipients,
        &message.SkipEncryptionLayer,
        &message.SendPushNotification,
        &skipGroupMessageWrap,
        &sendOnPersonalTopic,
        &message.Payload,
        &sender,
        &message.CommunityID,
        &message.ResendType,
        &message.PubsubTopic,
        &message.HashRatchetGroupID,
        &message.CommunityKeyExMsgType,
        &message.ResendMethod,
    )
    if err != nil {
        return nil, err
    }

    if encodedRecipients != nil {
        // Restore recipients
        decoder := gob.NewDecoder(bytes.NewBuffer(encodedRecipients))
        err = decoder.Decode(&rawPubKeys)
        if err != nil {
            return nil, err
        }
        for _, pkBytes := range rawPubKeys {
            pubkey, err := crypto.DecompressPubkey(pkBytes)
            if err != nil {
                return nil, err
            }
            message.Recipients = append(message.Recipients, pubkey)
        }
    }

    if skipGroupMessageWrap.Valid {
        message.SkipGroupMessageWrap = skipGroupMessageWrap.Bool
    }

    if sendOnPersonalTopic.Valid {
        message.SendOnPersonalTopic = sendOnPersonalTopic.Bool
    }

    if sender != nil {
        message.Sender, err = crypto.ToECDSA(sender)
        if err != nil {
            return nil, err
        }
    }
    return message, nil
}

func (db RawMessagesPersistence) RawMessagesIDsByType(t protobuf.ApplicationMetadataMessage_Type) ([]string, error) {
    ids := []string{}

    rows, err := db.db.Query(`
            SELECT
              id
            FROM
                raw_messages
            WHERE
            message_type = ?`,
        t)
    if err != nil {
        return ids, err
    }
    defer rows.Close()

    for rows.Next() {
        var id string
        if err := rows.Scan(&id); err != nil {
            return ids, err
        }
        ids = append(ids, id)
    }

    return ids, nil
}

// MarkAsConfirmed marks all the messages with dataSyncID as confirmed and returns
// the messageIDs that can be considered confirmed.
// If atLeastOne is set it will return messageid if at least once of the messages
// sent has been confirmed
func (db RawMessagesPersistence) MarkAsConfirmed(dataSyncID []byte, atLeastOne bool) (messageID types.HexBytes, err error) {
    tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
    if err != nil {
        return nil, err
    }
    defer func() {
        if err == nil {
            err = tx.Commit()
            return
        }
        // don't shadow original error
        _ = tx.Rollback()
    }()

    confirmedAt := time.Now().Unix()
    _, err = tx.Exec(`UPDATE raw_message_confirmations SET confirmed_at = ? WHERE datasync_id = ? AND confirmed_at = 0`, confirmedAt, dataSyncID)
    if err != nil {
        return
    }

    // Select any tuple that has a message_id with a datasync_id = ? and that has just been confirmed
    rows, err := tx.Query(`SELECT message_id,confirmed_at FROM raw_message_confirmations WHERE message_id = (SELECT message_id FROM raw_message_confirmations WHERE datasync_id = ? LIMIT 1)`, dataSyncID)
    if err != nil {
        return
    }
    defer rows.Close()

    confirmedResult := true

    for rows.Next() {
        var confirmedAt int64
        err = rows.Scan(&messageID, &confirmedAt)
        if err != nil {
            return
        }
        confirmed := confirmedAt > 0

        if atLeastOne && confirmed {
            // We return, as at least one was confirmed
            return
        }

        confirmedResult = confirmedResult && confirmed
    }

    if !confirmedResult {
        messageID = nil
        return
    }

    return
}

func (db RawMessagesPersistence) InsertPendingConfirmation(confirmation *RawMessageConfirmation) error {

    _, err := db.db.Exec(`INSERT INTO raw_message_confirmations
         (datasync_id, message_id, public_key)
         VALUES
         (?,?,?)`,
        confirmation.DataSyncID,
        confirmation.MessageID,
        confirmation.PublicKey,
    )
    return err
}

func (db RawMessagesPersistence) SaveHashRatchetMessage(groupID []byte, keyID []byte, m *types.Message) error {
    _, err := db.db.Exec(`INSERT INTO hash_ratchet_encrypted_messages(hash, sig, TTL, timestamp, topic, payload, dst, p2p, padding, group_id, key_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, m.Hash, m.Sig, m.TTL, m.Timestamp, types.TopicTypeToByteArray(m.Topic), m.Payload, m.Dst, m.P2P, m.Padding, groupID, keyID)
    return err
}

func (db RawMessagesPersistence) GetHashRatchetMessages(keyID []byte) ([]*types.Message, error) {
    var messages []*types.Message

    rows, err := db.db.Query(`SELECT hash, sig, TTL, timestamp, topic, payload, dst, p2p, padding FROM hash_ratchet_encrypted_messages WHERE key_id = ?`, keyID)
    if err != nil {
        return nil, err
    }

    for rows.Next() {
        var topic []byte
        message := &types.Message{}

        err := rows.Scan(&message.Hash, &message.Sig, &message.TTL, &message.Timestamp, &topic, &message.Payload, &message.Dst, &message.P2P, &message.Padding)
        if err != nil {
            return nil, err
        }

        message.Topic = types.BytesToTopic(topic)
        messages = append(messages, message)
    }

    return messages, nil
}

func (db RawMessagesPersistence) GetHashRatchetMessagesCountForGroup(groupID []byte) (int, error) {
    var count int
    err := db.db.QueryRow(`SELECT count(*) FROM hash_ratchet_encrypted_messages WHERE group_id = ?`, groupID).Scan(&count)
    if err == nil {
        return count, nil
    }
    if errors.Is(err, sql.ErrNoRows) {
        return 0, nil
    }
    return 0, err
}

func (db RawMessagesPersistence) DeleteHashRatchetMessages(ids [][]byte) error {
    if len(ids) == 0 {
        return nil
    }

    idsArgs := make([]interface{}, 0, len(ids))
    for _, id := range ids {
        idsArgs = append(idsArgs, id)
    }
    inVector := strings.Repeat("?, ", len(ids)-1) + "?"

    _, err := db.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE hash IN ("+inVector+")", idsArgs...) // nolint: gosec

    return err
}

func (db *RawMessagesPersistence) DeleteHashRatchetMessagesOlderThan(timestamp int64) error {
    _, err := db.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE timestamp < ?", timestamp)
    return err
}

func (db *RawMessagesPersistence) IsMessageAlreadyCompleted(hash []byte) (bool, error) {
    var alreadyCompleted int
    err := db.db.QueryRow("SELECT COUNT(*) FROM message_segments_completed WHERE hash = ?", hash).Scan(&alreadyCompleted)
    if err != nil {
        return false, err
    }
    return alreadyCompleted > 0, nil
}

func (db *RawMessagesPersistence) SaveMessageSegment(segment *SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error {
    sigPubKeyBlob := crypto.CompressPubkey(sigPubKey)

    _, err := db.db.Exec("INSERT INTO message_segments (hash, segment_index, segments_count, parity_segment_index, parity_segments_count, sig_pub_key, payload, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
        segment.EntireMessageHash, segment.Index, segment.SegmentsCount, segment.ParitySegmentIndex, segment.ParitySegmentsCount, sigPubKeyBlob, segment.Payload, timestamp)

    return err
}

// Get ordered message segments for given hash
func (db *RawMessagesPersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*SegmentMessage, error) {
    sigPubKeyBlob := crypto.CompressPubkey(sigPubKey)

    rows, err := db.db.Query(`
        SELECT
            hash, segment_index, segments_count, parity_segment_index, parity_segments_count, payload
        FROM
            message_segments
        WHERE
            hash = ? AND sig_pub_key = ?
        ORDER BY
            (segments_count = 0) ASC, -- Prioritize segments_count > 0
            segment_index ASC,
            parity_segment_index ASC`,
        hash, sigPubKeyBlob)
    if err != nil {
        return nil, err
    }
    defer rows.Close()

    var segments []*SegmentMessage
    for rows.Next() {
        segment := &SegmentMessage{
            SegmentMessage: &protobuf.SegmentMessage{},
        }
        err := rows.Scan(&segment.EntireMessageHash, &segment.Index, &segment.SegmentsCount, &segment.ParitySegmentIndex, &segment.ParitySegmentsCount, &segment.Payload)
        if err != nil {
            return nil, err
        }
        segments = append(segments, segment)
    }
    err = rows.Err()
    if err != nil {
        return nil, err
    }

    return segments, nil
}

func (db *RawMessagesPersistence) RemoveMessageSegmentsOlderThan(timestamp int64) error {
    _, err := db.db.Exec("DELETE FROM message_segments WHERE timestamp < ?", timestamp)
    return err
}

func (db *RawMessagesPersistence) CompleteMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey, timestamp int64) error {
    tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
    if err != nil {
        return err
    }

    defer func() {
        if err == nil {
            err = tx.Commit()
            return
        }
        // don't shadow original error
        _ = tx.Rollback()
    }()

    sigPubKeyBlob := crypto.CompressPubkey(sigPubKey)

    _, err = tx.Exec("DELETE FROM message_segments WHERE hash = ? AND sig_pub_key = ?", hash, sigPubKeyBlob)
    if err != nil {
        return err
    }

    _, err = tx.Exec("INSERT INTO message_segments_completed (hash, sig_pub_key, timestamp) VALUES (?,?,?)", hash, sigPubKeyBlob, timestamp)
    if err != nil {
        return err
    }

    return err
}

func (db *RawMessagesPersistence) RemoveMessageSegmentsCompletedOlderThan(timestamp int64) error {
    _, err := db.db.Exec("DELETE FROM message_segments_completed WHERE timestamp < ?", timestamp)
    return err
}

func (db RawMessagesPersistence) UpdateRawMessageSent(id string, sent bool) error {
    _, err := db.db.Exec("UPDATE raw_messages SET sent = ? WHERE id = ?", sent, id)
    return err
}

func (db RawMessagesPersistence) UpdateRawMessageLastSent(id string, lastSent uint64) error {
    _, err := db.db.Exec("UPDATE raw_messages SET last_sent = ? WHERE id = ?", lastSent, id)
    return err
}