Fs02/grimoire

View on GitHub
adapter/sql/bulder.go

Summary

Maintainability
A
2 hrs
Test Coverage
A
100%
package sql

import (
    "bytes"
    "strconv"
    "strings"
    "sync"

    "github.com/Fs02/grimoire"
    "github.com/Fs02/grimoire/c"
)

// UnescapeCharacter disable field escaping when it starts with this character.
var UnescapeCharacter byte = '^'

var fieldCache sync.Map

// Builder defines information of query builder.
type Builder struct {
    config      *Config
    returnField string
    count       int
}

// Find generates query for select.
func (builder *Builder) Find(q grimoire.Query) (string, []interface{}) {
    qs, args := builder.query(q)
    return builder.fields(q.AsDistinct, q.Fields...) + qs, args
}

// Aggregate generates query for aggregation.
func (builder *Builder) Aggregate(q grimoire.Query) (string, []interface{}) {
    qs, args := builder.query(q)
    field := q.AggregateMode + "(" + q.AggregateField + ") AS " + q.AggregateMode

    return builder.fields(false, append(q.GroupFields, field)...) + qs, args
}

func (builder *Builder) query(q grimoire.Query) (string, []interface{}) {
    var (
        buffer bytes.Buffer
        args   []interface{}
    )

    if s := builder.from(q.Collection); s != "" {
        buffer.WriteString(" ")
        buffer.WriteString(s)
    }

    if s, arg := builder.join(q.JoinClause...); s != "" {
        buffer.WriteString(" ")
        buffer.WriteString(s)
        args = append(args, arg...)
    }

    if s, arg := builder.where(q.Condition); s != "" {
        buffer.WriteString(" ")
        buffer.WriteString(s)
        args = append(args, arg...)
    }

    if s := builder.groupBy(q.GroupFields...); s != "" {
        buffer.WriteString(" ")
        buffer.WriteString(s)

        if s, arg := builder.having(q.HavingCondition); s != "" {
            buffer.WriteString(" ")
            buffer.WriteString(s)
            args = append(args, arg...)
        }
    }

    if s := builder.orderBy(q.OrderClause...); s != "" {
        buffer.WriteString(" ")
        buffer.WriteString(s)
    }

    if s := builder.limitOffset(q.LimitResult, q.OffsetResult); s != "" {
        buffer.WriteString(" ")
        buffer.WriteString(s)
    }

    if q.LockClause != "" {
        buffer.WriteString(" ")
        buffer.WriteString(q.LockClause)
    }

    buffer.WriteString(";")

    return buffer.String(), args
}

// Insert generates query for insert.
func (builder *Builder) Insert(collection string, changes map[string]interface{}) (string, []interface{}) {
    var (
        buffer bytes.Buffer
        length = len(changes)
        args   = make([]interface{}, 0, length)
    )

    buffer.WriteString("INSERT INTO ")
    buffer.WriteString(builder.config.EscapeChar)
    buffer.WriteString(collection)
    buffer.WriteString(builder.config.EscapeChar)

    if len(changes) == 0 && builder.config.InsertDefaultValues {
        buffer.WriteString(" DEFAULT VALUES")
    } else {
        buffer.WriteString(" (")

        curr := 0
        for field, value := range changes {
            buffer.WriteString(builder.config.EscapeChar)
            buffer.WriteString(field)
            buffer.WriteString(builder.config.EscapeChar)

            args = append(args, value)

            if curr < length-1 {
                buffer.WriteString(",")
            }

            curr++
        }
        buffer.WriteString(") VALUES ")

        buffer.WriteString("(")
        for i := 0; i < length; i++ {
            buffer.WriteString(builder.ph())

            if i < length-1 {
                buffer.WriteString(",")
            }
        }
        buffer.WriteString(")")
    }

    if builder.returnField != "" {
        buffer.WriteString(" RETURNING ")
        buffer.WriteString(builder.config.EscapeChar)
        buffer.WriteString(builder.returnField)
        buffer.WriteString(builder.config.EscapeChar)
    }

    buffer.WriteString(";")

    return buffer.String(), args
}

// InsertAll generates query for multiple insert.
func (builder *Builder) InsertAll(collection string, fields []string, allchanges []map[string]interface{}) (string, []interface{}) {
    var (
        buffer bytes.Buffer
        args   = make([]interface{}, 0, len(fields)*len(allchanges))
    )

    buffer.WriteString("INSERT INTO ")

    buffer.WriteString(builder.config.EscapeChar)
    buffer.WriteString(collection)
    buffer.WriteString(builder.config.EscapeChar)

    sep := builder.config.EscapeChar + "," + builder.config.EscapeChar

    buffer.WriteString(" (")
    buffer.WriteString(builder.config.EscapeChar)
    buffer.WriteString(strings.Join(fields, sep))
    buffer.WriteString(builder.config.EscapeChar)
    buffer.WriteString(") VALUES ")

    for i, changes := range allchanges {
        buffer.WriteString("(")

        for j, field := range fields {
            if val, exist := changes[field]; exist {
                buffer.WriteString(builder.ph())
                args = append(args, val)
            } else {
                buffer.WriteString("DEFAULT")
            }

            if j < len(fields)-1 {
                buffer.WriteString(",")
            }
        }

        if i < len(allchanges)-1 {
            buffer.WriteString("),")
        } else {
            buffer.WriteString(")")
        }
    }

    if builder.returnField != "" {
        buffer.WriteString(" RETURNING ")
        buffer.WriteString(builder.config.EscapeChar)
        buffer.WriteString(builder.returnField)
        buffer.WriteString(builder.config.EscapeChar)
    }

    buffer.WriteString(";")

    return buffer.String(), args
}

// Update generates query for update.
func (builder *Builder) Update(collection string, changes map[string]interface{}, cond c.Condition) (string, []interface{}) {
    var (
        buffer bytes.Buffer
        length = len(changes)
        args   = make([]interface{}, 0, length)
    )

    buffer.WriteString("UPDATE ")
    buffer.WriteString(builder.config.EscapeChar)
    buffer.WriteString(collection)
    buffer.WriteString(builder.config.EscapeChar)
    buffer.WriteString(" SET ")

    curr := 0
    for field, value := range changes {
        buffer.WriteString(builder.config.EscapeChar)
        buffer.WriteString(field)
        buffer.WriteString(builder.config.EscapeChar)
        buffer.WriteString("=")
        buffer.WriteString(builder.ph())
        args = append(args, value)

        if curr < length-1 {
            buffer.WriteString(",")
        }

        curr++
    }

    if s, arg := builder.where(cond); s != "" {
        buffer.WriteString(" ")
        buffer.WriteString(s)
        args = append(args, arg...)
    }

    buffer.WriteString(";")

    return buffer.String(), args
}

// Delete generates query for delete.
func (builder *Builder) Delete(collection string, cond c.Condition) (string, []interface{}) {
    var (
        buffer bytes.Buffer
        args   []interface{}
    )

    buffer.WriteString("DELETE FROM ")
    buffer.WriteString(builder.config.EscapeChar)
    buffer.WriteString(collection)
    buffer.WriteString(builder.config.EscapeChar)

    if s, arg := builder.where(cond); s != "" {
        buffer.WriteString(" ")
        buffer.WriteString(s)
        args = append(args, arg...)
    }

    buffer.WriteString(";")

    return buffer.String(), args
}

func (builder *Builder) fields(distinct bool, fields ...string) string {
    if len(fields) == 0 {
        return "SELECT *"
    }

    var buffer bytes.Buffer

    buffer.WriteString("SELECT ")

    if distinct {
        buffer.WriteString("DISTINCT ")
    }

    l := len(fields) - 1
    for i, f := range fields {
        buffer.WriteString(builder.escape(f))

        if i < l {
            buffer.WriteString(",")
        }
    }

    return buffer.String()
}

func (builder *Builder) from(collection string) string {
    return "FROM " + builder.config.EscapeChar + collection + builder.config.EscapeChar
}

func (builder *Builder) join(join ...c.Join) (string, []interface{}) {
    if len(join) == 0 {
        return "", nil
    }

    var (
        qs   string
        args []interface{}
    )

    for i, j := range join {
        cs, jargs := builder.condition(j.Condition)
        qs += j.Mode + " " + builder.config.EscapeChar + j.Collection + builder.config.EscapeChar + " ON " + cs
        args = append(args, jargs...)

        if i < len(join)-1 {
            qs += " "
        }
    }

    return qs, args
}

func (builder *Builder) where(condition c.Condition) (string, []interface{}) {
    if condition.None() {
        return "", nil
    }

    qs, args := builder.condition(condition)
    return "WHERE " + qs, args
}

func (builder *Builder) groupBy(fields ...string) string {
    if len(fields) == 0 {
        return ""
    }

    var buffer bytes.Buffer
    buffer.WriteString("GROUP BY ")

    l := len(fields) - 1
    for i, f := range fields {
        buffer.WriteString(builder.escape(f))

        if i < l {
            buffer.WriteString(",")
        }
    }

    return buffer.String()
}

func (builder *Builder) having(condition c.Condition) (string, []interface{}) {
    if condition.None() {
        return "", nil
    }

    qs, args := builder.condition(condition)
    return "HAVING " + qs, args
}

func (builder *Builder) orderBy(orders ...c.Order) string {
    length := len(orders)
    if length == 0 {
        return ""
    }

    qs := "ORDER BY "
    for i, o := range orders {
        if o.Asc() {
            qs += builder.escape(string(o.Field)) + " ASC"
        } else {
            qs += builder.escape(string(o.Field)) + " DESC"
        }

        if i < length-1 {
            qs += ", "
        }
    }

    return qs
}

func (builder *Builder) limitOffset(limit int, offset int) string {
    str := ""

    if limit > 0 {
        str = "LIMIT " + strconv.Itoa(limit)

        if offset > 0 {
            str += " OFFSET " + strconv.Itoa(offset)
        }
    }

    return str
}

func (builder *Builder) condition(cond c.Condition) (string, []interface{}) {
    switch cond.Type {
    case c.ConditionAnd:
        return builder.build("AND", cond.Inner)
    case c.ConditionOr:
        return builder.build("OR", cond.Inner)
    case c.ConditionNot:
        qs, args := builder.build("AND", cond.Inner)
        return "NOT " + qs, args
    case c.ConditionEq,
        c.ConditionNe,
        c.ConditionLt,
        c.ConditionLte,
        c.ConditionGt,
        c.ConditionGte:
        return builder.buildComparison(cond)
    case c.ConditionNil:
        return builder.escape(string(cond.Left.Column)) + " IS NULL", cond.Right.Values
    case c.ConditionNotNil:
        return builder.escape(string(cond.Left.Column)) + " IS NOT NULL", cond.Right.Values
    case c.ConditionIn,
        c.ConditionNin:
        return builder.buildInclusion(cond)
    case c.ConditionLike:
        return builder.escape(string(cond.Left.Column)) + " LIKE " + builder.ph(), cond.Right.Values
    case c.ConditionNotLike:
        return builder.escape(string(cond.Left.Column)) + " NOT LIKE " + builder.ph(), cond.Right.Values
    case c.ConditionFragment:
        return string(cond.Left.Column), cond.Right.Values
    }

    return "", nil
}

func (builder *Builder) build(op string, inner []c.Condition) (string, []interface{}) {
    var (
        qstring string
        length  = len(inner)
        args    []interface{}
    )

    if length > 1 {
        qstring += "("
    }

    for i, c := range inner {
        cQstring, cArgs := builder.condition(c)
        qstring += cQstring
        args = append(args, cArgs...)

        if i < length-1 {
            qstring += " " + op + " "
        }
    }

    if length > 1 {
        qstring += ")"
    }

    return qstring, args
}

func (builder *Builder) buildComparison(cond c.Condition) (string, []interface{}) {
    var (
        cs string
        op string
    )

    switch cond.Type {
    case c.ConditionEq:
        op = "="
    case c.ConditionNe:
        op = "<>"
    case c.ConditionLt:
        op = "<"
    case c.ConditionLte:
        op = "<="
    case c.ConditionGt:
        op = ">"
    case c.ConditionGte:
        op = ">="
    }

    if cond.Left.Column != "" {
        cs = builder.escape(string(cond.Left.Column)) + op
    } else {
        cs = builder.ph() + op
    }

    if cond.Right.Column != "" {
        cs += builder.escape(string(cond.Right.Column))
    } else {
        cs += builder.ph()
    }

    return cs, append(cond.Left.Values, cond.Right.Values...)
}

func (builder *Builder) buildInclusion(cond c.Condition) (string, []interface{}) {
    var buffer bytes.Buffer
    buffer.WriteString(builder.escape(string(cond.Left.Column)))

    if cond.Type == c.ConditionIn {
        buffer.WriteString(" IN (")
    } else {
        buffer.WriteString(" NOT IN (")
    }

    buffer.WriteString(builder.ph())
    for i := 1; i <= len(cond.Right.Values)-1; i++ {
        buffer.WriteString(",")
        buffer.WriteString(builder.ph())
    }
    buffer.WriteString(")")

    return buffer.String(), cond.Right.Values
}

func (builder *Builder) ph() string {
    if builder.config.Ordinal {
        builder.count++
        return builder.config.Placeholder + strconv.Itoa(builder.count)
    }

    return builder.config.Placeholder
}

func (builder *Builder) escape(field string) string {
    if builder.config.EscapeChar == "" || field == "*" {
        return field
    }

    key := field + builder.config.EscapeChar
    escapedField, ok := fieldCache.Load(key)
    if ok {
        return escapedField.(string)
    }

    if len(field) > 0 && field[0] == UnescapeCharacter {
        escapedField = field[1:]
    } else if start, end := strings.IndexRune(field, '('), strings.IndexRune(field, ')'); start >= 0 && end >= 0 && end > start {
        escapedField = field[:start+1] + builder.escape(field[start+1:end]) + field[end:]
    } else if strings.HasSuffix(field, "*") {
        escapedField = builder.config.EscapeChar + strings.Replace(field, ".", builder.config.EscapeChar+".", 1)
    } else {
        escapedField = builder.config.EscapeChar +
            strings.Replace(field, ".", builder.config.EscapeChar+"."+builder.config.EscapeChar, 1) +
            builder.config.EscapeChar
    }

    fieldCache.Store(key, escapedField)
    return escapedField.(string)
}

// Returning append returning to insert query.
func (builder *Builder) Returning(field string) *Builder {
    builder.returnField = field
    return builder
}

// NewBuilder create new SQL builder.
func NewBuilder(config *Config) *Builder {
    return &Builder{
        config: config,
    }
}