casey-chow/tigertrade

View on GitHub
server/models/users.go

Summary

Maintainability
A
3 hrs
Test Coverage
package models

import (
    "database/sql"
    sq "github.com/Masterminds/squirrel"
    log "github.com/Sirupsen/logrus"
    "github.com/guregu/null"
)

// A User is a record type storing a row of the users table
type User struct {
    KeyID                int       `json:"keyId"`
    CreationDate         null.Time `json:"creationDate"`
    LastModificationDate null.Time `json:"lastModificationDate"`
    NetID                string    `json:"netId"`
}

// GetOrCreateUser makes sure the netID exists in the db, creating it if it doesn't already.
// Security Note: DO NOT allow user-generated data into this function. This assumes the netID is from CAS
func GetOrCreateUser(db *sql.DB, netID string) (*User, error) {
    user, err := GetUser(db, netID)
    if err == nil {
        return user, nil
    }
    if err != sql.ErrNoRows {
        log.
            WithField("netID", netID).
            WithError(err).
            Error("error while getting user")
        return nil, err
    }

    log.
        WithField("netID", netID).
        Print("creating user")
    insert := psql.Insert("users").
        Columns("net_id").
        Values(netID)

    _, err = insert.RunWith(db).Exec()
    if err != nil {
        return nil, err
    }

    return GetUser(db, netID)
}

// GetUser gets the specified user. If user does not exist, returns an error
func GetUser(db *sql.DB, netID string) (*User, error) {
    query := psql.
        Select("key_id", "net_id", "creation_date", "last_modification_date").
        From("users").
        Where(sq.Eq{"net_id": netID}).
        Limit(1)

    user := new(User)
    err := query.RunWith(db).
        QueryRow().
        Scan(
            &user.KeyID,
            &user.NetID,
            &user.CreationDate,
            &user.LastModificationDate,
        )
    if err != nil {
        return nil, err
    }

    return user, nil
}

// GetUserByID gets the specified user. If user does not exist, returns an error
func GetUserByID(db *sql.DB, id int) (*User, error) {
    query := psql.
        Select("key_id", "net_id", "creation_date", "last_modification_date").
        From("users").
        Where(sq.Eq{"key_id": id}).
        Limit(1)

    user := new(User)
    err := query.RunWith(db).
        QueryRow().
        Scan(
            &user.KeyID,
            &user.NetID,
            &user.CreationDate,
            &user.LastModificationDate,
        )
    if err != nil {
        return nil, err
    }

    return user, nil
}