model/diff.go

Summary

Maintainability
C
1 day
Test Coverage
package model

import (
    "errors"
    "reflect"
    "strconv"
    "time"
    "unsafe"
)

var timeType = reflect.TypeOf(time.Time{})
var codeConsumerType = reflect.TypeOf(CodeConsumer{})
var codeType = reflect.TypeOf(Code{})

type visit struct {
    a1  unsafe.Pointer
    a2  unsafe.Pointer
    typ reflect.Type
}

type DiffResult struct {
    DiffMap map[string]interface{}
    Equal   bool
}

func Equal(x, y interface{}) (*DiffResult, error) {
    if x == nil || y == nil {
        return nil, errors.New("use of Equal with nil value")
    }
    v1 := handlePtr(reflect.ValueOf(x))
    v2 := handlePtr(reflect.ValueOf(y))
    if v1.Type() != v2.Type() {
        return nil, errors.New("use of Equal with different type values")
    }

    result := &DiffResult{
        DiffMap: make(map[string]interface{}),
        Equal:   true,
    }

    for i, n := 0, v1.NumField(); i < n; i++ {
        // Ignore unexported fields
        if !v1.Field(i).CanInterface() {
            continue
        }
        // Ignore fields with tag `diffignore:"true"`
        if checkTag(v1.Type().Field(i)) {
            continue
        }

        if !deepValueEqual(v1.Field(i), v2.Field(i), make(map[visit]bool)) {
            result.DiffMap[v1.Type().Field(i).Name] = v2.Field(i).Interface()
            result.Equal = false
        }
    }

    return result, nil
}

func handlePtr(v reflect.Value) reflect.Value {
    if v.Kind() == reflect.Ptr {
        return v.Elem()
    }
    return v
}

func checkTag(v reflect.StructField) bool {
    b, _ := strconv.ParseBool(v.Tag.Get("diffignore"))
    return b
}

func deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool) bool {
    if !v1.IsValid() || !v2.IsValid() {
        return v1.IsValid() == v2.IsValid()
    }
    if v1.Type() != v2.Type() {
        return false
    }

    hard := func(k reflect.Kind) bool {
        switch k {
        case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
            return true
        }
        return false
    }

    if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
        addr1 := unsafe.Pointer(v1.UnsafeAddr())
        addr2 := unsafe.Pointer(v2.UnsafeAddr())
        if uintptr(addr1) > uintptr(addr2) {
            // Canonicalize order to reduce number of entries in visited.
            // Assumes non-moving garbage collector.
            addr1, addr2 = addr2, addr1
        }

        // Short circuit if references are already seen
        typ := v1.Type()
        v := visit{addr1, addr2, typ}
        if visited[v] {
            return true
        }

        // Remember for later.
        visited[v] = true
    }

    switch v1.Kind() {
    case reflect.Slice:
        // We treat a nil slice the same as an empty slice.
        if v1.Len() != v2.Len() {
            return false
        }
        if v1.Pointer() == v2.Pointer() {
            return true
        }
        for i := 0; i < v1.Len(); i++ {
            if !deepValueEqual(v1.Index(i), v2.Index(i), visited) {
                return false
            }
        }
        return true
    case reflect.Interface:
        if v1.IsNil() || v2.IsNil() {
            return v1.IsNil() == v2.IsNil()
        }
        return deepValueEqual(v1.Elem(), v2.Elem(), visited)
    case reflect.Ptr:
        return deepValueEqual(v1.Elem(), v2.Elem(), visited)
    case reflect.Struct:
        if v1.Type() == codeConsumerType {
            t1 := v1.Interface().(CodeConsumer)
            t2 := v2.Interface().(CodeConsumer)
            return deepValueEqual(reflect.ValueOf(t1.Codes()), reflect.ValueOf(t2.Codes()), visited)
        } else if v1.Type() == codeType {
            t1 := v1.Interface().(Code)
            t2 := v2.Interface().(Code)
            return t1.codeSpace == t2.codeSpace && t1.Value() == t2.Value()
        } else if v1.Type() == timeType {
            // Special case for time - we ignore the time zone.
            t1 := v1.Interface().(time.Time)
            t2 := v2.Interface().(time.Time)
            return t1.Equal(t2)
        }
        for i, n := 0, v1.NumField(); i < n; i++ {
            if !deepValueEqual(v1.Field(i), v2.Field(i), visited) {
                return false
            }
        }
        return true
    case reflect.Map:
        if v1.IsNil() != v2.IsNil() {
            return false
        }
        if v1.Len() != v2.Len() {
            return false
        }
        if v1.Pointer() == v2.Pointer() {
            return true
        }
        for _, k := range v1.MapKeys() {
            val1 := v1.MapIndex(k)
            val2 := v2.MapIndex(k)
            if !val1.IsValid() || !val2.IsValid() || !deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited) {
                return false
            }
        }
        return true
    case reflect.Func:
        if v1.IsNil() && v2.IsNil() {
            return true
        }
        // Can't do better than this:
        return false
    default:
        return v1.Interface() == v2.Interface()
    }
}