resource/preparer.go
// Copyright © 2016 Asteris, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package resource
import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
"github.com/Sirupsen/logrus"
"github.com/arbovm/levenshtein"
multierror "github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
"golang.org/x/net/context"
)
var (
durationType = reflect.TypeOf(time.Duration(0))
timeType = reflect.TypeOf(time.Time{})
)
const (
// longForm layout for time parsing
longForm = "2006-01-02T15:04:05"
// shortForm layout for time parsing
shortForm = "2006-01-02"
)
// Preparer wraps and implements resource.Resource in order to deserialize into
// regular Preparers
type Preparer struct {
Source map[string]interface{}
Destination Resource
}
// NewPreparer wraps a given resource in this preparer
func NewPreparer(r Resource) *Preparer {
return &Preparer{
Source: make(map[string]interface{}),
Destination: r,
}
}
// NewPreparerWithSource creates a new preparer with the source included
func NewPreparerWithSource(r Resource, source map[string]interface{}) *Preparer {
prep := NewPreparer(r)
prep.Source = source
return prep
}
// Prepare the destination to prepare itself.
func (p *Preparer) Prepare(ctx context.Context, r Renderer) (Task, error) {
value := reflect.ValueOf(p.Destination)
typ := value.Type()
wasPtr := false // so we can re-wrap later if we need to
if typ.Kind() == reflect.Ptr {
wasPtr = true
typ = typ.Elem()
value = value.Elem()
}
if typ.Kind() != reflect.Struct {
return nil, errors.New("Preparer can only wrap structs")
}
if err := p.validateExtra(typ); err != nil {
return nil, err
}
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
if field.Anonymous {
continue
}
val, err := p.getValueForField(r, field)
if err != nil {
return nil, err
}
fieldValue := value.Field(i)
if fieldValue.CanSet() {
fieldValue.Set(val)
}
}
if wasPtr && value.CanAddr() {
value = value.Addr()
}
unwrapped := value.Interface()
resource, ok := unwrapped.(Resource)
if !ok {
return nil, errors.New("unwrapped was not a Resource")
}
return resource.Prepare(ctx, r)
}
func (p *Preparer) validateExtra(typ reflect.Type) error {
if typ.Kind() != reflect.Struct {
return errors.New("can't validate extra on a non-struct type")
}
fieldNames := map[string]struct{}{}
for i := 0; i < typ.NumField(); i++ {
fieldNames[p.getFieldName(typ.Field(i))] = struct{}{}
}
// add special fields
fieldNames["depends"] = struct{}{}
fieldNames["group"] = struct{}{}
var err error
for key := range p.Source {
if _, ok := fieldNames[key]; ok {
continue
}
// check for spelling errors. Deploy the Levenshtein distance algorithm!
var candidates []string
for candidate := range fieldNames {
if levenshtein.Distance(key, candidate) <= 5 {
candidates = append(candidates, candidate)
}
}
var msg string
if len(candidates) > 0 {
msg = " Maybe you meant: " + strings.Join(candidates, ", ")
}
err = multierror.Append(
err,
fmt.Errorf("I don't have a field named %q.%s", key, msg),
)
}
return err
}
// getValueForField retrieves and converts the value for a given field
func (p *Preparer) getValueForField(r Renderer, field reflect.StructField) (reflect.Value, error) {
// get the field name for use in future lookups
name := p.getFieldName(field)
raw, isSet := p.Source[name]
// validate that the param is present, if required
if err := p.validateRequired(field, raw); err != nil {
return reflect.Zero(field.Type), err
}
// return a default type if nothing is set. No need to do any conversions or
// anything in this case, we're simply returning the zero value of the field.
if !isSet {
return reflect.Zero(field.Type), nil
}
// now that we know the field is present, we can make sure it's not
// violating any mutual exclusion constraints
if err := p.validateMutuallyExclusive(field); err != nil {
return reflect.Zero(field.Type), err
}
// get the base for numeric conversion, if present
base, err := p.getBase(field)
if err != nil {
return reflect.Zero(field.Type), err
}
// finally after all those checks we can deserialize the value of the field
// from the interface{}!
value, err := p.convertValue(field.Type, r, name, raw, base)
if err != nil {
return value, err
}
// validate that the param is nonempty, if a value is required
if err := p.validateNonempty(field, value); err != nil {
return reflect.Zero(field.Type), err
}
// validate results
if err := p.validateValidValues(field, r, base, value); err != nil {
return reflect.Zero(field.Type), err
}
return value, nil
}
// getFieldName extracts a field name from either the "hcl" tag or the field
// name itself.
func (p *Preparer) getFieldName(field reflect.StructField) string {
if raw, ok := field.Tag.Lookup("hcl"); ok {
return strings.SplitN(raw, ",", 1)[0]
}
return field.Name
}
// getBase returns the base for the conversion of strings to numbers, defaulting
// to base 10.
func (p *Preparer) getBase(field reflect.StructField) (int, error) {
if raw, ok := field.Tag.Lookup("base"); ok {
base, err := strconv.Atoi(raw)
if err != nil {
return 0, errors.Wrap(err, "could not convert base tag to int")
}
return base, nil
}
return 10, nil
}
// validateRequired detects if the value is required but not provided
func (p *Preparer) validateRequired(field reflect.StructField, val interface{}) error {
if required, ok := field.Tag.Lookup("required"); ok && required == "true" && val == nil {
return fmt.Errorf("%q is required", p.getFieldName(field))
}
return nil
}
// validateNonempty detects if the value provided is empty (or the zero value of
// its type), but should be nonempty
func (p *Preparer) validateNonempty(field reflect.StructField, value reflect.Value) error {
if nonempty, ok := field.Tag.Lookup("nonempty"); ok && nonempty == "true" && value.Interface() == reflect.Zero(field.Type).Interface() {
return fmt.Errorf("%q must be nonempty", p.getFieldName(field))
}
return nil
}
// validateMutuallyExclusive detects if multiple mutually exclusive fields are
// set
func (p *Preparer) validateMutuallyExclusive(field reflect.StructField) error {
if mutuallyexclusives, ok := field.Tag.Lookup("mutually_exclusive"); ok {
name := p.getFieldName(field)
exclusives := strings.Split(mutuallyexclusives, ",")
for _, mutuallyexclusive := range exclusives {
if mutuallyexclusive == name {
continue
}
if _, ok := p.Source[mutuallyexclusive]; ok {
err := "only one of "
if len(exclusives) == 2 {
err += `"` + exclusives[0] + `" or "` + exclusives[1] + `"`
} else {
for i, exclusive := range exclusives {
err += `"` + exclusive + `"`
if i+1 != len(exclusives) {
err += ", "
}
if i+1 == len(exclusives)-1 {
err += "or "
}
}
}
err += " can be set"
return errors.New(err)
}
}
}
return nil
}
// validateValidValues detects if the provided value is within an acceptable set
// of values.
func (p *Preparer) validateValidValues(field reflect.StructField, r Renderer, base int, value reflect.Value) error {
if valids, ok := field.Tag.Lookup("valid_values"); ok {
name := p.getFieldName(field)
for _, valid := range strings.Split(valids, ",") {
parsed, err := p.convertValue(field.Type, r, name, valid, base)
if err != nil {
return errors.Wrapf(err, "invalid value for %s: %s", field.Type.Kind(), valid)
}
if parsed.Interface() == value.Interface() {
return nil
}
}
return fmt.Errorf("value did not pass validation. Must be one of %q, was %q", valids, value)
}
return nil
}
// convertValue converts and returns the value of an individual element
func (p *Preparer) convertValue(typ reflect.Type, r Renderer, name string, val interface{}, base int) (out reflect.Value, err error) {
switch typ {
case durationType:
out, err = p.convertDuration(typ, r, name, val, base)
case timeType:
out, err = p.convertTime(typ, r, name, val, base)
default:
switch typ.Kind() {
case reflect.Bool:
out, err = p.convertBool(r, name, val)
case reflect.String:
out, err = p.convertString(r, name, val)
case reflect.Interface:
out, err = p.convertInterface(typ, r, name, val, base)
case reflect.Map:
out, err = p.convertMap(typ, r, name, val, base)
case reflect.Slice:
out, err = p.convertSlice(typ, r, name, val, base)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64:
out, err = p.convertNumber(typ, r, name, val, base)
case reflect.Ptr:
out, err = p.convertPointer(typ, r, name, val, base)
default:
logrus.WithFields(logrus.Fields{
"field": name,
"type": typ.Kind(),
}).Warn("could not render field type, using zero value")
out = reflect.Zero(typ)
}
}
if err != nil {
return out, err
}
return p.realias(out, typ)
}
// realias restores type information lost when converting. Since we convert
// based on the kind of the type, that information gets lost in the case of
// alias types (e.g. `type State string`.) Fortunately, we can just add this
// type information back in by converting, so that's what we do.
func (p *Preparer) realias(val reflect.Value, typ reflect.Type) (reflect.Value, error) {
if val.Type() != typ {
if !val.Type().ConvertibleTo(typ) {
return val, fmt.Errorf("cannot re-alias %s to %s", val.Type(), typ)
}
return val.Convert(typ), nil
}
return val, nil
}
// convertDuration converts a value to time.Duration
func (p *Preparer) convertDuration(typ reflect.Type, r Renderer, name string, val interface{}, base int) (reflect.Value, error) {
if val == nil {
return reflect.Zero(typ), nil
}
switch reflect.ValueOf(val).Kind() {
case reflect.Int:
num, err := p.convertNumber(typ, r, name, val, base)
if err != nil {
return reflect.Zero(typ), errors.Wrapf(err, "could not convert %v to duration", val)
}
dur := time.Duration(num.Int() * 1E9)
return reflect.ValueOf(dur), nil
case reflect.String:
dur, err := time.ParseDuration(val.(string))
if err != nil {
return reflect.Zero(typ), errors.Wrapf(err, "could not convert %s to duration", val)
}
return reflect.ValueOf(dur), nil
default:
return reflect.Zero(typ), fmt.Errorf("cannot handle duration conversion of %v", reflect.ValueOf(val).Kind())
}
}
// convertTime converts a value to time.Time
//
// Parsing is attempted with up to three layouts in the following order:
// 1. RFC3999, the timezone provided is used
// 2. YYYY-MM-DDThh:mm:ss, the system timezone is used
// 3. YYYY-MM-DD, the system timezone is used
func (p *Preparer) convertTime(typ reflect.Type, r Renderer, name string, val interface{}, base int) (reflect.Value, error) {
if val == nil {
return reflect.Zero(typ), nil
}
switch typ {
case timeType:
// parse the time with the zone provided
zoneTime, ztErr := time.Parse(time.RFC3339, val.(string))
if ztErr != nil {
// obtain the system timezone
zone := time.FixedZone(time.Now().In(time.Local).Zone())
// parse the time with the system zone
longTime, ltErr := time.ParseInLocation(longForm, val.(string), zone)
if ltErr != nil {
// parse the time as a date
shortTime, stErr := time.ParseInLocation(shortForm, val.(string), zone)
if stErr != nil {
errs := fmt.Errorf("could not convert %s to time.Time any of:\n1. %v\n2. %v\n3. %v\n", name, ztErr, ltErr, stErr)
return reflect.Zero(typ), errs
}
return reflect.ValueOf(shortTime), nil
}
return reflect.ValueOf(longTime), nil
}
return reflect.ValueOf(zoneTime), nil
default:
return reflect.Zero(typ), fmt.Errorf("cannot handle time conversion of %v", reflect.ValueOf(val).Kind())
}
}
// convertBool converts a value to bool using the following rules:
//
// - bool values are used without conversion
// - string values for truth are any capitalization of "t" or "true"
// - any other string value is false
func (p *Preparer) convertBool(r Renderer, name string, val interface{}) (reflect.Value, error) {
if val == nil {
return reflect.ValueOf(false), nil
}
switch t := val.(type) {
case bool:
return reflect.ValueOf(val), nil
case string:
boolish, err := r.Render(name, t)
if err != nil {
return reflect.ValueOf(false), errors.Wrapf(err, "error rendering field %s", name)
}
switch strings.ToLower(boolish) {
case "t", "true":
return reflect.ValueOf(true), nil
default:
return reflect.ValueOf(false), nil
}
default:
return reflect.ValueOf(false), fmt.Errorf("don't know how to convert %T to bool", t)
}
}
// convertNumber converts interfaces to numbers
func (p *Preparer) convertNumber(typ reflect.Type, r Renderer, name string, val interface{}, base int) (reflect.Value, error) {
if val == nil {
return reflect.Zero(typ), nil
}
var (
num string
err error
)
switch t := val.(type) {
case string:
num, err = r.Render(name, t)
if err != nil {
return reflect.Zero(typ), errors.Wrapf(err, "error rendering field %s", name)
}
default:
// we're taking an odd approach here. If we have a number already we're
// converting it to a string because JSON (and thus HCL) numbers are all
// floating point. Therefore there's no guarantee that we've parsed
// using the correct semantics (signed vs unsigned vs float) or bitsize.
num = fmt.Sprintf("%v", val)
// if we already have a number type, we can assume it is already in base 10
base = 10
}
// parse the number back out, depending on the type
switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
raw, err := strconv.ParseInt(num, base, typ.Bits())
if err != nil {
return reflect.Zero(typ), errors.Wrapf(err, "could not convert %s to %s", num, typ.Kind())
}
switch typ.Kind() {
case reflect.Int:
return reflect.ValueOf(int(raw)), nil
case reflect.Int8:
return reflect.ValueOf(int8(raw)), nil
case reflect.Int16:
return reflect.ValueOf(int16(raw)), nil
case reflect.Int32:
return reflect.ValueOf(int32(raw)), nil
case reflect.Int64:
return reflect.ValueOf(raw), nil
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
raw, err := strconv.ParseUint(num, base, typ.Bits())
if err != nil {
return reflect.Zero(typ), errors.Wrapf(err, "could not convert %s to %s", num, typ.Kind())
}
switch typ.Kind() {
case reflect.Uint:
return reflect.ValueOf(uint(raw)), nil
case reflect.Uint8:
return reflect.ValueOf(uint8(raw)), nil
case reflect.Uint16:
return reflect.ValueOf(uint16(raw)), nil
case reflect.Uint32:
return reflect.ValueOf(uint32(raw)), nil
case reflect.Uint64:
return reflect.ValueOf(raw), nil
}
case reflect.Float32, reflect.Float64:
raw, err := strconv.ParseFloat(num, typ.Bits())
if err != nil {
return reflect.Zero(typ), errors.Wrapf(err, "could not convert %s to %s", num, typ.Kind())
}
switch typ.Kind() {
case reflect.Float32:
return reflect.ValueOf(float32(raw)), nil
case reflect.Float64:
return reflect.ValueOf(raw), nil
}
}
return reflect.Zero(typ), fmt.Errorf("can't parse a number from %s", typ.Kind())
}
// convertString converts and renders any strings given it
func (p *Preparer) convertString(r Renderer, name string, val interface{}) (reflect.Value, error) {
if val == nil {
return reflect.ValueOf(""), nil
}
strVal, ok := val.(string)
if !ok {
return reflect.ValueOf(""), fmt.Errorf("value was not a string: %v", val)
}
rendered, err := r.Render(name, strVal)
if err != nil {
return reflect.ValueOf(""), errors.Wrapf(err, "error rendering field %s", name)
}
return reflect.ValueOf(rendered), nil
}
// convertInterface reflects back to convertValue in the case of non-zero
// values. There's really not much to see here.
func (p *Preparer) convertInterface(typ reflect.Type, r Renderer, name string, val interface{}, base int) (reflect.Value, error) {
if val == nil {
return reflect.Zero(typ), nil
}
raw, err := p.convertValue(reflect.TypeOf(val), r, name, val, base)
if err != nil {
return raw, err
}
return p.maybeUnwrapMap(raw), nil
}
// convertMap properly converts and renders both keys and values
func (p *Preparer) convertMap(typ reflect.Type, r Renderer, name string, val interface{}, base int) (reflect.Value, error) {
if val == nil {
return reflect.Zero(typ), nil
}
values := p.maybeUnwrapMap(reflect.ValueOf(val))
if values.Kind() != reflect.Map {
return reflect.Zero(typ), fmt.Errorf("expected map for %q, got %T", name, val)
}
acc := reflect.MakeMap(typ)
for i, key := range values.MapKeys() {
// key
k, err := p.convertValue(
typ.Key(),
r,
fmt.Sprintf("%s.%d.key", name, i),
key.Interface(),
base,
)
if err != nil {
return reflect.Zero(typ), errors.Wrapf(err, "could not render %s.%d.key", name, i)
}
// value
v, err := p.convertValue(
typ.Elem(),
r,
fmt.Sprintf("%s.%d.value", name, i),
values.MapIndex(key).Interface(),
base,
)
if err != nil {
return reflect.Zero(typ), errors.Wrapf(err, "could not render %s.%d.value", name, i)
}
acc.SetMapIndex(k, v)
}
return acc, nil
}
func (p *Preparer) maybeUnwrapMap(val reflect.Value) reflect.Value {
typ := val.Type()
// HCL does this annoying thing where it deserializes into lists by default.
// So our val might be a list with one map at index 0. Hooray!
if typ.Kind() == reflect.Slice && val.Len() == 1 && typ.Elem().Kind() == reflect.Map {
val = val.Index(0)
}
return val
}
// convertSlice properly converts and renders all elements in a slice.
func (p *Preparer) convertSlice(typ reflect.Type, r Renderer, name string, val interface{}, base int) (reflect.Value, error) {
if val == nil {
return reflect.Zero(typ), nil
}
values := reflect.ValueOf(val)
if values.Kind() != reflect.Slice {
return reflect.Zero(typ), fmt.Errorf("expected slice for %q, got %T", name, val)
}
acc := reflect.MakeSlice(typ, values.Len(), values.Cap())
for i := 0; i < values.Len(); i++ {
item, err := p.convertValue(
typ.Elem(),
r,
fmt.Sprintf("%s.%d", name, i),
values.Index(i).Interface(),
base,
)
if err != nil {
return reflect.Zero(typ), errors.Wrapf(err, "could not render %s.%d", name, i)
}
acc.Index(i).Set(item)
}
return acc, nil
}
// convertPointer wraps whatever value we have in a pointer
func (p *Preparer) convertPointer(typ reflect.Type, r Renderer, name string, val interface{}, base int) (reflect.Value, error) {
inner, err := p.convertValue(typ.Elem(), r, name, val, base)
if err != nil {
return inner, err
}
out := reflect.New(typ.Elem())
out.Elem().Set(inner)
return out, nil
}