status-im/status-go

View on GitHub
protocol/common/message_segmentation.go

Summary

Maintainability
A
0 mins
Test Coverage
C
79%
package common

import (
    "bytes"
    "math"
    "time"

    "github.com/golang/protobuf/proto"
    "github.com/jinzhu/copier"
    "github.com/klauspost/reedsolomon"
    "github.com/pkg/errors"
    "go.uber.org/zap"

    "github.com/status-im/status-go/eth-node/crypto"
    "github.com/status-im/status-go/eth-node/types"
    "github.com/status-im/status-go/protocol/protobuf"
    v1protocol "github.com/status-im/status-go/protocol/v1"
)

var ErrMessageSegmentsIncomplete = errors.New("message segments incomplete")
var ErrMessageSegmentsAlreadyCompleted = errors.New("message segments already completed")
var ErrMessageSegmentsInvalidCount = errors.New("invalid segments count")
var ErrMessageSegmentsHashMismatch = errors.New("hash of entire payload does not match")
var ErrMessageSegmentsInvalidParity = errors.New("invalid parity segments")

const (
    segmentsParityRate          = 0.125
    segmentsReedsolomonMaxCount = 256
)

type SegmentMessage struct {
    *protobuf.SegmentMessage
}

func (s *SegmentMessage) IsValid() bool {
    return s.SegmentsCount >= 2 || s.ParitySegmentsCount > 0
}

func (s *SegmentMessage) IsParityMessage() bool {
    return s.SegmentsCount == 0 && s.ParitySegmentsCount > 0
}

func (s *MessageSender) segmentMessage(newMessage *types.NewMessage) ([]*types.NewMessage, error) {
    // We set the max message size to 3/4 of the allowed message size, to leave
    // room for segment message metadata.
    newMessages, err := segmentMessage(newMessage, int(s.transport.MaxMessageSize()/4*3))
    s.logger.Debug("message segmented", zap.Int("segments", len(newMessages)))
    return newMessages, err
}

func replicateMessageWithNewPayload(message *types.NewMessage, payload []byte) (*types.NewMessage, error) {
    copy := &types.NewMessage{}
    err := copier.Copy(copy, message)
    if err != nil {
        return nil, err
    }

    copy.Payload = payload
    copy.PowTarget = calculatePoW(payload)
    return copy, nil
}

// Segments message into smaller chunks if the size exceeds segmentSize.
func segmentMessage(newMessage *types.NewMessage, segmentSize int) ([]*types.NewMessage, error) {
    if len(newMessage.Payload) <= segmentSize {
        return []*types.NewMessage{newMessage}, nil
    }

    entireMessageHash := crypto.Keccak256(newMessage.Payload)
    entirePayloadSize := len(newMessage.Payload)

    segmentsCount := int(math.Ceil(float64(entirePayloadSize) / float64(segmentSize)))
    paritySegmentsCount := int(math.Floor(float64(segmentsCount) * segmentsParityRate))

    segmentPayloads := make([][]byte, segmentsCount+paritySegmentsCount)
    segmentMessages := make([]*types.NewMessage, segmentsCount)

    for start, index := 0, 0; start < entirePayloadSize; start += segmentSize {
        end := start + segmentSize
        if end > entirePayloadSize {
            end = entirePayloadSize
        }

        segmentPayload := newMessage.Payload[start:end]
        segmentWithMetadata := &protobuf.SegmentMessage{
            EntireMessageHash: entireMessageHash,
            Index:             uint32(index),
            SegmentsCount:     uint32(segmentsCount),
            Payload:           segmentPayload,
        }
        marshaledSegmentWithMetadata, err := proto.Marshal(segmentWithMetadata)
        if err != nil {
            return nil, err
        }
        segmentMessage, err := replicateMessageWithNewPayload(newMessage, marshaledSegmentWithMetadata)
        if err != nil {
            return nil, err
        }

        segmentPayloads[index] = segmentPayload
        segmentMessages[index] = segmentMessage
        index++
    }

    // Skip reedsolomon if the combined total of data and parity segments exceeds the predefined limit of segmentsReedsolomonMaxCount.
    // Exceeding this limit necessitates shard sizes to be multiples of 64, which are incompatible with clients that do not support forward error correction.
    if paritySegmentsCount == 0 || segmentsCount+paritySegmentsCount > segmentsReedsolomonMaxCount {
        return segmentMessages, nil
    }

    enc, err := reedsolomon.New(segmentsCount, paritySegmentsCount)
    if err != nil {
        return nil, err
    }

    // Align the size of the last segment payload.
    lastSegmentPayload := segmentPayloads[segmentsCount-1]
    segmentPayloads[segmentsCount-1] = make([]byte, segmentSize)
    copy(segmentPayloads[segmentsCount-1], lastSegmentPayload)

    // Make space for parity data.
    for i := segmentsCount; i < segmentsCount+paritySegmentsCount; i++ {
        segmentPayloads[i] = make([]byte, segmentSize)
    }

    err = enc.Encode(segmentPayloads)
    if err != nil {
        return nil, err
    }

    // Create parity messages.
    for i, index := segmentsCount, 0; i < segmentsCount+paritySegmentsCount; i++ {
        segmentWithMetadata := &protobuf.SegmentMessage{
            EntireMessageHash:   entireMessageHash,
            SegmentsCount:       0, // indicates parity message
            ParitySegmentIndex:  uint32(index),
            ParitySegmentsCount: uint32(paritySegmentsCount),
            Payload:             segmentPayloads[i],
        }
        marshaledSegmentWithMetadata, err := proto.Marshal(segmentWithMetadata)
        if err != nil {
            return nil, err
        }
        segmentMessage, err := replicateMessageWithNewPayload(newMessage, marshaledSegmentWithMetadata)
        if err != nil {
            return nil, err
        }

        segmentMessages = append(segmentMessages, segmentMessage)
        index++
    }

    return segmentMessages, nil
}

// SegmentationLayerV1 reconstructs the message only when all segments have been successfully retrieved.
// It lacks the capability to perform forward error correction.
// Kept to test forward compatibility.
func (s *MessageSender) handleSegmentationLayerV1(message *v1protocol.StatusMessage) error {
    logger := s.logger.With(zap.String("site", "handleSegmentationLayerV1")).With(zap.String("hash", types.HexBytes(message.TransportLayer.Hash).String()))

    segmentMessage := &SegmentMessage{
        SegmentMessage: &protobuf.SegmentMessage{},
    }
    err := proto.Unmarshal(message.TransportLayer.Payload, segmentMessage.SegmentMessage)
    if err != nil {
        return errors.Wrap(err, "failed to unmarshal SegmentMessage")
    }

    logger.Debug("handling message segment", zap.String("EntireMessageHash", types.HexBytes(segmentMessage.EntireMessageHash).String()),
        zap.Uint32("Index", segmentMessage.Index), zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount))

    alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash)
    if err != nil {
        return err
    }
    if alreadyCompleted {
        return ErrMessageSegmentsAlreadyCompleted
    }

    if segmentMessage.SegmentsCount < 2 {
        return ErrMessageSegmentsInvalidCount
    }

    err = s.persistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix())
    if err != nil {
        return err
    }

    segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey)
    if err != nil {
        return err
    }

    if len(segments) != int(segmentMessage.SegmentsCount) {
        return ErrMessageSegmentsIncomplete
    }

    // Combine payload
    var entirePayload bytes.Buffer
    for _, segment := range segments {
        _, err := entirePayload.Write(segment.Payload)
        if err != nil {
            return errors.Wrap(err, "failed to write segment payload")
        }
    }

    // Sanity check
    entirePayloadHash := crypto.Keccak256(entirePayload.Bytes())
    if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) {
        return ErrMessageSegmentsHashMismatch
    }

    err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix())
    if err != nil {
        return err
    }

    message.TransportLayer.Payload = entirePayload.Bytes()

    return nil
}

// SegmentationLayerV2 is capable of reconstructing the message from both complete and partial sets of data segments.
// It has capability to perform forward error correction.
func (s *MessageSender) handleSegmentationLayerV2(message *v1protocol.StatusMessage) error {
    logger := s.logger.With(zap.String("site", "handleSegmentationLayerV2")).With(zap.String("hash", types.HexBytes(message.TransportLayer.Hash).String()))

    segmentMessage := &SegmentMessage{
        SegmentMessage: &protobuf.SegmentMessage{},
    }
    err := proto.Unmarshal(message.TransportLayer.Payload, segmentMessage.SegmentMessage)
    if err != nil {
        return errors.Wrap(err, "failed to unmarshal SegmentMessage")
    }

    logger.Debug("handling message segment",
        zap.String("EntireMessageHash", types.HexBytes(segmentMessage.EntireMessageHash).String()),
        zap.Uint32("Index", segmentMessage.Index),
        zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount),
        zap.Uint32("ParitySegmentIndex", segmentMessage.ParitySegmentIndex),
        zap.Uint32("ParitySegmentsCount", segmentMessage.ParitySegmentsCount))

    alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash)
    if err != nil {
        return err
    }
    if alreadyCompleted {
        return ErrMessageSegmentsAlreadyCompleted
    }

    if !segmentMessage.IsValid() {
        return ErrMessageSegmentsInvalidCount
    }

    err = s.persistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix())
    if err != nil {
        return err
    }

    segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey)
    if err != nil {
        return err
    }

    if len(segments) == 0 {
        return errors.New("unexpected state: no segments found after save operation") // This should theoretically never occur.
    }

    firstSegmentMessage := segments[0]
    lastSegmentMessage := segments[len(segments)-1]

    // First segment message must not be a parity message.
    if firstSegmentMessage.IsParityMessage() || len(segments) != int(firstSegmentMessage.SegmentsCount) {
        return ErrMessageSegmentsIncomplete
    }

    payloads := make([][]byte, firstSegmentMessage.SegmentsCount+lastSegmentMessage.ParitySegmentsCount)
    payloadSize := len(firstSegmentMessage.Payload)

    restoreUsingParityData := lastSegmentMessage.IsParityMessage()
    if !restoreUsingParityData {
        for i, segment := range segments {
            payloads[i] = segment.Payload
        }
    } else {
        enc, err := reedsolomon.New(int(firstSegmentMessage.SegmentsCount), int(lastSegmentMessage.ParitySegmentsCount))
        if err != nil {
            return err
        }

        var lastNonParitySegmentPayload []byte
        for _, segment := range segments {
            if !segment.IsParityMessage() {
                if segment.Index == firstSegmentMessage.SegmentsCount-1 {
                    // Ensure last segment is aligned to payload size, as it is required by reedsolomon.
                    payloads[segment.Index] = make([]byte, payloadSize)
                    copy(payloads[segment.Index], segment.Payload)
                    lastNonParitySegmentPayload = segment.Payload
                } else {
                    payloads[segment.Index] = segment.Payload
                }
            } else {
                payloads[firstSegmentMessage.SegmentsCount+segment.ParitySegmentIndex] = segment.Payload
            }
        }

        err = enc.Reconstruct(payloads)
        if err != nil {
            return err
        }

        ok, err := enc.Verify(payloads)
        if err != nil {
            return err
        }
        if !ok {
            return ErrMessageSegmentsInvalidParity
        }

        if lastNonParitySegmentPayload != nil {
            payloads[firstSegmentMessage.SegmentsCount-1] = lastNonParitySegmentPayload // Bring back last segment with original length.
        }
    }

    // Combine payload.
    var entirePayload bytes.Buffer
    for i := 0; i < int(firstSegmentMessage.SegmentsCount); i++ {
        _, err := entirePayload.Write(payloads[i])
        if err != nil {
            return errors.Wrap(err, "failed to write segment payload")
        }
    }

    // Sanity check.
    entirePayloadHash := crypto.Keccak256(entirePayload.Bytes())
    if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) {
        return ErrMessageSegmentsHashMismatch
    }

    err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix())
    if err != nil {
        return err
    }

    message.TransportLayer.Payload = entirePayload.Bytes()

    return nil
}

func (s *MessageSender) CleanupSegments() error {
    monthAgo := time.Now().AddDate(0, -1, 0).Unix()

    err := s.persistence.RemoveMessageSegmentsOlderThan(monthAgo)
    if err != nil {
        return err
    }

    err = s.persistence.RemoveMessageSegmentsCompletedOlderThan(monthAgo)
    if err != nil {
        return err
    }

    return nil
}