bnkamalesh/verifier

View on GitHub
stores/postgres.go

Summary

Maintainability
A
1 hr
Test Coverage
package stores

import (
    "context"
    "database/sql"
    "fmt"
    "strings"
    "time"

    "github.com/Masterminds/squirrel"
    "github.com/fatih/structs"
    "github.com/jackc/pgx/v4/pgxpool"

    "github.com/bnkamalesh/verifier"
)

// structToMapStringWithTag converts a struct to map[string]interface{}, where keys are fetched from
// provided tag values
func structToMapStringWithTag(tag string, source interface{}) (map[string]interface{}, error) {
    converter := structs.New(source)
    converter.TagName = tag
    return converter.Map(), nil
}

// PostgresConfig holds all configuration required for postgres
type PostgresConfig struct {
    Host      string `json:"host,omitempty"`
    Port      string `json:"port,omitempty"`
    Username  string `json:"username,omitempty"`
    Password  string `json:"password,omitempty"`
    StoreName string `json:"storeName,omitempty"`
    PoolSize  int    `json:"poolSize,omitempty"`
    SSLMode   string `json:"sslMode,omitempty"`

    DialTimeoutSecs  time.Duration `json:"dialTimeoutSecs,omitempty"`
    ReadTimeoutSecs  time.Duration `json:"readTimeoutSecs,omitempty"`
    WriteTimeoutSecs time.Duration `json:"writeTimeoutSecs,omitempty"`
    IdleTimeoutSecs  time.Duration `json:"idleTimeoutSecs,omitempty"`

    TableName string `json:"tableName,omitempty"`
}

// ConnURL returns the connection URL
func (pgcfg *PostgresConfig) ConnURL() string {
    sslMode := strings.TrimSpace(pgcfg.SSLMode)
    if sslMode == "" {
        sslMode = "disable"
    }

    return fmt.Sprintf(
        "postgres://%s:%s@%s:%s/%s?sslmode=%s",
        pgcfg.Username,
        pgcfg.Password,
        pgcfg.Host,
        pgcfg.Port,
        pgcfg.StoreName,
        sslMode,
    )
}

// Postgres implements the verifier store functions using Postgresql as the persistence layer
type Postgres struct {
    cfg       *PostgresConfig
    tableName string
    pqdriver  *pgxpool.Pool
    qbuilder  squirrel.StatementBuilderType
}

func ctxWithTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
    if ctx == nil {
        ctx = context.Background()
    }
    return context.WithTimeout(
        ctx,
        timeout,
    )
}

// Create creates a new entry of verifier request
func (pgs *Postgres) Create(req *verifier.Request) (*verifier.Request, error) {
    reqmap, err := structToMapStringWithTag("json", req)
    if err != nil {
        return nil, err
    }

    query, args, err := pgs.qbuilder.Insert(pgs.tableName).SetMap(reqmap).ToSql()
    if err != nil {
        return nil, err
    }

    ctx, _ := ctxWithTimeout(nil, pgs.cfg.WriteTimeoutSecs)
    _, err = pgs.pqdriver.Exec(ctx, query, args...)
    if err != nil {
        return nil, err
    }

    return req, nil
}

// ReadLastPending reads the last pending verification request of the commtype + recipient
func (pgs *Postgres) ReadLastPending(ctype verifier.CommType, recipient string) (*verifier.Request, error) {
    query, args, err := pgs.qbuilder.Select(
        "id",
        "type",
        "sender",
        "recipient",
        "data",
        "secret",
        "secretExpiry",
        "attempts",
        "commStatus",
        "status",
        "createdAt",
        "updatedAt",
    ).From(
        pgs.tableName,
    ).OrderBy(
        "autoID DESC",
    ).Limit(
        1,
    ).Where(squirrel.Eq{
        "type":   ctype,
        "status": verifier.VerStatusPending,
    }).ToSql()
    if err != nil {
        return nil, err
    }

    ctx, _ := ctxWithTimeout(nil, pgs.cfg.ReadTimeoutSecs)
    row := pgs.pqdriver.QueryRow(
        ctx,
        query,
        args...,
    )

    req := &verifier.Request{
        SecretExpiry: new(time.Time),
        CreatedAt:    new(time.Time),
        UpdatedAt:    new(time.Time),
        Data:         map[string]string{},
        CommStatus:   make([]verifier.CommStatus, 0, 10),
    }

    id := new(sql.NullString)
    commtype := new(sql.NullString)
    sender := new(sql.NullString)
    storedRecipient := new(sql.NullString)
    secret := new(sql.NullString)
    attempts := new(sql.NullInt32)

    err = row.Scan(
        id,
        commtype,
        sender,
        storedRecipient,
        &req.Data,
        secret,
        req.SecretExpiry,
        attempts,
        &req.CommStatus,
        &req.Status,
        req.CreatedAt,
        req.UpdatedAt,
    )
    if err != nil {
        return nil, err
    }

    req.ID = id.String
    req.Type = verifier.CommType(commtype.String)
    req.Sender = sender.String
    req.Recipient = storedRecipient.String
    req.Secret = secret.String
    req.Attempts = int(attempts.Int32)
    req.Sender = sender.String

    return req, nil
}

// Update updates a verification request for the given verification ID & the payload
func (pgs *Postgres) Update(verID string, req *verifier.Request) (*verifier.Request, error) {
    vermap, err := structToMapStringWithTag("json", req)
    if err != nil {
        return nil, err
    }

    query, args, err := pgs.qbuilder.Update(
        pgs.tableName,
    ).SetMap(
        vermap,
    ).Where(
        squirrel.Eq{"id": verID},
    ).ToSql()
    if err != nil {
        return nil, err
    }

    _, err = pgs.pqdriver.Exec(context.Background(), query, args...)
    if err != nil {
        return nil, err
    }

    return req, nil
}

// NewPostgres returns a new instance of Postgres with all the required fields initialized
func NewPostgres(cfg *PostgresConfig) (*Postgres, error) {
    poolcfg, err := pgxpool.ParseConfig(cfg.ConnURL())
    if err != nil {
        return nil, err
    }

    poolcfg.MaxConnLifetime = cfg.IdleTimeoutSecs
    poolcfg.MaxConns = int32(cfg.PoolSize)

    pool, err := pgxpool.ConnectConfig(context.Background(), poolcfg)
    if err != nil {
        return nil, err
    }

    pg := &Postgres{
        cfg:       cfg,
        tableName: cfg.TableName,
        pqdriver:  pool,
        qbuilder:  squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar),
    }

    return pg, nil
}