omissis/go-jsonschema

View on GitHub
pkg/generator/yaml_formatter.go

Summary

Maintainability
A
1 hr
Test Coverage
package generator

import (
    "strings"

    "github.com/atombender/go-jsonschema/pkg/codegen"
)

const (
    formatYAML  = "yaml"
    YAMLPackage = "gopkg.in/yaml.v3"
)

type yamlFormatter struct{}

func (yf *yamlFormatter) generate(declType codegen.TypeDecl, validators []validator) func(*codegen.Emitter) {
    var beforeValidators []validator

    var afterValidators []validator

    forceBefore := false

    for _, v := range validators {
        desc := v.desc()
        if desc.beforeJSONUnmarshal {
            beforeValidators = append(beforeValidators, v)
        } else {
            afterValidators = append(afterValidators, v)
            forceBefore = forceBefore || desc.requiresRawAfter
        }
    }

    return func(out *codegen.Emitter) {
        out.Commentf("Unmarshal%s implements %s.Unmarshaler.", strings.ToUpper(formatYAML), formatYAML)
        out.Printlnf("func (j *%s) Unmarshal%s(value *yaml.Node) error {", declType.Name,
            strings.ToUpper(formatYAML))
        out.Indent(1)

        if forceBefore || len(beforeValidators) != 0 {
            out.Printlnf("var %s map[string]interface{}", varNameRawMap)
            out.Printlnf("if err := value.Decode(&%s); err != nil { return err }", varNameRawMap)
        }

        for _, v := range beforeValidators {
            v.generate(out)
        }

        out.Printlnf("type Plain %s", declType.Name)
        out.Printlnf("var %s Plain", varNamePlainStruct)
        out.Printlnf("if err := value.Decode(&%s); err != nil { return err }", varNamePlainStruct)

        for _, v := range afterValidators {
            v.generate(out)
        }

        out.Printlnf("*j = %s(%s)", declType.Name, varNamePlainStruct)
        out.Printlnf("return nil")
        out.Indent(-1)
        out.Printlnf("}")
    }
}

func (yf *yamlFormatter) enumMarshal(declType codegen.TypeDecl) func(*codegen.Emitter) {
    return func(out *codegen.Emitter) {
        out.Commentf("Marshal%s implements %s.Marshal.", strings.ToUpper(formatYAML), formatYAML)
        out.Printlnf("func (j *%s) Marshal%s() (interface{}, error) {", declType.Name,
            strings.ToUpper(formatYAML))
        out.Indent(1)
        out.Printlnf("return %s.Marshal(j.Value)", formatYAML)
        out.Indent(-1)
        out.Printlnf("}")
    }
}

func (yf *yamlFormatter) enumUnmarshal(
    declType codegen.TypeDecl,
    enumType codegen.Type,
    valueConstant *codegen.Var,
    wrapInStruct bool,
) func(*codegen.Emitter) {
    return func(out *codegen.Emitter) {
        out.Commentf("Unmarshal%s implements %s.Unmarshaler.", strings.ToUpper(formatYAML), formatYAML)
        out.Printlnf("func (j *%s) Unmarshal%s(value *yaml.Node) error {", declType.Name,
            strings.ToUpper(formatYAML))
        out.Indent(1)
        out.Printf("var v ")
        enumType.Generate(out)
        out.Newline()

        varName := "v"
        if wrapInStruct {
            varName += ".Value"
        }

        out.Printlnf("if err := value.Decode(&%s); err != nil { return err }", varName)
        out.Printlnf("var ok bool")
        out.Printlnf("for _, expected := range %s {", valueConstant.Name)
        out.Printlnf("if reflect.DeepEqual(%s, expected) { ok = true; break }", varName)
        out.Printlnf("}")
        out.Printlnf("if !ok {")
        out.Printlnf(`return fmt.Errorf("invalid value (expected one of %%#v): %%#v", %s, %s)`,
            valueConstant.Name, varName)
        out.Printlnf("}")
        out.Printlnf(`*j = %s(v)`, declType.Name)
        out.Printlnf(`return nil`)
        out.Indent(-1)
        out.Printlnf("}")
    }
}

func (yf *yamlFormatter) addImport(out *codegen.File) {
    out.Package.AddImport(YAMLPackage, "yaml")
}