getwtxt/getwtxt

View on GitHub
svc/query.go

Summary

Maintainability
A
0 mins
Test Coverage
/*
Copyright (c) 2019 Ben Morrison (gbmor)

This file is part of Getwtxt.

Getwtxt is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Getwtxt is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Getwtxt.  If not, see <https://www.gnu.org/licenses/>.
*/

package svc // import "git.sr.ht/~gbmor/getwtxt/svc"

import (
    "crypto/sha256"
    "fmt"
    "log"
    "net/http"
    "strconv"
    "strings"
    "sync"

    "git.sr.ht/~gbmor/getwtxt/registry"
    "github.com/gorilla/mux"
)

// Wrapper to check if an error is non-nil, then
// log the error if applicable.
func apiErrCheck(err error, r *http.Request) {
    if err != nil {
        uip := getIPFromCtx(r.Context())
        log.Printf("*** %v :: %v %v :: %v\n", uip, r.Method, r.URL, err.Error())
    }
}

// Deduplicates a slice of strings
func dedupe(list []string) []string {
    out := []string{}
    seen := make(map[string]bool)

    for _, e := range list {
        if !seen[e] {
            out = append(out, e)
            seen[e] = true
        }
    }

    return out
}

// Takes the output of queries and formats it for
// an HTTP response. Iterates over the string slice,
// appending each entry to a byte slice, and adding
// newlines where appropriate.
func parseQueryOut(out []string) []byte {
    data := make([]byte, 0)

    for i, e := range out {
        data = append(data, []byte(e)...)
        if !strings.HasSuffix(e, "\n") && i != len(out)-1 {
            data = append(data, byte('\n'))
        }
    }

    return data
}

// apiEndpointQuery is called via apiEndpointHandler when
// the endpoint is "users" and r.FormValue("q") is not empty.
// It queries the registry cache for users or user URLs
// matching the term supplied via r.FormValue("q")
func apiEndpointQuery(w http.ResponseWriter, r *http.Request) error {
    query := r.FormValue("q")
    urls := r.FormValue("url")
    pageVal := r.FormValue("page")
    var out []string
    var err error

    pageVal = strings.TrimSpace(pageVal)
    page, err := strconv.Atoi(pageVal)
    errLog("", err)

    vars := mux.Vars(r)
    endpoint := vars["endpoint"]

    // Handle user URL queries first, then nickname queries.
    // Concatenate both outputs if they're both set.
    // Also handle mention queries and status queries.
    // If we made it this far and 'default' is matched,
    // something went very wrong.
    switch endpoint {
    case "users":
        var out2 []string
        if query != "" {
            out, err = twtxtCache.QueryUser(query)
            apiErrCheck(err, r)
        }
        if urls != "" {
            out2, err = twtxtCache.QueryUser(urls)
            apiErrCheck(err, r)
        }
        if query != "" && urls != "" {
            out = joinQueryOuts(out2)
        }

    case "mentions":
        if urls == "" {
            return fmt.Errorf("missing URL in mention query")
        }
        urls += ">"
        out, err = twtxtCache.QueryInStatus(urls)
        apiErrCheck(err, r)

    case "tweets":
        out = compositeStatusQuery(query, r)

    default:
        return fmt.Errorf("endpoint query, no cases match")
    }

    out = registry.ReduceToPage(page, out)
    data := parseQueryOut(out)
    etag := fmt.Sprintf("%x", sha256.Sum256(data))

    w.Header().Set("ETag", etag)
    w.Header().Set("Content-Type", txtutf8)
    _, err = w.Write(data)

    return err
}

// For composite queries, join the various slices of strings
// into a single slice of strings, then deduplicates them.
func joinQueryOuts(data ...[]string) []string {
    single := []string{}
    for _, e := range data {
        single = append(single, e...)
    }
    return dedupe(single)
}

// Performs a composite query against the statuses.
func compositeStatusQuery(query string, r *http.Request) []string {
    var wg sync.WaitGroup
    var out, out2, out3 []string
    var err, err2, err3 error

    wg.Add(3)

    query = strings.ToLower(query)
    go func(query string) {
        out, err = twtxtCache.QueryInStatus(query)
        wg.Done()
    }(query)

    query = strings.Title(query)
    go func(query string) {
        out2, err2 = twtxtCache.QueryInStatus(query)
        wg.Done()
    }(query)

    query = strings.ToUpper(query)
    go func(query string) {
        out3, err3 = twtxtCache.QueryInStatus(query)
        wg.Done()
    }(query)

    wg.Wait()

    apiErrCheck(err, r)
    apiErrCheck(err2, r)
    apiErrCheck(err3, r)

    return joinQueryOuts(out, out2, out3)
}