
View on GitHub


0 mins
Test Coverage
// Copyright 2019 spaGO Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ag

import (


var (
    // forceSyncExecution, when set to true, forces operators to run synchronously, overriding any "async" flag in the Run() function.
    // This can be particularly useful for debugging.
    forceSyncExecution = false

// SetForceSyncExecution enables or disables the forcing of synchronous execution for all operators.
// When enabled, the operators will run synchronously, regardless of the "async" flag in the Run() function.
// This setting can be particularly useful for debugging.
func SetForceSyncExecution(enable bool) {
    forceSyncExecution = enable

// backwardState is an enumeration type associated to an Operator, to keep
// track of its visited status among different backpropagation phases.
type backwardState = uint32

const (
    // idle reports that gradient propagation is not pending for an
    // operator node.
    // It's the default zero-value state of an operator, and it's also the
    // final value set from the backward step once gradients have been
    // propagated.
    // As soon as a backward operation is performed, the status will change to
    // pending.
    idle backwardState = iota
    // pending is set on an operator node from the preparatory phase
    // of the backward step.
    // It reports that the node has been marked as a candidate for gradients
    // propagation and the number of pendingGrads has been computed.
    // The next logical state is ongoing.
    // ongoing is set on an operator node from the core phase of the
    // backward step. It reports that the node has been visited once for
    // performing its Operator.backward method.
    // This status remains set until the gradients of all dependents have been
    // resolved, and the node's own gradients have been propagated too.
    // After that, the status is set back to idle.

// AutoGradFunction represents a function with automatic differentiation features.
// It's used to define a new operator.
type AutoGradFunction interface {
    // Forward computes the output of the function.
    Forward() (mat.Tensor, error)
    // Backward computes the backward pass given the gradient of the output.
    Backward(gy mat.Tensor) error
    // Operands returns the list of operands.
    Operands() []mat.Tensor

// forwardGuard is a buffered channel that acts as a semaphore to limit the concurrency
// of async forward operations in the Run function. Its buffer size determines the maximum
// number of forward operations that can run concurrently. Acquiring and releasing slots
// in the semaphore ensures that the concurrency level stays within the desired limit.
var forwardGuard chan struct{}

// Using runtime.NumCPU() * 2 is a common heuristic for setting the number of concurrent goroutines or the concurrency level in a Go program.
var concurrencyLimit = runtime.NumCPU() * 2

func init() {
    forwardGuard = make(chan struct{}, concurrencyLimit)

// Operator is a type of node.
// It's used to represent a function with automatic differentiation features.
type Operator struct {
    // value stores the results of a forward evaluation, as mat.Matrix.
    // It's set by executeForward() goroutine.
    // Use the Value() method to get the actual value.
    // It also contains the accumulated gradients. Use the Grad() method to get them.
    value mat.Tensor
    // onceOperands is used to initialize the operands only once.
    onceOperands sync.Once
    // AutoGradFunction's operands are memoized here after the first request.
    operands []mat.Tensor
    // backwardPass is the backward function to be executed.
    fn AutoGradFunction
    // broadcast is the channel used to broadcast the result of the forward pass.
    broadcast chan struct{}
    // broadcastGrad is the channel used to broadcast the result of the backward pass.
    // It is initialized only when the backward pass is performed.
    broadcastGrad chan struct{}
    // pendingGrads is the number of pending gradients to be accumulated. (default: 0)
    pendingGrads int64
    // onceRequiresGrad is used to initialize the requiresGrad only once.
    onceRequiresGrad sync.Once
    // requiresGrad is a flag that indicates whether the operator requires gradients.
    // Use the RequiresGrad() method to get the actual value.
    requiresGrad bool
    // backwardState is the state of the backward pass.
    backwardState backwardState

// NewOperator creates a new operator with the given AutoGradFunction.
// Note that the operator's Value() can only be accessed after calling the Run() function.
func NewOperator(f AutoGradFunction) *Operator {
    return &Operator{fn: f}

// SetAt sets the value at the given indices.
// It panics if the given indices are out of range.
func (o *Operator) SetAt(m mat.Tensor, indices {
    o.Value().SetAt(m, indices...)

// At returns the value at the given indices.
// It panics if the given indices are out of range.
func (o *Operator) At(indices mat.Tensor {
    return o.Value().At(indices...)

// Run starts the execution of the operator, performing the forward pass.
// If the optional async argument is set to true, the forward pass will be executed in a separate goroutine.
// The function returns a pointer to the Operator, allowing for method chaining.
func (o *Operator) Run(async ...bool) *Operator {
    isAsync := !forceSyncExecution && len(async) > 0 && async[0]

    if isAsync {
        //lint:ignore S1019 explicitly set the buffer size to 0 as the channel is used as a signal
        o.broadcast = make(chan struct{}, 0)
        forwardGuard <- struct{}{}
        go func() {
        return o

    return o

// forward executes the forward function and inform all goroutines that have been waiting for the result.
func (o *Operator) executeForward() {
    value, err := o.fn.Forward()
    if err != nil {
        log.Fatalf("ag: error during forward pass: %v", err) // TODO: handle error
    o.value = value

    if o.broadcast != nil { // if nil, it means that the operator is not async
        close(o.broadcast) // inform all goroutines that have been waiting for the result

// Value returns the result of the function.
func (o *Operator) Value() mat.Tensor {
    if o.broadcast != nil { // if nil, it means that the operator is not async
        <-o.broadcast // wait for the forward goroutine to finish
    return o.value

func (o *Operator) Item() float.Float {
    return o.Value().Item()

// Grad returns the gradients accumulated during the backward pass.
func (o *Operator) Grad() mat.Tensor {
    if o.isBackwardIdle() || atomic.LoadInt64(&o.pendingGrads) == 0 {
        return o.Value().Grad()

    <-o.broadcastGrad // wait for the backward goroutine to finish
    return o.Value().Grad()

// HasGrad returns true if there are accumulated gradients.
func (o *Operator) HasGrad() bool {
    return !isNil(o.Grad()) // safety wait for the backward goroutine to finish

// RequiresGrad returns true if the node requires gradients.
func (o *Operator) RequiresGrad() bool {
    o.onceRequiresGrad.Do(func() {
        for _, op := range o.Operands() {
            if op.RequiresGrad() {
                o.requiresGrad = true // memoize the result
    return o.requiresGrad

// Operands returns the operands of the operator.
func (o *Operator) Operands() []mat.Tensor {
    o.onceOperands.Do(func() {
        o.operands = o.fn.Operands() // memoize the result
    return o.operands

// ZeroGrad clears the gradients.
func (o *Operator) ZeroGrad() {
    if o.HasGrad() {

// AccGrad accumulates the gradients to the node itself.
func (o *Operator) AccGrad(grad mat.Tensor) {

    // Don't decrement the counter if the backward pass is not running.
    if !o.isBackwardIdle() && atomic.AddInt64(&o.pendingGrads, -1) == 0 {
        close(o.broadcastGrad) // notify all goroutines that have been waiting for the gradients

func (o *Operator) assignOutputGradient() error {
    grad := o.Value().Grad()

    if !isNil(grad) {
        return nil

    if o.Value().Size() == 1 {
        return nil

    return fmt.Errorf("ag: missing gradient for %v", o)

func (o *Operator) prepareBackwardPass() {
    if !o.RequiresGrad() {

    if !o.trySetBackwardPending() {

    //lint:ignore S1019 explicitly set the buffer size to 0 as the channel is used as a signal
    o.broadcastGrad = make(chan struct{}, 0)

    for _, operand := range o.Operands() {
        if oo, ok := operand.(*Operator); ok {

func (o *Operator) processBackwardPass(wg *sync.WaitGroup) {
    if !o.RequiresGrad() || !o.trySetBackwardOngoing() {

    wg.Add(1) // decrement when the backward pass is done
    go o.executeBackward(wg)

    for _, operand := range o.Operands() {
        if oo, ok := operand.(*Operator); ok {

func (o *Operator) executeBackward(wg *sync.WaitGroup) {
    defer wg.Done()
    defer o.setBackwardIdle()

    grad := o.Grad() // wait until the accumulated gradients are ready
    if grad == nil {
        return // no gradients to propagate

    if err := o.fn.Backward(grad); err != nil {
        log.Fatalf("ag: error during backward pass: %v", err) // TODO: handle error

func (o *Operator) isBackwardIdle() bool {
    return atomic.LoadUint32(&o.backwardState) == idle

func (o *Operator) setBackwardIdle() {
    atomic.StoreUint32(&o.backwardState, idle)

func (o *Operator) trySetBackwardPending() bool {
    return atomic.CompareAndSwapUint32(&o.backwardState, idle, pending)

func (o *Operator) trySetBackwardOngoing() bool {
    return atomic.CompareAndSwapUint32(&o.backwardState, pending, ongoing)

// isNil returns true if the gradients are nil.
func isNil(grad any) bool {
    if grad == nil || reflect.ValueOf(grad).IsNil() {
        return true
    return false