go-ml-dev/nn

View on GitHub
block.go

Summary

Maintainability
A
0 mins
Test Coverage
F
29%
package nn

import (
    "go4ml.xyz/nn/mx"
)

type Block interface {
    Combine(*mx.Symbol) *mx.Symbol
}

func Combine(nn Block) *mx.Symbol {
    symbolMu.Lock()
    defer symbolMu.Unlock()
    resetSymbolId(0)
    return nn.Combine(mx.Input())
}

type BlockConnect struct {
    blocks []Block
}

func (bc *BlockConnect) Combine(s *mx.Symbol) *mx.Symbol {
    for _, b := range bc.blocks {
        s = b.Combine(s)
    }
    return s
}

func Sequence(b ...Block) Block {
    return &BlockConnect{b}
}

type BlockConcat struct {
    blocks []Block
}

func (bc *BlockConcat) Combine(s *mx.Symbol) *mx.Symbol {
    b := make([]*mx.Symbol, 0, len(bc.blocks))
    for _, v := range bc.blocks {
        if v != nil {
            x := v.Combine(s)
            b = append(b, x)
        }
    }
    return mx.Concat(b...)
}

func Concat(b ...Block) Block {
    return &BlockConcat{b}
}

type BlockStack struct {
    blocks []Block
    axis1  bool
}

func (bc *BlockStack) Combine(s *mx.Symbol) *mx.Symbol {
    b := make([]*mx.Symbol, len(bc.blocks), len(bc.blocks))
    for i, v := range bc.blocks {
        b[i] = v.Combine(s)
    }
    if bc.axis1 {
        return mx.Stack1(b...)
    }
    return mx.Stack(b...)
}

func TransStack(b ...Block) Block {
    return &BlockStack{b, true}
}

func Stack(b ...Block) Block {
    return &BlockStack{b, false}
}

type ResidualBlock struct {
    blocks []Block
}

func Residual(a ...Block) Block {
    return &ResidualBlock{a}
}

func (rcb *ResidualBlock) Combine(a *mx.Symbol) *mx.Symbol {
    for _, n := range rcb.blocks {
        a = mx.Add(a, n.Combine(a))
    }
    return a
}