pkg/generator/validator.go
package generator
import (
"fmt"
"reflect"
"strings"
"github.com/pkg/errors"
"github.com/sanity-io/litter"
"github.com/atombender/go-jsonschema/pkg/codegen"
"github.com/atombender/go-jsonschema/pkg/mathutils"
)
type validator interface {
generate(out *codegen.Emitter)
desc() *validatorDesc
}
type validatorDesc struct {
hasError bool
beforeJSONUnmarshal bool
requiresRawAfter bool
}
var (
_ validator = new(requiredValidator)
_ validator = new(nullTypeValidator)
_ validator = new(defaultValidator)
_ validator = new(arrayValidator)
_ validator = new(stringValidator)
_ validator = new(numericValidator)
)
type requiredValidator struct {
jsonName string
declName string
}
func (v *requiredValidator) generate(out *codegen.Emitter) {
// The container itself may be null (if the type is ["null", "object"]), in which case
// the map will be nil and none of the properties are present. This shouldn't fail
// the validation, though, as that's allowed as long as the container is allowed to be null.
out.Printlnf(`if _, ok := %s["%s"]; %s != nil && !ok {`, varNameRawMap, v.jsonName, varNameRawMap)
out.Indent(1)
out.Printlnf(`return fmt.Errorf("field %s in %s: required")`, v.jsonName, v.declName)
out.Indent(-1)
out.Printlnf("}")
}
func (v *requiredValidator) desc() *validatorDesc {
return &validatorDesc{
hasError: true,
beforeJSONUnmarshal: true,
}
}
type nullTypeValidator struct {
jsonName string
fieldName string
arrayDepth int
}
func (v *nullTypeValidator) generate(out *codegen.Emitter) {
value := getPlainName(v.fieldName)
fieldName := v.jsonName
indexes := make([]string, v.arrayDepth)
for i := range v.arrayDepth {
index := fmt.Sprintf("i%d", i)
indexes[i] = index
out.Printlnf(`for %s := range %s {`, index, value)
value += fmt.Sprintf("[%s]", index)
fieldName += "[%d]"
out.Indent(1)
}
fieldName = fmt.Sprintf(`"%s"`, fieldName)
if len(indexes) > 0 {
fieldName = fmt.Sprintf(`fmt.Sprintf(%s, %s)`, fieldName, strings.Join(indexes, ", "))
}
out.Printlnf(`if %s != nil {`, value)
out.Indent(1)
out.Printlnf(`return fmt.Errorf("field %%s: must be null", %s)`, fieldName)
out.Indent(-1)
out.Printlnf("}")
for range v.arrayDepth {
out.Indent(-1)
out.Printlnf("}")
}
}
func (v *nullTypeValidator) desc() *validatorDesc {
return &validatorDesc{
hasError: true,
beforeJSONUnmarshal: false,
requiresRawAfter: true,
}
}
type defaultValidator struct {
jsonName string
fieldName string
defaultValueType codegen.Type
defaultValue interface{}
}
func (v *defaultValidator) generate(out *codegen.Emitter) {
defaultValue, err := v.tryDumpDefaultSlice(out.MaxLineLength())
if err != nil {
// Fallback to sdump in case we couldn't dump it properly.
defaultValue = litter.Sdump(v.defaultValue)
}
out.Printlnf(`if v, ok := %s["%s"]; !ok || v == nil {`, varNameRawMap, v.jsonName)
out.Indent(1)
out.Printlnf(`%s = %s`, getPlainName(v.fieldName), defaultValue)
out.Indent(-1)
out.Printlnf("}")
}
func (v *defaultValidator) tryDumpDefaultSlice(maxLineLen int32) (string, error) {
tmpEmitter := codegen.NewEmitter(maxLineLen)
v.defaultValueType.Generate(tmpEmitter)
tmpEmitter.Printlnf("{")
kind := reflect.ValueOf(v.defaultValue).Kind()
switch kind {
case reflect.Slice:
df, ok := v.defaultValue.([]interface{})
if !ok {
return "", errors.New("invalid default value")
}
for _, value := range df {
tmpEmitter.Printlnf("%s,", litter.Sdump(value))
}
default:
return "", errors.New("didn't find a slice to dump")
}
tmpEmitter.Printf("}")
return tmpEmitter.String(), nil
}
func (v *defaultValidator) desc() *validatorDesc {
return &validatorDesc{
hasError: false,
beforeJSONUnmarshal: false,
requiresRawAfter: true,
}
}
type arrayValidator struct {
jsonName string
fieldName string
arrayDepth int
minItems int
maxItems int
}
func (v *arrayValidator) generate(out *codegen.Emitter) {
if v.minItems == 0 && v.maxItems == 0 {
return
}
value := getPlainName(v.fieldName)
fieldName := v.jsonName
var indexes []string
for i := 1; i < v.arrayDepth; i++ {
index := fmt.Sprintf("i%d", i)
indexes = append(indexes, index)
out.Printlnf(`for %s := range %s {`, index, value)
value += fmt.Sprintf("[%s]", index)
fieldName += "[%d]"
out.Indent(1)
}
fieldName = fmt.Sprintf(`"%s"`, fieldName)
if len(indexes) > 0 {
fieldName = fmt.Sprintf(`fmt.Sprintf(%s, %s)`, fieldName, strings.Join(indexes, ", "))
}
if v.minItems != 0 {
out.Printlnf(`if %s != nil && len(%s) < %d {`, value, value, v.minItems)
out.Indent(1)
out.Printlnf(`return fmt.Errorf("field %%s length: must be >= %%d", %s, %d)`, fieldName, v.minItems)
out.Indent(-1)
out.Printlnf("}")
}
if v.maxItems != 0 {
out.Printlnf(`if len(%s) > %d {`, value, v.maxItems)
out.Indent(1)
out.Printlnf(`return fmt.Errorf("field %%s length: must be <= %%d", %s, %d)`, fieldName, v.maxItems)
out.Indent(-1)
out.Printlnf("}")
}
for i := 1; i < v.arrayDepth; i++ {
out.Indent(-1)
out.Printlnf("}")
}
}
func (v *arrayValidator) desc() *validatorDesc {
return &validatorDesc{
hasError: true,
beforeJSONUnmarshal: false,
}
}
type stringValidator struct {
jsonName string
fieldName string
minLength int
maxLength int
isNillable bool
pattern string
}
func (v *stringValidator) generate(out *codegen.Emitter) {
value := getPlainName(v.fieldName)
checkPointer := ""
pointerPrefix := ""
if v.isNillable {
checkPointer = fmt.Sprintf("%s != nil && ", value)
pointerPrefix = "*"
}
if len(v.pattern) != 0 {
if v.isNillable {
out.Printlnf("if %s != nil {", value)
out.Indent(1)
}
out.Printlnf(`if matched, _ := regexp.MatchString("%s", string(%s%s)); !matched {`, v.pattern, pointerPrefix, value)
out.Indent(1)
out.Printlnf(`return fmt.Errorf("field %%s pattern match: must match %%s", "%s", "%s")`, v.pattern, v.fieldName)
out.Indent(-1)
out.Printlnf("}")
if v.isNillable {
out.Indent(-1)
out.Printlnf("}")
}
}
if v.minLength == 0 && v.maxLength == 0 {
return
}
fieldName := v.jsonName
if v.minLength != 0 {
out.Printlnf(`if %slen(%s%s) < %d {`, checkPointer, pointerPrefix, value, v.minLength)
out.Indent(1)
out.Printlnf(`return fmt.Errorf("field %%s length: must be >= %%d", "%s", %d)`, fieldName, v.minLength)
out.Indent(-1)
out.Printlnf("}")
}
if v.maxLength != 0 {
out.Printlnf(`if %slen(%s%s) > %d {`, checkPointer, pointerPrefix, value, v.maxLength)
out.Indent(1)
out.Printlnf(`return fmt.Errorf("field %%s length: must be <= %%d", "%s", %d)`, fieldName, v.maxLength)
out.Indent(-1)
out.Printlnf("}")
}
}
func (v *stringValidator) desc() *validatorDesc {
return &validatorDesc{
hasError: true,
beforeJSONUnmarshal: false,
}
}
type numericValidator struct {
jsonName string
fieldName string
isNillable bool
multipleOf *float64
maximum *float64
exclusiveMaximum *any
minimum *float64
exclusiveMinimum *any
roundToInt bool
}
func (v *numericValidator) generate(out *codegen.Emitter) {
value := getPlainName(v.fieldName)
checkPointer := ""
pointerPrefix := ""
if v.isNillable {
checkPointer = fmt.Sprintf("%s != nil && ", value)
pointerPrefix = "*"
}
if v.multipleOf != nil {
if v.roundToInt {
out.Printlnf(`if %s %s%s %% %v != 0 {`, checkPointer, pointerPrefix, value, v.valueOf(*v.multipleOf))
} else {
out.Printlnf(
`if %s math.Abs(math.Mod(%s%s, %v)) > 1e-10 {`, checkPointer, pointerPrefix, value, v.valueOf(*v.multipleOf))
}
out.Indent(1)
out.Printlnf(`return fmt.Errorf("field %%s: must be a multiple of %%v", "%s", %f)`, v.jsonName, *v.multipleOf)
out.Indent(-1)
out.Printlnf("}")
}
nMin, nMax, nMinExclusive, nMaxExclusive := mathutils.NormalizeBounds(
v.minimum, v.maximum, v.exclusiveMinimum, v.exclusiveMaximum,
)
v.genBoundary(out, checkPointer, pointerPrefix, value, nMax, nMaxExclusive, "<")
v.genBoundary(out, checkPointer, pointerPrefix, value, nMin, nMinExclusive, ">")
}
func (v *numericValidator) genBoundary(
out *codegen.Emitter,
checkPointer,
pointerPrefix,
value string,
boundary *float64,
exclusive bool,
sign string,
) {
if boundary == nil {
return
}
// Technically, this should be based on schema version, but that information is lost.
comp := sign
if exclusive {
// We're putting the other number first, so we need the = if it's exclusive.
comp += "="
} else {
sign += "="
}
out.Printlnf(`if %s%v %s%s %s {`, checkPointer, v.valueOf(*boundary), comp, pointerPrefix, value)
out.Indent(1)
out.Printlnf(`return fmt.Errorf("field %%s: must be %s %%v", "%s", %v)`, sign, v.jsonName, v.valueOf(*boundary))
out.Indent(-1)
out.Printlnf("}")
}
func (v *numericValidator) desc() *validatorDesc {
return &validatorDesc{
hasError: true,
beforeJSONUnmarshal: false,
}
}
func (v *numericValidator) valueOf(val float64) any {
if v.roundToInt {
return int64(val)
}
return val
}
func getPlainName(fieldName string) string {
if fieldName == "" {
return varNamePlainStruct
}
return fmt.Sprintf("%s.%s", varNamePlainStruct, fieldName)
}