utils/utils.go
package utils
import (
"database/sql/driver"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"reflect"
"regexp"
"runtime"
"runtime/debug"
"sort"
"time"
"github.com/gosimple/slug"
"github.com/jinzhu/gorm"
"github.com/jinzhu/now"
"github.com/microcosm-cc/bluemonday"
"github.com/qor/qor"
"strings"
)
// AppRoot app root path
var AppRoot, _ = os.Getwd()
// ContextKey defined type used for context's key
type ContextKey string
// ContextDBName db name used for context
var ContextDBName ContextKey = "ContextDB"
// HTMLSanitizer html sanitizer to avoid XSS
var HTMLSanitizer = bluemonday.UGCPolicy()
func init() {
HTMLSanitizer.AllowStandardAttributes()
if path := os.Getenv("WEB_ROOT"); path != "" {
AppRoot = path
}
}
// GOPATH return GOPATH from env
func GOPATH() []string {
paths := strings.Split(os.Getenv("GOPATH"), string(os.PathListSeparator))
if len(paths) == 0 {
fmt.Println("GOPATH doesn't exist")
}
return paths
}
// GetDBFromRequest get database from request
var GetDBFromRequest = func(req *http.Request) *gorm.DB {
db := req.Context().Value(ContextDBName)
if tx, ok := db.(*gorm.DB); ok {
return tx
}
return nil
}
// HumanizeString Humanize separates string based on capitalizd letters
// e.g. "OrderItem" -> "Order Item"
func HumanizeString(str string) string {
var human []rune
for i, l := range str {
if i > 0 && isUppercase(byte(l)) {
if (!isUppercase(str[i-1]) && str[i-1] != ' ') || (i+1 < len(str) && !isUppercase(str[i+1]) && str[i+1] != ' ' && str[i-1] != ' ') {
human = append(human, rune(' '))
}
}
human = append(human, l)
}
return strings.Title(string(human))
}
func isUppercase(char byte) bool {
return 'A' <= char && char <= 'Z'
}
var asicsiiRegexp = regexp.MustCompile("^(\\w|\\s|-|!)*$")
// ToParamString replaces spaces and separates words (by uppercase letters) with
// underscores in a string, also downcase it
// e.g. ToParamString -> to_param_string, To ParamString -> to_param_string
func ToParamString(str string) string {
if asicsiiRegexp.MatchString(str) {
return gorm.ToDBName(strings.Replace(str, " ", "_", -1))
}
return slug.Make(str)
}
// PatchURL updates the query part of the request url.
// PatchURL("google.com","key","value") => "google.com?key=value"
func PatchURL(originalURL string, params ...interface{}) (patchedURL string, err error) {
url, err := url.Parse(originalURL)
if err != nil {
return
}
query := url.Query()
for i := 0; i < len(params)/2; i++ {
// Check if params is key&value pair
key := fmt.Sprintf("%v", params[i*2])
value := fmt.Sprintf("%v", params[i*2+1])
if value == "" {
query.Del(key)
} else {
query.Set(key, value)
}
}
url.RawQuery = query.Encode()
patchedURL = url.String()
return
}
// JoinURL updates the path part of the request url.
// JoinURL("google.com", "admin") => "google.com/admin"
// JoinURL("google.com?q=keyword", "admin") => "google.com/admin?q=keyword"
func JoinURL(originalURL string, paths ...interface{}) (joinedURL string, err error) {
u, err := url.Parse(originalURL)
if err != nil {
return
}
var urlPaths = []string{u.Path}
for _, p := range paths {
urlPaths = append(urlPaths, fmt.Sprint(p))
}
if strings.HasSuffix(strings.Join(urlPaths, ""), "/") {
u.Path = path.Join(urlPaths...) + "/"
} else {
u.Path = path.Join(urlPaths...)
}
joinedURL = u.String()
return
}
// SetCookie set cookie for context
func SetCookie(cookie http.Cookie, context *qor.Context) {
cookie.HttpOnly = true
// set https cookie
if context.Request != nil && context.Request.URL.Scheme == "https" {
cookie.Secure = true
}
// set default path
if cookie.Path == "" {
cookie.Path = "/"
}
http.SetCookie(context.Writer, &cookie)
}
// Stringify stringify any data, if it is a struct, will try to use its Name, Title, Code field, else will use its primary key
func Stringify(object interface{}) string {
if obj, ok := object.(interface {
Stringify() string
}); ok {
return obj.Stringify()
}
scope := gorm.Scope{Value: object}
for _, column := range []string{"Name", "Title", "Code"} {
if field, ok := scope.FieldByName(column); ok {
if field.Field.IsValid() {
result := field.Field.Interface()
if valuer, ok := result.(driver.Valuer); ok {
if result, err := valuer.Value(); err == nil {
return fmt.Sprint(result)
}
}
return fmt.Sprint(result)
}
}
}
if scope.PrimaryField() != nil {
if scope.PrimaryKeyZero() {
return ""
}
return fmt.Sprintf("%v#%v", scope.GetModelStruct().ModelType.Name(), scope.PrimaryKeyValue())
}
return fmt.Sprint(reflect.Indirect(reflect.ValueOf(object)).Interface())
}
// ModelType get value's model type
func ModelType(value interface{}) reflect.Type {
reflectType := reflect.Indirect(reflect.ValueOf(value)).Type()
for reflectType.Kind() == reflect.Ptr || reflectType.Kind() == reflect.Slice {
reflectType = reflectType.Elem()
}
return reflectType
}
// ParseTagOption parse tag options to hash
func ParseTagOption(str string) map[string]string {
tags := strings.Split(str, ";")
setting := map[string]string{}
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if len(v) == 2 {
setting[k] = v[1]
} else {
setting[k] = k
}
}
return setting
}
// ExitWithMsg debug error messages and print stack
func ExitWithMsg(msg interface{}, value ...interface{}) {
fmt.Printf("\n"+filenameWithLineNum()+"\n"+fmt.Sprint(msg)+"\n", value...)
debug.PrintStack()
}
// FileServer file server that disabled file listing
func FileServer(dir http.Dir) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p := path.Join(string(dir), r.URL.Path)
if f, err := os.Stat(p); err == nil && !f.IsDir() {
http.ServeFile(w, r, p)
return
}
http.NotFound(w, r)
})
}
func filenameWithLineNum() string {
var total = 10
var results []string
for i := 2; i < 15; i++ {
if _, file, line, ok := runtime.Caller(i); ok {
total--
results = append(results[:0],
append(
[]string{fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)},
results[0:]...)...)
if total == 0 {
return strings.Join(results, "\n")
}
}
}
return ""
}
// GetLocale get locale from request, cookie, after get the locale, will write the locale to the cookie if possible
// Overwrite the default logic with
// utils.GetLocale = func(context *qor.Context) string {
// // ....
// }
var GetLocale = func(context *qor.Context) string {
if locale := context.Request.Header.Get("Locale"); locale != "" {
return locale
}
if locale := context.Request.URL.Query().Get("locale"); locale != "" {
if context.Writer != nil {
context.Request.Header.Set("Locale", locale)
SetCookie(http.Cookie{Name: "locale", Value: locale, Expires: time.Now().AddDate(1, 0, 0)}, context)
}
return locale
}
if locale, err := context.Request.Cookie("locale"); err == nil {
return locale.Value
}
return ""
}
// ParseTime parse time from string
// Overwrite the default logic with
// utils.ParseTime = func(timeStr string, context *qor.Context) (time.Time, error) {
// // ....
// }
var ParseTime = func(timeStr string, context *qor.Context) (time.Time, error) {
return now.Parse(timeStr)
}
// FormatTime format time to string
// Overwrite the default logic with
// utils.FormatTime = func(time time.Time, format string, context *qor.Context) string {
// // ....
// }
var FormatTime = func(date time.Time, format string, context *qor.Context) string {
return date.Format(format)
}
var replaceIdxRegexp = regexp.MustCompile(`\[\d+\]`)
// SortFormKeys sort form keys
func SortFormKeys(strs []string) {
sort.Slice(strs, func(i, j int) bool { // true for first
str1 := strs[i]
str2 := strs[j]
matched1 := replaceIdxRegexp.FindAllStringIndex(str1, -1)
matched2 := replaceIdxRegexp.FindAllStringIndex(str2, -1)
for x := 0; x < len(matched1); x++ {
prefix1 := str1[:matched1[x][0]]
prefix2 := str2
if len(matched2) >= x+1 {
prefix2 = str2[:matched2[x][0]]
}
if prefix1 != prefix2 {
return strings.Compare(prefix1, prefix2) < 0
}
if len(matched2) < x+1 {
return false
}
number1 := str1[matched1[x][0]:matched1[x][1]]
number2 := str2[matched2[x][0]:matched2[x][1]]
if number1 != number2 {
if len(number1) != len(number2) {
return len(number1) < len(number2)
}
return strings.Compare(number1, number2) < 0
}
}
return strings.Compare(str1, str2) < 0
})
}
// GetAbsURL get absolute URL from request, refer: https://stackoverflow.com/questions/6899069/why-are-request-url-host-and-scheme-blank-in-the-development-server
func GetAbsURL(req *http.Request) url.URL {
if req.URL.IsAbs() {
return *req.URL
}
var result *url.URL
if domain := req.Header.Get("Origin"); domain != "" {
result, _ = url.Parse(domain)
} else {
if req.TLS == nil {
result, _ = url.Parse("http://" + req.Host)
} else {
result, _ = url.Parse("https://" + req.Host)
}
}
result.Parse(req.RequestURI)
return *result
}
// Indirect returns last value that v points to
func Indirect(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Ptr {
v = reflect.Indirect(v)
}
return v
}
// SliceUniq removes duplicate values in given slice
func SliceUniq(s []string) []string {
for i := 0; i < len(s); i++ {
for i2 := i + 1; i2 < len(s); i2++ {
if s[i] == s[i2] {
// delete
s = append(s[:i2], s[i2+1:]...)
i2--
}
}
}
return s
}
// SafeJoin safe join https://snyk.io/research/zip-slip-vulnerability#go
func SafeJoin(paths ...string) (string, error) {
result := path.Join(paths...)
// check filepath
if !strings.HasPrefix(result, filepath.Clean(paths[0])+string(os.PathSeparator)) {
return "", errors.New("invalid filepath")
}
return result, nil
}