go-ml-dev/nn

View on GitHub
output.go

Summary

Maintainability
A
0 mins
Test Coverage
F
0%
package nn

import (
    "fmt"
    "go4ml.xyz/nn/mx"
)

type Output struct {
    Name  string
    Round int
    Axis  int
    Begin int
    End   int
}

func (ly Output) Combine(a *mx.Symbol) *mx.Symbol {
    name := ly.Name
    if name == "" {
        name = fmt.Sprintf("Output%d", NextSymbolId())
    }
    if ly.Round > 0 {
        name = fmt.Sprintf("%s$RNN%02d", name, ly.Round)
    }
    out := a
    if ly.Begin != ly.End {
        a = mx.Slice(a, ly.Axis, ly.Begin, ly.End)
        a.SetName(name)
    }
    return mx.Bound(out, mx.Output(a, name))
}