cloudfoundry/cf-k8s-controllers

View on GitHub
api/payloads/validation/validator.go

Summary

Maintainability
A
0 mins
Test Coverage
package validation

import (
    "encoding/json"
    "errors"
    "fmt"
    "net/http"
    "net/url"
    "regexp"
    "sort"
    "strings"

    apierrors "code.cloudfoundry.org/korifi/api/errors"

    "github.com/jellydator/validation"
    "golang.org/x/text/cases"
    "golang.org/x/text/language"
    "gopkg.in/yaml.v3"
)

type KeyedPayload interface {
    SupportedKeys() []string
    DecodeFromURLValues(url.Values) error
}

type IgnoredKeysPayload interface {
    IgnoredKeys() []*regexp.Regexp
}

type DecoderValidator struct{}

func NewDefaultDecoderValidator() DecoderValidator {
    return DecoderValidator{}
}

func (dv DecoderValidator) DecodeAndValidateJSONPayload(r *http.Request, object any) error {
    decoder := json.NewDecoder(r.Body)
    defer r.Body.Close()
    decoder.DisallowUnknownFields()
    err := decoder.Decode(object)
    if err != nil {
        var unmarshalTypeError *json.UnmarshalTypeError
        switch {
        case errors.As(err, &unmarshalTypeError):
            titler := cases.Title(language.AmericanEnglish)
            return apierrors.NewUnprocessableEntityError(err, fmt.Sprintf("%v must be a %v", titler.String(unmarshalTypeError.Field), unmarshalTypeError.Type))
        case strings.HasPrefix(err.Error(), "json: unknown field"):
            // check whether the message matches an "unknown field" error. If so, 422. Else, 400
            return apierrors.NewUnprocessableEntityError(err, fmt.Sprintf("invalid request body: %s", err.Error()))
        default:
            return apierrors.NewMessageParseError(err)
        }
    }

    return dv.validatePayload(object)
}

func (dv DecoderValidator) DecodeAndValidateYAMLPayload(r *http.Request, object any) error {
    decoder := yaml.NewDecoder(r.Body)
    defer r.Body.Close()
    decoder.KnownFields(false) // TODO: change this to true once we've added all manifest fields to payloads.Manifest
    err := decoder.Decode(object)
    if err != nil {
        return apierrors.NewMessageParseError(err)
    }

    return dv.validatePayload(object)
}

func (dv DecoderValidator) DecodeAndValidateURLValues(r *http.Request, object KeyedPayload) error {
    if err := r.ParseForm(); err != nil {
        return err
    }
    if err := checkKeysAreSupported(object, r.Form); err != nil {
        return apierrors.NewUnknownKeyError(err, object.SupportedKeys())
    }
    if err := object.DecodeFromURLValues(r.Form); err != nil {
        return apierrors.NewMessageParseError(err)
    }
    return dv.validatePayload(object)
}

func checkKeysAreSupported(payloadObject KeyedPayload, values url.Values) error {
    supportedKeys := map[string]bool{}
    for _, key := range payloadObject.SupportedKeys() {
        supportedKeys[key] = true
    }

    for key := range values {
        if !supportedKeys[key] && !isIgnored(payloadObject, key) {
            return fmt.Errorf("unsupported query parameter: %s", key)
        }
    }

    return nil
}

func isIgnored(payload KeyedPayload, key string) bool {
    ignoredKeysPayload, ok := payload.(IgnoredKeysPayload)
    if !ok {
        return false
    }

    for _, re := range ignoredKeysPayload.IgnoredKeys() {
        if re.MatchString(key) {
            return true
        }
    }

    return false
}

func (dv *DecoderValidator) validatePayload(object any) error {
    t, ok := object.(validation.Validatable)
    if !ok {
        return nil
    }

    if err := t.Validate(); err != nil {
        return apierrors.NewUnprocessableEntityError(err, strings.Join(errorMessages(err), ", "))
    }

    return nil
}

func errorMessages(err error) []string {
    errs := prefixedErrorMessages("", err)
    sort.Strings(errs)
    return errs
}

var arrayIndexRegexp = regexp.MustCompile(`^\d+$`)

func prefixedErrorMessages(field string, err error) []string {
    errors, ok := err.(validation.Errors)
    if !ok {
        return []string{field + " " + err.Error()}
    }

    prefix := ""
    if field != "" {
        if arrayIndexRegexp.MatchString(field) {
            prefix = "[" + field + "]."
        } else {
            prefix = field + "."
        }
    }

    var messages []string
    for f, err := range errors {
        if arrayIndexRegexp.MatchString(f) {
            prefix = strings.TrimSuffix(prefix, ".")
        }

        ems := prefixedErrorMessages(f, err)
        for _, e := range ems {
            messages = append(messages, prefix+e)
        }
    }

    return messages
}