mx/ndarray.go
package mx
import (
"fmt"
"go4ml.xyz/nn/mx/capi"
"reflect"
"runtime"
"unsafe"
)
type NDArray struct {
ctx Context
dim Dimension
dtype Dtype
handle capi.NDArrayHandle
}
func release(a *NDArray) {
if a != nil {
capi.ReleaseNDArry(a.handle)
a.handle = nil
}
}
func (a *NDArray) Release() {
release(a)
}
func Array(tp Dtype, d Dimension) *NDArray {
return CPU.Array(tp, d)
}
func (c Context) Array(tp Dtype, d Dimension, vals ...interface{}) *NDArray {
if !d.Good() {
panic(fmt.Sprintf("failed to create array %v%v: bad dimension", tp.String(), d.String()))
}
a := &NDArray{ctx: c, dim: d, dtype: tp}
a.handle = capi.NewNDArrayHandle(c.DevType(), c.DevNo(), int(tp), d.Shape, d.Len)
if len(vals) > 0 {
a.SetValues(vals...)
}
runtime.SetFinalizer(a, release)
return a
}
func (c Context) CopyAs(a *NDArray, dtype Dtype) *NDArray {
if a == nil || a.handle == nil {
panic("can't copy broken array")
}
b := c.Array(dtype, a.dim)
capi.ImperativeInvokeInOut1(capi.OpCopyTo, a.handle, b.handle)
return b
}
func (a *NDArray) NewLikeThis() *NDArray {
return a.ctx.Array(a.dtype, a.dim)
}
func (a *NDArray) Context() Context {
return a.ctx
}
func (a *NDArray) Dtype() Dtype {
return a.dtype
}
func (a *NDArray) Dim() Dimension {
return a.dim
}
func (a *NDArray) Cast(dt Dtype) *NDArray {
return nil
}
func (a *NDArray) Reshape(dim Dimension) *NDArray {
return nil
}
func (a *NDArray) String() string {
return ""
}
func (a *NDArray) Depth() int {
return a.dim.Len
}
func (a *NDArray) Len(d int) int {
if d < 0 || d >= 3 {
return 0
}
if a.dim.Len <= d {
return 1
}
return a.dim.Shape[d]
}
func (a *NDArray) Size() int {
return a.dim.SizeOf(a.dtype)
}
var typemap = map[Dtype]reflect.Type{
Float64: reflect.TypeOf(float64(0)),
Float32: reflect.TypeOf(float32(0)),
Int8: reflect.TypeOf(int8(0)),
Uint8: reflect.TypeOf(uint8(0)),
Int32: reflect.TypeOf(int32(0)),
Int64: reflect.TypeOf(int64(0)),
}
var rtypemap = map[reflect.Type]Dtype{
reflect.TypeOf(float64(0)): Float64,
reflect.TypeOf(float32(0)): Float32,
reflect.TypeOf(int8(0)): Int8,
reflect.TypeOf(uint8(0)): Uint8,
reflect.TypeOf(int32(0)): Int32,
reflect.TypeOf(int64(0)): Int64,
}
func copyTo(s reflect.Value, n int, v0 reflect.Value, dt reflect.Type) int {
if v0.Kind() == reflect.Interface {
v0 = v0.Elem()
}
if v0.Kind() == reflect.Slice || v0.Kind() == reflect.Array {
if v0.Type() == reflect.SliceOf(dt) && s.Len()-n >= v0.Len() {
n += reflect.Copy(s.Slice(n, s.Len()), v0)
} else {
for i := 0; i < v0.Len(); i++ {
n = copyTo(s, n, v0.Index(i), dt)
}
}
} else {
switch v0.Kind() {
case reflect.Int, reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16,
reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64,
reflect.Float32, reflect.Float64:
if s.Len() <= n {
panic("too many elements to copy")
}
s.Index(n).Set(v0.Convert(dt))
n++
default:
panic("can't initialize with non numeric type " + v0.Type().String())
}
}
return n
}
func (a *NDArray) SetValues(vals ...interface{}) {
if a == nil || a.handle == nil {
panic("can't initialize broken array")
}
if a.dtype == Float16 {
q := CPU.CopyAs(a, Float32)
defer q.Release()
q.SetValues(vals...)
capi.ImperativeInvokeInOut1(capi.OpCopyTo, q.handle, a.handle)
return
}
dt, ok := typemap[a.dtype]
if !ok {
panic(fmt.Sprintf("initialization with dtype %v is unsupportd", a.dtype))
}
s := reflect.ValueOf(vals[0])
if len(vals) != 1 || s.Type() != reflect.SliceOf(dt) || s.Len() != a.dim.Total() {
s = reflect.MakeSlice(reflect.SliceOf(dt), a.dim.Total(), a.dim.Total())
n := copyTo(s, 0, reflect.ValueOf(vals), dt)
if n != a.dim.Total() {
panic("not enough elements to set value")
}
}
capi.SetNDArrayRawData(a.handle, unsafe.Pointer(s.Index(0).UnsafeAddr()), a.dim.Total())
}
func (a *NDArray) Raw() []byte {
ln := a.dim.Total()
bs := make([]byte, ln)
capi.GetNDArrayRawData(a.handle, unsafe.Pointer(&bs[0]), ln)
return bs
}
func (a *NDArray) Values(dtype Dtype) interface{} {
if dtype == Float16 {
panic("can't gate values in Float16 format")
}
q := a
ln := q.dim.Total()
if q.dtype != dtype {
q = CPU.CopyAs(q, dtype)
defer q.Release()
}
vals := reflect.MakeSlice(reflect.SliceOf(typemap[dtype]), ln, ln)
capi.GetNDArrayRawData(q.handle, unsafe.Pointer(vals.Index(0).UnsafeAddr()), ln)
return vals.Interface()
}
func (a *NDArray) ValuesF32() []float32 {
return a.Values(Float32).([]float32)
}
func (a *NDArray) CopyValuesTo(dst interface{}) {
q := a
ln := q.dim.Total()
s := reflect.ValueOf(dst)
t, ok := rtypemap[s.Index(0).Type()]
if !ok {
panic("invalid destination type " + s.Type().String())
}
if q.dtype != t {
q = CPU.CopyAs(q, t)
defer q.Release()
}
capi.GetNDArrayRawData(q.handle, unsafe.Pointer(s.Index(0).UnsafeAddr()), ln)
}
func (a *NDArray) ReCopyValuesTo(dst interface{}) {
q := a
ln := q.dim.Total()
s := reflect.ValueOf(dst)
t, ok := rtypemap[s.Index(0).Type()]
if !ok {
panic("invalid destination type " + s.Type().String())
}
q = CPU.CopyAs(q, t)
defer q.Release()
capi.GetNDArrayRawData(q.handle, unsafe.Pointer(s.Index(0).UnsafeAddr()), ln)
}