go-ml-dev/nn

View on GitHub
mx/dimension.go

Summary

Maintainability
A
0 mins
Test Coverage
F
57%
package mx

import (
    "fmt"
    "go4ml.xyz/zorros"
    "strconv"
    "strings"
)

const (
    DimRow    = 0
    DimColumn = 1
    DimDepth  = 2
    DimDepth3 = 3
)

// do not change this constant
// code can assume exactly this value
const MaxDimensionCount = 4

// Array Dimension
type Dimension struct {
    Shape [MaxDimensionCount]int `yaml:"shape,flow"`
    Len   int                    `yaml:"len"`
}

func DimensionFromString(s string) (Dimension, error) {
    r := Dimension{}
    if len(s) > 0 && s != "" {
        if s[0] != '(' || s[len(s)-1] != ')' {
            return Dimension{}, zorros.Errorf("invalid dimension string")
        }
        s := s[1 : len(s)-1]
        if len(s) > 0 && s != "" {
            d := strings.Split(s, ",")
            for i, n := range d {
                v, err := strconv.ParseInt(n, 10, 32)
                if err != nil {
                    return Dimension{}, zorros.Errorf("bad dimansion value; %w", err)
                }
                r.Shape[i] = int(v)
                r.Len = i + 1
            }
        }
    }
    return r, nil
}

/*func (d *Dimension) UnmarshalYAML(value *yaml.Node) error {
    if value.Tag != "!!str" {
        return fmt.Errorf("can't decode coin")
    }

    if v, err := DimensionFromString(value.Value); err != nil {
        return err
    } else {
        *d = v
    }

    return nil
}

func (d Dimension) MarshalYAML() (interface{}, error) {
    return d.String(), nil
}*/

func (dim Dimension) Skip(n int) Dimension {
    if dim.Len <= n {
        return Dim()
    }
    d := Dimension{Len: dim.Len - n}
    copy(d.Shape[0:], dim.Shape[n:])
    return d
}

func (dim Dimension) Push(i int) Dimension {
    d := Dimension{Len: dim.Len + 1}
    copy(d.Shape[1:dim.Len+1], dim.Shape[:dim.Len])
    d.Shape[0] = i
    return d
}

func (dim Dimension) Like(b Dimension) Dimension {
    d := dim
    for i, v := range dim.Slice() {
        if v == 0 {
            d.Shape[i] = b.Shape[i]
        } else if v < 0 {
            d.Shape[i] = b.Shape[-v]
        }
    }
    return d
}

func (dim Dimension) Slice() []int {
    return dim.Shape[:dim.Len]
}

// represent array dimension as string
func (dim Dimension) String() string {
    s := "(%d,%d,%d,%d"[0:(3+(dim.Len-1)*3)] + ")"
    q := ([]interface{}{dim.Shape[0], dim.Shape[1], dim.Shape[2], dim.Shape[3]})[:dim.Len]
    return fmt.Sprintf(s, q...)
}

// check array dimension
func (dim Dimension) Good() bool {
    if dim.Len <= 0 || dim.Len > MaxDimensionCount {
        return false
    }
    for _, v := range dim.Shape[:dim.Len] {
        if v <= 0 {
            return false
        }
    }
    return true
}

func (dim Dimension) Empty() bool {
    return dim.Len <= 0
}

// sizeof whole array data
func (dim Dimension) SizeOf(dt Dtype) int {
    return dt.Size() * dim.Total()
}

// total elements in the whole array
func (dim Dimension) Total() int {
    r := 1
    for i := 0; i < dim.Len; i++ {
        r *= dim.Shape[i]
    }
    return r
}

// creates new dimension object
func Dim(a ...int) Dimension {
    var dim Dimension
    if q := len(a); q > 0 && q <= MaxDimensionCount {
        dim.Len = q
        for i, v := range a {
            dim.Shape[i] = v
        }
    }
    return dim
}