deeplearning4j/deeplearning4j

View on GitHub
libnd4j/include/ops/declarable/headers/updaters.h

Summary

Maintainability
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  * See the NOTICE file distributed with this work for additional
 *  * information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//

#ifndef LIBND4J_HEADERS_UPDATERS_H
#define LIBND4J_HEADERS_UPDATERS_H
#include <execution/Threads.h>
#include <helpers/ConstantTadHelper.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/headers/common.h>
#include <ops/declarable/helpers/updatersHelpers.h>

namespace sd {
namespace ops {

/**
 * SGD updater
 * Input arrays:
 * 0 - input array with gradients.
 * Optional:
 * 1 - scalar learning rate value
 * Optional:
 * T args
 * 0 - scalar learning rate value
 */
#if NOT_EXCLUDED(OP_sgd_updater)
DECLARE_CONFIGURABLE_OP(sgd_updater, 1, 1, true, 0, 0);
#endif

/**
 * RmsPropUpdater updater
 * Input arrays:
 * 0 - input array with gradients.
 * 1 - Initial state
 * Optional:
 * 2 - scalar learning rate value
 * 3 - scalar rms decay
 * 4 - epsilon
 * Optional:
 * T args
 * 0 - scalar learning rate value
 * 1 - scalar rms decay
 * 2 - epsilon
 */
#if NOT_EXCLUDED(OP_rms_prop_updater)
DECLARE_CONFIGURABLE_OP(rms_prop_updater, 2, 2, true, 0, 0);
#endif
// AdaGrad
/* Input arrays :
 *  0 - input array with gradients.
 *  1 - historical grad state
 * Optional :
 * 2 - scalar learning rate value
 * 3 - epsilon
 * Optional:
 * T args
 * 0 - scalar learning rate value
 * 1 - epsilon
 */
#if NOT_EXCLUDED(OP_ada_grad_updater)
DECLARE_CONFIGURABLE_OP(ada_grad_updater, 2, 2, true, 0, 0);
#endif
// AdaMax
/* Input arrays :
 *  0 - input array with gradients.
 *  1 - gradient state V
 *  2 - gradient state M
 * Optional :
 * 3 - scalar learning rate value
 * 4 - beta 1 value
 * 5 - beta 2 value
 * 6 - epsilon
 * Optional:
 * T args
 * 0 - scalar learning rate value
 * 1 - beta 1 value
 * 2 - beta 2 value
 * 3 - epsilon
 * Optional:
 * I args
 * 0 - iteration
 */
#if NOT_EXCLUDED(OP_ada_max_updater)
DECLARE_CONFIGURABLE_OP(ada_max_updater, 3, 3, true, 0, 0);
#endif
// Nesterov's momentum
/* Input arrays :
 *  0 - input array with gradients.
 *  1 - V grad state
 * Optional :
 * 2 - scalar learning rate value
 * 3 - scalar momentum value
 * Optional:
 * T args
 * 0 - learning rate value
 * 1 - momentum value
 */
#if NOT_EXCLUDED(OP_nesterovs_updater)
DECLARE_CONFIGURABLE_OP(nesterovs_updater, 2, 2, true, 0, 0);
#endif
// Adam
/* Input arrays :
 *  0 - input array with gradients.
 *  1 - gradient state V
 *  2 - gradient state M
 * Optional :
 * 3 - scalar learning rate value
 * 4 - beta 1 value
 * 5 - beta 2 value
 * 6 - epsilon
 * Optional:
 * T args
 * 0 - scalar learning rate value
 * 1 - beta 1 value
 * 2 - beta 2 value
 * 3 - epsilon
 * Optional:
 * I args
 * 0 - iteration
 */
#if NOT_EXCLUDED(OP_adam_updater)
DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0);
#endif
// AdaBelief
/* Input arrays :
 *  0 - input array with gradients.
 *  1 - gradient state V
 *  2 - gradient state M
 * Optional :
 * 3 - scalar learning rate value
 * 4 - beta 1 value
 * 5 - beta 2 value
 * 6 - epsilon
 * Optional:
 * T args
 * 0 - scalar learning rate value
 * 1 - beta 1 value
 * 2 - beta 2 value
 * 3 - epsilon
 * Optional:
 * I args
 * 0 - iteration
 */
#if NOT_EXCLUDED(OP_adabelief_updater)
DECLARE_CONFIGURABLE_OP(adabelief_updater, 3, 3, true, 0, 0);
#endif
// AdaDelta
/* Input arrays :
 *  0 - input array with gradients.
 *  1 - gradient state V
 *  2 - gradient state M
 * Optional :
 * 3 - rho value
 * 6 - epsilon
 * Optional:
 * T args
 * 0 - rho
 * 1 - epsilon
 */
#if NOT_EXCLUDED(OP_ada_delta_updater)
DECLARE_CONFIGURABLE_OP(ada_delta_updater, 3, 3, true, 0, 0);
#endif
// Nadam
/* Input arrays :
 *  0 - input array with gradients.
 *  1 - gradient state V
 *  2 - gradient state M
 * Optional :
 * 3 - scalar learning rate value
 * 4 - beta 1 value
 * 5 - beta 2 value
 * 6 - epsilon
 * Optional:
 * T args
 * 0 - scalar learning rate value
 * 1 - beta 1 value
 * 2 - beta 2 value
 * 3 - epsilon
 * Optional:
 * I args
 * 0 - iteration
 */
#if NOT_EXCLUDED(OP_nadam_updater)
DECLARE_CONFIGURABLE_OP(nadam_updater, 3, 3, true, 0, 0);
#endif
// AmsGrad
/* Input arrays :
 *  0 - input array with gradients.
 *  1 - gradient state V - sqrd gradients
 *  2 - gradient state M - moving avg
 *  3 - gradient state H - max
 * Optional :
 * 4 - scalar learning rate value
 * 5 - beta 1 value
 * 6 - beta 2 value
 * 7 - epsilon
 * Optional:
 * T args
 * 0 - scalar learning rate value
 * 1 - beta 1 value
 * 2 - beta 2 value
 * 3 - epsilon
 * Optional:
 * I args
 * 0 - iteration
 */
#if NOT_EXCLUDED(OP_ams_grad_updater)
DECLARE_CONFIGURABLE_OP(ams_grad_updater, 4, 4, true, 0, 0);
#endif
}  // namespace ops
}  // namespace sd

#endif