
View on GitHub


0 mins
Test Coverage
package main

import (


const (
    epochs   = 100  // number of training epochs
    examples = 1000 // number of training examples

// Linear defines a Linear module
type Linear struct {
    W *nn.Param
    B *nn.Param

// NewLinear creates a new Linear module with the specified input and output dimensions
func NewLinear[T float.DType](in, out int) *Linear {
    return &Linear{
        W: nn.NewParam(mat.NewDense[T](mat.WithShape(out, in))),
        B: nn.NewParam(mat.NewDense[T](mat.WithShape(out))),

// InitRandom initializes the Linear module with random weights using the Xavier uniform distribution
func (m *Linear) InitRandom(seed uint64) *Linear {
    initializers.XavierUniform(m.W.Value().(mat.Matrix), 1.0, rand.NewLockedRand(seed))
    return m

// Forward applies the forward pass of the Linear module to the input x
func (m *Linear) Forward(x mat.Tensor) mat.Tensor {
    return ag.Add(ag.Mul(m.W, x), m.B)

type T = float64

func main() {
    m := NewLinear[T](1, 1).InitRandom(42)

    strategy := sgd.New[T](sgd.NewConfig(0.001, 0.9, true))
    optimizer := optimizers.New(nn.Parameters(m), strategy)

    normalize := func(x T) T { return x / T(examples) }
    objective := func(x T) T { return 3*x + 1 }
    criterion := losses.MSE

    learn := func(input, expected T) (T, error) {
        x, target := mat.Scalar(input), mat.Scalar(expected)
        y := m.Forward(x)
        loss := criterion(y, target, true)
        if err := ag.Backward(loss); err != nil {
            return 0, err
        return float.ValueOf[T](loss.Value().Item()), nil

    for epoch := 0; epoch < epochs; epoch++ {
        for i := 0; i < examples; i++ {
            x := normalize(T(i))
            loss, err := learn(x, objective(x))
            if err != nil {
            if i%100 == 0 {
        if err := optimizer.Optimize(); err != nil {

    fmt.Printf("\n\nTraining completed!\n\n")

    fmt.Printf("Model parameters:\n")
    fmt.Printf("W: %.2f | B: %.2f\n\n", m.W.Value().Item().F64(), m.B.Value().Item().F64())

    // -- Enable this code to save the trained model to a file --
    // fmt.Printf("Saving the trained model to the file...\n")
    // err := nn.DumpToFile(m, "model.bin")
    // if err != nil {
    //     log.Fatal(err)
    // }