waku-org/go-waku

View on GitHub
waku/v2/rendezvous/db.go

Summary

Maintainability
B
5 hrs
Test Coverage
D
60%
package rendezvous

import (
    "bytes"
    "context"
    "crypto/rand"
    "crypto/sha256"
    "database/sql"
    "encoding/binary"
    "errors"
    "fmt"
    "time"

    "github.com/libp2p/go-libp2p/core/peer"
    dbi "github.com/waku-org/go-libp2p-rendezvous/db"
    "github.com/waku-org/go-waku/waku/v2/utils"
    "go.uber.org/zap"
)

type DB struct {
    db     *sql.DB
    logger *zap.Logger

    insertPeerRegistration     *sql.Stmt
    deletePeerRegistrations    *sql.Stmt
    deletePeerRegistrationsNs  *sql.Stmt
    countPeerRegistrations     *sql.Stmt
    selectPeerRegistrations    *sql.Stmt
    selectPeerRegistrationsNS  *sql.Stmt
    selectPeerRegistrationsC   *sql.Stmt
    selectPeerRegistrationsNSC *sql.Stmt
    deleteExpiredRegistrations *sql.Stmt
    getCounter                 *sql.Stmt

    nonce []byte

    cancel func()
}

func NewDB(db *sql.DB, logger *zap.Logger) *DB {
    rdb := &DB{
        db:     db,
        logger: logger.Named("rendezvous/db"),
    }

    return rdb
}

func (db *DB) Start(ctx context.Context) error {
    err := db.loadNonce()
    if err != nil {
        db.Close()
        return err
    }

    err = db.prepareStmts()
    if err != nil {
        db.Close()
        return err
    }

    bgctx, cancel := context.WithCancel(ctx)
    db.cancel = cancel
    go db.background(bgctx)

    return nil
}

func (db *DB) Close() error {
    db.cancel()
    return db.db.Close()
}

func (db *DB) insertNonce() error {
    nonce := make([]byte, 32)
    _, err := rand.Read(nonce)
    if err != nil {
        return err
    }

    _, err = db.db.Exec("INSERT INTO nonce VALUES (?)", nonce)
    if err != nil {
        return err
    }

    db.nonce = nonce
    return nil
}

func (db *DB) loadNonce() error {
    var nonce []byte
    row := db.db.QueryRow("SELECT nonce FROM nonce")
    err := row.Scan(&nonce)
    if err != nil {
        if errors.Is(err, sql.ErrNoRows) {
            return db.insertNonce()
        }
        return err
    }
    db.nonce = nonce
    return nil
}

func (db *DB) prepareStmts() error {
    stmt, err := db.db.Prepare("INSERT INTO registrations VALUES (NULL, ?, ?, ?, ?)")
    if err != nil {
        return err
    }
    db.insertPeerRegistration = stmt

    stmt, err = db.db.Prepare("DELETE FROM registrations WHERE peer = ?")
    if err != nil {
        return err
    }
    db.deletePeerRegistrations = stmt

    stmt, err = db.db.Prepare("DELETE FROM registrations WHERE peer = ? AND ns = ?")
    if err != nil {
        return err
    }
    db.deletePeerRegistrationsNs = stmt

    stmt, err = db.db.Prepare("SELECT COUNT(*) FROM registrations WHERE peer = ?")
    if err != nil {
        return err
    }
    db.countPeerRegistrations = stmt

    stmt, err = db.db.Prepare("SELECT * FROM registrations WHERE expire > ? LIMIT ?")
    if err != nil {
        return err
    }
    db.selectPeerRegistrations = stmt

    stmt, err = db.db.Prepare("SELECT * FROM registrations WHERE ns = ? AND expire > ? LIMIT ?")
    if err != nil {
        return err
    }
    db.selectPeerRegistrationsNS = stmt

    stmt, err = db.db.Prepare("SELECT * FROM registrations WHERE counter > ? AND expire > ? LIMIT ?")
    if err != nil {
        return err
    }
    db.selectPeerRegistrationsC = stmt

    stmt, err = db.db.Prepare("SELECT * FROM registrations WHERE counter > ? AND ns = ? AND expire > ? LIMIT ?")
    if err != nil {
        return err
    }
    db.selectPeerRegistrationsNSC = stmt

    stmt, err = db.db.Prepare("DELETE FROM registrations WHERE expire < ?")
    if err != nil {
        return err
    }
    db.deleteExpiredRegistrations = stmt

    stmt, err = db.db.Prepare("SELECT MAX(counter) FROM registrations")
    if err != nil {
        return err
    }
    db.getCounter = stmt

    return nil
}

func (db *DB) Register(p peer.ID, ns string, signedPeerRecord []byte, ttl int) (uint64, error) {
    pid := p.String()
    expire := time.Now().Unix() + int64(ttl)

    tx, err := db.db.Begin()
    if err != nil {
        return 0, err
    }

    delOld := tx.Stmt(db.deletePeerRegistrationsNs)
    insertNew := tx.Stmt(db.insertPeerRegistration)
    getCounter := tx.Stmt(db.getCounter)

    _, err = delOld.Exec(pid, ns)
    if err != nil {
        _ = tx.Rollback()
        return 0, err
    }

    _, err = insertNew.Exec(pid, ns, expire, signedPeerRecord)
    if err != nil {
        _ = tx.Rollback()
        return 0, err
    }

    var counter uint64
    row := getCounter.QueryRow()
    err = row.Scan(&counter)
    if err != nil {
        _ = tx.Rollback()
        return 0, err
    }

    err = tx.Commit()
    return counter, err
}

func (db *DB) CountRegistrations(p peer.ID) (int, error) {
    pid := p.String()

    row := db.countPeerRegistrations.QueryRow(pid)

    var count int
    err := row.Scan(&count)

    return count, err
}

func (db *DB) Unregister(p peer.ID, ns string) error {
    pid := p.String()

    var err error

    if ns == "" {
        _, err = db.deletePeerRegistrations.Exec(pid)
    } else {
        _, err = db.deletePeerRegistrationsNs.Exec(pid, ns)
    }

    return err
}

func (db *DB) Discover(ns string, cookie []byte, limit int) ([]dbi.RegistrationRecord, []byte, error) {
    now := time.Now().Unix()

    var (
        counter int64
        rows    *sql.Rows
        err     error
    )

    if cookie != nil {
        counter, err = unpackCookie(cookie)
        if err != nil {
            db.logger.Error("unpacking cookie", zap.Error(err))
            return nil, nil, err
        }
    }

    if counter > 0 {
        if ns == "" {
            rows, err = db.selectPeerRegistrationsC.Query(counter, now, limit)
        } else {
            rows, err = db.selectPeerRegistrationsNSC.Query(counter, ns, now, limit)
        }
    } else {
        if ns == "" {
            rows, err = db.selectPeerRegistrations.Query(now, limit)
        } else {
            rows, err = db.selectPeerRegistrationsNS.Query(ns, now, limit)
        }
    }

    if err != nil {
        db.logger.Error("query", zap.Error(err))
        return nil, nil, err
    }

    defer rows.Close()

    regs := make([]dbi.RegistrationRecord, 0, limit)
    for rows.Next() {
        var (
            reg              dbi.RegistrationRecord
            rid              string
            rns              string
            expire           int64
            signedPeerRecord []byte
            p                peer.ID
        )

        err = rows.Scan(&counter, &rid, &rns, &expire, &signedPeerRecord)
        if err != nil {
            db.logger.Error("row scan error", zap.Error(err))
            return nil, nil, err
        }

        p, err = peer.Decode(rid)
        if err != nil {
            db.logger.Error("error decoding peer id", zap.Error(err))
            continue
        }

        reg.Id = p
        reg.SignedPeerRecord = signedPeerRecord
        reg.Ttl = int(expire - now)

        if ns == "" {
            reg.Ns = rns
        }

        regs = append(regs, reg)
    }

    err = rows.Err()
    if err != nil {
        return nil, nil, err
    }

    if counter > 0 {
        cookie = packCookie(counter, ns, db.nonce)
    }

    return regs, cookie, nil
}

func (db *DB) ValidCookie(ns string, cookie []byte) bool {
    return validCookie(cookie, ns, db.nonce)
}

func (db *DB) background(ctx context.Context) {
    defer utils.LogOnPanic()
    for {
        db.cleanupExpired()

        select {
        case <-time.After(15 * time.Minute):
        case <-ctx.Done():
            return
        }
    }
}

func (db *DB) cleanupExpired() {
    now := time.Now().Unix()
    _, err := db.deleteExpiredRegistrations.Exec(now)
    if err != nil {
        db.logger.Error("deleting expired registrations", zap.Error(err))
    }
}

// cookie: counter:SHA256(nonce + ns + counter)
func packCookie(counter int64, ns string, nonce []byte) []byte {
    cbits := make([]byte, 8)
    binary.BigEndian.PutUint64(cbits, uint64(counter))

    hash := sha256.New()
    _, err := hash.Write(nonce)
    if err != nil {
        panic(err)
    }
    _, err = hash.Write([]byte(ns))
    if err != nil {
        panic(err)
    }
    _, err = hash.Write(cbits)
    if err != nil {
        panic(err)
    }

    return hash.Sum(cbits)
}

func unpackCookie(cookie []byte) (int64, error) {
    if len(cookie) < 8 {
        return 0, fmt.Errorf("bad packed cookie: not enough bytes: %v", cookie)
    }

    counter := binary.BigEndian.Uint64(cookie[:8])
    return int64(counter), nil
}

func validCookie(cookie []byte, ns string, nonce []byte) bool {
    if len(cookie) != 40 {
        return false
    }

    cbits := cookie[:8]
    hash := sha256.New()
    _, err := hash.Write(nonce)
    if err != nil {
        panic(err)
    }
    _, err = hash.Write([]byte(ns))
    if err != nil {
        panic(err)
    }
    _, err = hash.Write(cbits)
    if err != nil {
        panic(err)
    }
    hbits := hash.Sum(nil)

    return bytes.Equal(cookie[8:], hbits)
}