synapsecns/sanguine

View on GitHub
services/cctp-relayer/db/sql/base/message.go

Summary

Maintainability
A
1 hr
Test Coverage
package base

import (
    "context"
    "database/sql"
    "fmt"

    "github.com/ethereum/go-ethereum/common"
    "gorm.io/gorm/clause"

    "github.com/synapsecns/sanguine/services/cctp-relayer/types"
)

// GetLastBlockNumber gets the last block number that had a message in the database.
func (s Store) GetLastBlockNumber(ctx context.Context, chainID uint32) (uint64, error) {
    var message types.Message
    var lastBlockNumber sql.NullInt64

    dbTx := s.DB().WithContext(ctx).
        Model(&message).
        Where(fmt.Sprintf("%s = ?", OriginChainIDFieldName), chainID).
        Select(fmt.Sprintf("MAX(%s)", BlockNumberFieldName)).
        Find(&lastBlockNumber)
    if dbTx.Error != nil {
        return 0, fmt.Errorf("failed to get last block number: %w", dbTx.Error)
    }

    // explicitly return 0 on nil
    if !lastBlockNumber.Valid {
        return 0, nil
    }

    return uint64(lastBlockNumber.Int64), nil
}

// StoreMessage stores a message in the database.
func (s Store) StoreMessage(ctx context.Context, msg types.Message) error {
    // This one is a bit tricky, what we want to do is insert the message into the database
    // if it hasn't been inserted already and update the status, but only if the status is not stored
    // we'll add an ignore if the status is Created otherwise we'll force an update
    var clauses clause.Expression

    switch msg.State {
    case types.Pending:
        // ignore queries don't work w/ sqlite so we need to adjust this to do nothing
        if s.db.Dialector.Name() == "sqlite" {
            clauses = clause.OnConflict{
                Columns:   []clause.Column{{Name: MessageHashFieldName}},
                DoNothing: true,
            }
        } else {
            clauses = clause.Insert{
                Modifier: "IGNORE",
            }
        }
    case types.Attested:
        clauses = clause.OnConflict{
            Columns: []clause.Column{{Name: MessageHashFieldName}},
            DoUpdates: clause.AssignmentColumns([]string{
                StateFieldName,
                AttestationFieldName,
            }),
        }
    case types.Submitted:
        clauses = clause.OnConflict{
            Columns: []clause.Column{{Name: MessageHashFieldName}},
            DoUpdates: clause.AssignmentColumns([]string{
                DestTxHashFieldName,
                StateFieldName,
                NonceFieldName,
            }),
        }
    case types.Complete:
        clauses = clause.OnConflict{
            Columns: []clause.Column{{Name: MessageHashFieldName}},
            DoUpdates: clause.AssignmentColumns([]string{
                DestTxHashFieldName,
                StateFieldName,
            }),
        }
    }

    dbTx := s.DB().WithContext(ctx).Clauses(clauses).Create(&msg)

    if dbTx.Error != nil {
        return fmt.Errorf("failed to store message: %w", dbTx.Error)
    }
    return nil
}

// GetMessagesByState gets messages by state.
func (s Store) GetMessagesByState(ctx context.Context, states ...types.MessageState) ([]types.Message, error) {
    var messages []types.Message

    stateArgs := make([]int, len(states))

    for i := range states {
        stateArgs[i] = int(states[i])
    }

    dbTx := s.DB().WithContext(ctx).
        Where(fmt.Sprintf("%s IN ?", StateFieldName), stateArgs).
        Find(&messages)
    if dbTx.Error != nil {
        return nil, fmt.Errorf("failed to get messages by status: %w", dbTx.Error)
    }

    return messages, nil
}

// GetMessageByOriginHash gets a message by its origin hash.
// TODO: this is actually non-unique, but we only return 1.
func (s Store) GetMessageByOriginHash(ctx context.Context, originHash common.Hash) (*types.Message, error) {
    var message types.Message

    dbTx := s.DB().WithContext(ctx).
        Model(&types.Message{}).
        Where(fmt.Sprintf("%s = ?", OriginTxHashFieldName), originHash.String()).
        First(&message)
    if dbTx.Error != nil {
        return nil, fmt.Errorf("failed to get message by hash: %w", dbTx.Error)
    }

    return &message, nil
}

// GetMessageByRequestID gets a message by its request id.
func (s Store) GetMessageByRequestID(ctx context.Context, requestID string) (*types.Message, error) {
    var message types.Message

    dbTx := s.DB().WithContext(ctx).
        Model(&types.Message{}).
        Where(fmt.Sprintf("%s = ?", RequestIDFieldName), requestID).
        First(&message)
    if dbTx.Error != nil {
        return nil, fmt.Errorf("failed to get message by request id: %w", dbTx.Error)
    }

    return &message, nil
}

// GetMessageByHash gets a message by its message hash.
func (s Store) GetMessageByHash(ctx context.Context, messageHash common.Hash) (*types.Message, error) {
    var message types.Message

    dbTx := s.DB().WithContext(ctx).
        Model(&types.Message{}).
        Where(fmt.Sprintf("%s = ?", MessageHashFieldName), messageHash).
        First(&message)
    if dbTx.Error != nil {
        return nil, fmt.Errorf("failed to get message by hash: %w", dbTx.Error)
    }

    return &message, nil
}