go-ml-dev/nn

View on GitHub
sgd.go

Summary

Maintainability
A
0 mins
Test Coverage
F
0%
package nn

import "go4ml.xyz/nn/mx"

type SGD struct {
    Lr, Mom, Decay float64

    LrMap map[int]float64
}

func (opt SGD) Init(e int) Optimizer {
    r := &implSGD{SGD: opt, States: make(map[*mx.NDArray]*mx.NDArray)}
    if r.Lr == 0 {
        r.Lr = locateLr(e, opt.LrMap, 0.01)
    }
    return r
}

type implSGD struct {
    SGD
    States map[*mx.NDArray]*mx.NDArray
}

func (opt *implSGD) Release() {
    for _, v := range opt.States {
        v.Release()
    }
}

func (opt *implSGD) Update(params *mx.NDArray, grads *mx.NDArray) {
    if opt.Mom != 0 {
        st, ok := opt.States[params]
        if !ok {
            st = params.NewLikeThis().Zeros()
            opt.States[params] = st
        }
        mx.SgdMomUpdate(params, grads, st, opt.Lr, opt.Mom, 0)
    }
    mx.SgdUpdate(params, grads, opt.Lr, opt.Decay)
}