nn/gnn/slstm/slstm.go
// Copyright 2019 spaGO Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package slstm implements a Sentence-State LSTM graph neural network.
//
// Reference: "Sentence-State LSTM for Text Representation" by Zhang et al, 2018.
// (https://arxiv.org/pdf/1805.02474.pdf)
package slstm
import (
"encoding/gob"
"github.com/nlpodyssey/spago/ag"
"github.com/nlpodyssey/spago/mat"
"github.com/nlpodyssey/spago/mat/float"
"github.com/nlpodyssey/spago/nn"
)
var _ nn.Model = &Model{}
// TODO(1): code refactoring using a structure to maintain states.
// TODO(2): use a gradient policy (i.e. reinforcement learning) to increase the context with dynamic skip connections.
// Model contains the serializable parameters.
type Model struct {
nn.Module
Config Config
InputGate *HyperLinear4
LeftCellGate *HyperLinear4
RightCellGate *HyperLinear4
CellGate *HyperLinear4
SentCellGate *HyperLinear4
OutputGate *HyperLinear4
InputActivation *HyperLinear4
NonLocalSentCellGate *HyperLinear3
NonLocalInputGate *HyperLinear3
NonLocalSentOutputGate *HyperLinear3
StartH *nn.Param
EndH *nn.Param
InitValue *nn.Param
}
// Config provides configuration settings for a Sentence-State LSTM Model.
type Config struct {
InputSize int
OutputSize int
Steps int
}
const windowSize = 3 // the window is fixed in this implementation
// HyperLinear4 groups multiple params for an affine transformation.
type HyperLinear4 struct {
nn.Module
W *nn.Param
U *nn.Param
V *nn.Param
B *nn.Param
}
// HyperLinear3 groups multiple params for an affine transformation.
type HyperLinear3 struct {
nn.Module
W *nn.Param
U *nn.Param
B *nn.Param
}
// State contains nodes used during the forward step.
type State struct {
xUi []mat.Tensor
xUl []mat.Tensor
xUr []mat.Tensor
xUf []mat.Tensor
xUs []mat.Tensor
xUo []mat.Tensor
xUu []mat.Tensor
ViPrevG mat.Tensor
VlPrevG mat.Tensor
VrPrevG mat.Tensor
VfPrevG mat.Tensor
VsPrevG mat.Tensor
VoPrevG mat.Tensor
VuPrevG mat.Tensor
}
func init() {
gob.Register(&Model{})
}
// New returns a new model with parameters initialized to zeros.
func New[T float.DType](config Config) *Model {
in, out := config.InputSize, config.OutputSize
return &Model{
Config: config,
InputGate: newGate4[T](in, out),
LeftCellGate: newGate4[T](in, out),
RightCellGate: newGate4[T](in, out),
CellGate: newGate4[T](in, out),
SentCellGate: newGate4[T](in, out),
OutputGate: newGate4[T](in, out),
InputActivation: newGate4[T](in, out),
NonLocalSentCellGate: newGate3[T](out),
NonLocalInputGate: newGate3[T](out),
NonLocalSentOutputGate: newGate3[T](out),
StartH: nn.NewParam(mat.NewDense[T](mat.WithShape(out))),
EndH: nn.NewParam(mat.NewDense[T](mat.WithShape(out))),
InitValue: nn.NewParam(mat.NewDense[T](mat.WithShape(out))),
}
}
func newGate4[T float.DType](in, out int) *HyperLinear4 {
return &HyperLinear4{
W: nn.NewParam(mat.NewDense[T](mat.WithShape(out, out*windowSize))),
U: nn.NewParam(mat.NewDense[T](mat.WithShape(out, in))),
V: nn.NewParam(mat.NewDense[T](mat.WithShape(out, out))),
B: nn.NewParam(mat.NewDense[T](mat.WithShape(out))),
}
}
func newGate3[T float.DType](size int) *HyperLinear3 {
return &HyperLinear3{
W: nn.NewParam(mat.NewDense[T](mat.WithShape(size, size))),
U: nn.NewParam(mat.NewDense[T](mat.WithShape(size, size))),
B: nn.NewParam(mat.NewDense[T](mat.WithShape(size))),
}
}
// Forward performs the forward step for each input node and returns the result.
func (m *Model) Forward(xs ...mat.Tensor) []mat.Tensor {
steps := m.Config.Steps
n := len(xs)
h := make([][]mat.Tensor, steps)
c := make([][]mat.Tensor, steps)
g := make([]mat.Tensor, steps)
cg := make([]mat.Tensor, steps)
h[0] = make([]mat.Tensor, n)
c[0] = make([]mat.Tensor, n)
g[0] = m.InitValue
cg[0] = m.InitValue
for i := 0; i < n; i++ {
h[0][i] = m.InitValue
c[0][i] = m.InitValue
}
s := &State{}
m.computeUx(s, xs) // the result is shared among all steps
for t := 1; t < m.Config.Steps; t++ {
m.computeVg(s, g[t-1]) // the result is shared among all nodes of the same step
h[t], c[t] = m.updateHiddenNodes(s, h[t-1], c[t-1], g[t-1])
g[t], cg[t] = m.updateSentenceState(h[t-1], c[t-1], g[t-1])
}
return h[len(h)-1]
}
func (m *Model) computeUx(s *State, xs []mat.Tensor) {
n := len(xs)
s.xUi = make([]mat.Tensor, n)
s.xUl = make([]mat.Tensor, n)
s.xUr = make([]mat.Tensor, n)
s.xUf = make([]mat.Tensor, n)
s.xUs = make([]mat.Tensor, n)
s.xUo = make([]mat.Tensor, n)
s.xUu = make([]mat.Tensor, n)
for i := 0; i < n; i++ {
s.xUi[i] = ag.Mul(m.InputGate.U, xs[i])
s.xUl[i] = ag.Mul(m.LeftCellGate.U, xs[i])
s.xUr[i] = ag.Mul(m.RightCellGate.U, xs[i])
s.xUf[i] = ag.Mul(m.CellGate.U, xs[i])
s.xUs[i] = ag.Mul(m.SentCellGate.U, xs[i])
s.xUo[i] = ag.Mul(m.OutputGate.U, xs[i])
s.xUu[i] = ag.Mul(m.InputActivation.U, xs[i])
}
}
func (m *Model) computeVg(s *State, prevG mat.Tensor) {
s.ViPrevG = ag.Mul(m.InputGate.V, prevG)
s.VlPrevG = ag.Mul(m.LeftCellGate.V, prevG)
s.VrPrevG = ag.Mul(m.RightCellGate.V, prevG)
s.VfPrevG = ag.Mul(m.CellGate.V, prevG)
s.VsPrevG = ag.Mul(m.SentCellGate.V, prevG)
s.VoPrevG = ag.Mul(m.OutputGate.V, prevG)
s.VuPrevG = ag.Mul(m.InputActivation.U, prevG)
}
func (m *Model) updateHiddenNodes(s *State, prevH []mat.Tensor, prevC []mat.Tensor, prevG mat.Tensor) ([]mat.Tensor, []mat.Tensor) {
n := len(prevH)
h := make([]mat.Tensor, n)
c := make([]mat.Tensor, n)
for i := 0; i < n; i++ {
h[i], c[i] = m.processNode(s, i, prevH, prevC, prevG)
}
return h, c
}
func (m *Model) updateSentenceState(prevH []mat.Tensor, prevC []mat.Tensor, prevG mat.Tensor) (mat.Tensor, mat.Tensor) {
n := len(prevH)
avgH := ag.Mean(prevH)
fG := ag.Sigmoid(ag.Affine(m.NonLocalSentCellGate.B, m.NonLocalSentCellGate.W, prevG, m.NonLocalSentCellGate.U, avgH))
oG := ag.Sigmoid(ag.Affine(m.NonLocalSentOutputGate.B, m.NonLocalSentOutputGate.W, prevG, m.NonLocalSentOutputGate.U, avgH))
hG := make([]mat.Tensor, n)
gG := ag.Affine(m.NonLocalInputGate.B, m.NonLocalInputGate.W, prevG)
for i := 0; i < n; i++ {
hG[i] = ag.Sigmoid(ag.Add(gG, ag.Mul(m.NonLocalInputGate.U, prevH[i])))
}
var sum mat.Tensor
for i := 0; i < n; i++ {
sum = ag.Add(sum, ag.Prod(hG[i], prevC[i]))
}
cg := ag.Add(ag.Prod(fG, prevG), sum)
gt := ag.Prod(oG, ag.Tanh(cg))
return gt, cg
}
func (m *Model) processNode(s *State, i int, prevH []mat.Tensor, prevC []mat.Tensor, prevG mat.Tensor) (h mat.Tensor, c mat.Tensor) {
n := len(prevH)
first := 0
last := n - 1
j := i - 1
k := i + 1
var prevHj, prevCj mat.Tensor
if j < first {
prevHj, prevCj = m.StartH, m.StartH
} else {
prevHj, prevCj = prevH[j], prevC[j]
}
var prevHk, prevCk mat.Tensor
if k > last {
prevHk, prevCk = m.EndH, m.EndH
} else {
prevHk, prevCk = prevH[k], prevC[k]
}
context := ag.Concat(prevHj, prevH[i], prevHk)
iG := ag.Sigmoid(ag.Sum(m.InputGate.B, ag.Mul(m.InputGate.W, context), s.ViPrevG, s.xUi[i]))
lG := ag.Sigmoid(ag.Sum(m.LeftCellGate.B, ag.Mul(m.LeftCellGate.W, context), s.VlPrevG, s.xUl[i]))
rG := ag.Sigmoid(ag.Sum(m.RightCellGate.B, ag.Mul(m.RightCellGate.W, context), s.VrPrevG, s.xUr[i]))
fG := ag.Sigmoid(ag.Sum(m.CellGate.B, ag.Mul(m.CellGate.W, context), s.VfPrevG, s.xUf[i]))
sG := ag.Sigmoid(ag.Sum(m.SentCellGate.B, ag.Mul(m.SentCellGate.W, context), s.VsPrevG, s.xUs[i]))
oG := ag.Sigmoid(ag.Sum(m.OutputGate.B, ag.Mul(m.OutputGate.W, context), s.VoPrevG, s.xUo[i]))
uG := ag.Tanh(ag.Sum(m.InputActivation.B, ag.Mul(m.InputActivation.W, context), s.VuPrevG, s.xUu[i]))
c1 := ag.Prod(lG, prevCj)
c2 := ag.Prod(fG, prevC[i])
c3 := ag.Prod(rG, prevCk)
c4 := ag.Prod(sG, prevG)
c5 := ag.Prod(iG, uG)
c = ag.Sum(c1, c2, c3, c4, c5)
h = ag.Prod(oG, ag.Tanh(c))
return
}