nlpodyssey/spago

View on GitHub
ag/gradientblocker.go

Summary

Maintainability
A
0 mins
Test Coverage
// Copyright 2019 spaGO Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ag

import (
    "encoding/gob"

    "github.com/nlpodyssey/spago/mat"
)

// GradientBlocker embeds any tensors implementation disabling gradients handling and
// blocking gradients accumulation.
type GradientBlocker struct {
    mat.Tensor
}

func init() {
    gob.Register(&GradientBlocker{})
}

// StopGrad creates a new GradientBlocker that stops the accumulated gradients from
// flowing through the wrapped Node.
func StopGrad(t mat.Tensor) mat.Tensor {
    return &GradientBlocker{
        Tensor: t,
    }
}

// Grad always returns nil on a GradientBlocker Node.
func (r *GradientBlocker) Grad() mat.Tensor { return nil }

// AccGrad has no effects on a GradientBlocker Node.
func (r *GradientBlocker) AccGrad(_ mat.Tensor) {}

// HasGrad always returns false on a GradientBlocker Node.
func (r *GradientBlocker) HasGrad() bool { return false }

// RequiresGrad always returns false on a GradientBlocker Node.
func (r *GradientBlocker) RequiresGrad() bool { return false }

// ZeroGrad has no effects on a GradientBlocker Node.
func (r *GradientBlocker) ZeroGrad() {}