status-im/status-go

View on GitHub
protocol/encryption/sharedsecret/persistence.go

Summary

Maintainability
A
0 mins
Test Coverage
C
75%
package sharedsecret

import (
    "database/sql"
    "strings"
)

type Response struct {
    secret          []byte
    installationIDs map[string]bool
}

type sqlitePersistence struct {
    db *sql.DB
}

func newSQLitePersistence(db *sql.DB) *sqlitePersistence {
    return &sqlitePersistence{db: db}
}

func (s *sqlitePersistence) Add(identity []byte, secret []byte, installationID string) error {
    tx, err := s.db.Begin()
    if err != nil {
        return err
    }

    insertSecretStmt, err := tx.Prepare("INSERT INTO secrets(identity, secret) VALUES (?, ?)")
    if err != nil {
        _ = tx.Rollback()
        return err
    }
    defer insertSecretStmt.Close()

    _, err = insertSecretStmt.Exec(identity, secret)
    if err != nil {
        _ = tx.Rollback()
        return err
    }

    insertInstallationIDStmt, err := tx.Prepare("INSERT INTO secret_installation_ids(id, identity_id) VALUES (?, ?)")
    if err != nil {
        _ = tx.Rollback()
        return err
    }
    defer insertInstallationIDStmt.Close()

    _, err = insertInstallationIDStmt.Exec(installationID, identity)
    if err != nil {
        _ = tx.Rollback()
        return err
    }
    return tx.Commit()
}

func (s *sqlitePersistence) Get(identity []byte, installationIDs []string) (*Response, error) {
    response := &Response{
        installationIDs: make(map[string]bool),
    }
    args := make([]interface{}, len(installationIDs)+1)
    args[0] = identity
    for i, installationID := range installationIDs {
        args[i+1] = installationID
    }

    /* #nosec */
    query := `SELECT secret, id
            FROM secrets t
            JOIN
                  secret_installation_ids tid
              ON t.identity = tid.identity_id
              WHERE
                  t.identity = ?
              AND
                  tid.id IN (?` + strings.Repeat(",?", len(installationIDs)-1) + `)`

    rows, err := s.db.Query(query, args...)
    if err != nil && err != sql.ErrNoRows {
        return nil, err
    }
    defer rows.Close()

    for rows.Next() {
        var installationID string
        var secret []byte
        err = rows.Scan(&secret, &installationID)
        if err != nil {
            return nil, err
        }

        response.secret = secret
        response.installationIDs[installationID] = true
    }

    return response, nil
}

func (s *sqlitePersistence) All() ([][][]byte, error) {
    query := "SELECT identity, secret FROM secrets"

    var secrets [][][]byte

    rows, err := s.db.Query(query)
    if err != nil {
        return nil, err
    }
    defer rows.Close()

    for rows.Next() {
        var secret []byte
        var identity []byte
        err = rows.Scan(&identity, &secret)
        if err != nil {
            return nil, err
        }

        secrets = append(secrets, [][]byte{identity, secret})
    }

    return secrets, nil
}