codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/RNN.kt
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.codegen.ops
import org.nd4j.codegen.api.DataType.*
import org.nd4j.codegen.api.Language
import org.nd4j.codegen.api.doc.DocScope
import org.nd4j.codegen.dsl.*
fun SDRNN() = Namespace("RNN") {
val LSTMConfiguration = Config("LSTMConfiguration") {
Arg(ENUM, "RnnDataFormat") {
possibleValues = listOf("TNS", "NST", "NTS"); description = " The data format of the input. Input shape depends on data format (in config):<br>\n" +
" TNS -> [timeSteps, batchSize, inSize]<br>\n" +
" NST -> [batchSize, inSize, timeSteps]<br>\n" +
" NTS -> [batchSize, timeSteps, inSize]<br>"
}
Arg(BOOL, "peepHole") { description = "Whether to provide peephole connections"; }
Arg(NUMERIC, "forgetBias") { description = "The bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training."; }
Arg(NUMERIC, "clippingCellValue") { description = "The bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training."; }
javaClassOverride = "org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration"
}
val LSTMLayerConfig = Config("LSTMLayerConfig") {
Arg(ENUM, "LSTMDataFormat") {
possibleValues = listOf("TNS", "NST", "NTS", "T2NS");
description = "for unidirectional:" +
" TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as \"time major\"<br>\n" +
" NST: shape [numExamples, inOutSize, timeLength]<br>\n" +
" NTS: shape [numExamples, timeLength, inOutSize] - TF \"time_major=false\" layout<br>" +
" for bidirectional:\n" +
" T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX)"
}
Arg(ENUM, "LSTMDirectionMode") {
possibleValues = listOf("FWD", "BWD", "BIDIR_SUM", "BIDIR_CONCAT", "BIDIR_EXTRA_DIM"); description = "direction <br>\n" +
" FWD: 0 = fwd\n" +
" BWD: 1 = bwd\n" +
" BIDIR_SUM: 2 = bidirectional sum\n" +
" BIDIR_CONCAT: 3 = bidirectional concat\n" +
" BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)"
}
Arg(ENUM, "gateAct") {
possibleValues = listOf("TANH",
"RELU",
"SIGMOID",
"AFFINE",
"LEAKY_RELU",
"THRESHHOLD_RELU",
"SCALED_TAHN",
"HARD_SIGMOID",
"ELU",
"SOFTSIGN",
"SOFTPLUS"); description = "Activations"
}
Arg(ENUM, "cellAct") {
possibleValues = listOf("TANH",
"RELU",
"SIGMOID",
"AFFINE",
"LEAKY_RELU",
"THRESHHOLD_RELU",
"SCALED_TAHN",
"HARD_SIGMOID",
"ELU",
"SOFTSIGN",
"SOFTPLUS"); description = "Activations"
}
Arg(ENUM, "outAct") {
possibleValues = listOf("TANH",
"RELU",
"SIGMOID",
"AFFINE",
"LEAKY_RELU",
"THRESHHOLD_RELU",
"SCALED_TAHN",
"HARD_SIGMOID",
"ELU",
"SOFTSIGN",
"SOFTPLUS"); description = "Activations"
}
Arg(BOOL, "retFullSequence") { description = "indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}"; defaultValue = true }
Arg(BOOL, "retLastH") {
description = "indicates whether to return output at last time step only,\n" +
" in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)"; defaultValue = false
}
Arg(BOOL, "retLastC") {
description = "indicates whether to return cells state at last time step only,\n" +
" in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)"; defaultValue = false
}
Arg(NUMERIC, "cellClip") { description = "Cell clipping value, if it = 0 then do not apply clipping"; defaultValue = 0.0}
Arg(NUMERIC, "gateAlpha") {defaultValue=0.0}
Arg(NUMERIC, "gateBeta") {defaultValue=0.0}
Arg(NUMERIC, "cellAlpha") {defaultValue=0.0}
Arg(NUMERIC, "cellBeta") {defaultValue=0.0}
Arg(NUMERIC, "outAlpha") {defaultValue=0.0}
Arg(NUMERIC, "outBeta") {defaultValue=0.0}
javaClassOverride = "org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig"
}
val GRUWeights = Config("GRUWeights") {
Input(NUMERIC, "ruWeight")
Input(NUMERIC, "cWeight")
Input(NUMERIC, "ruBias")
Input(NUMERIC, "cBias")
javaClassOverride = "org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights"
}
val SRUWeights = Config("SRUWeights") {
Input(NUMERIC, "weights")
Input(NUMERIC, "bias")
javaClassOverride = "org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights"
}
val LSTMWeights = Config("LSTMWeights") {
Input(NUMERIC, "ruWeight")
Input(NUMERIC, "inputPeepholeWeights")
Input(NUMERIC, "forgetPeepholeWeights")
Input(NUMERIC, "outputPeepholeWeights")
Input(NUMERIC, "bias")
javaClassOverride = "org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights"
}
val LSTMLayerWeights = Config("LSTMLayerWeights") {
Input(NUMERIC, "inputWeights") {description="input weights Wx:\n" +
" 1) shapes `[nIn, 4*nOut]` for FWD,BWD " +
" 2) shapes `[2, nIn, 4*nOut]` BIDIR_SUM, BIDIR_CONCAT and BIDIR_EXTRA_DIM"}
Input(NUMERIC, "recurrentWeights") {description="recurrent weights Wr:\n" +
" 1) shapes `[nIn, 4*nOut]` for FWD, BWD " +
" 2) shapes `[2, nIn, 4*nOut]` BIDIR_SUM, BIDIR_CONCAT and BIDIR_EXTRA_DIM"}
Input(NUMERIC, "biases") {description="biases\n"+
" 1) shapes `[4*nOut]` for FWD, BWD " +
" 2) shapes `[2, 4*nOut]` for BIDIR_SUM, BIDIR_CONCAT and BIDIR_EXTRA_DIM"
defaultValue=null}
Input(NUMERIC, "peepholeWeights") {description="peephole weights Wp:\n" +
" 1) `[3*nOut]` when directionMode < 2\n" +
" 2) `[2, 3*nOut]` when directionMode >= 2"; defaultValue=null}
javaClassOverride = "org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights"
}
val namespaceJavaPackage = "org.nd4j.linalg.api.ops.impl.layers.recurrent"
Op("gruCell") {
javaPackage = namespaceJavaPackage
javaOpClass = "GRUCell"
Input(NUMERIC, "x") { description = "Input, with shape [batchSize, inSize]" }
Input(NUMERIC, "hLast") { description = "Output of the previous cell/time step, with shape [batchSize, numUnits]" }
useConfig(GRUWeights)
Output(NUMERIC, "r") { description = "Reset gate output" }
Output(NUMERIC, "u") { description = "Update gate output" }
Output(NUMERIC, "c") { description = "Cell gate output" }
Output(NUMERIC, "h") { description = "Cell output" }
Doc(Language.ANY, DocScope.ALL) {
"""
The GRU cell. Does a single time step operation
""".trimIndent()
}
}
Op("gru") {
javaPackage = namespaceJavaPackage
javaOpClass = "GRU"
Input(NUMERIC, "x") { description = "input [time, bS, nIn]" }
Input(NUMERIC, "hLast") { description = "initial cell output (at time step = 0) [bS, nOut]" }
Input(NUMERIC, "Wx") { description = "input-to-hidden weights, [nIn, 3*nOut]" }
Input(NUMERIC, "Wh") { description = "hidden-to-hidden weights, [nOut, 3*nOut]" }
Input(NUMERIC, "biases") { description = "biases, [3*nOut]" }
Output(NUMERIC, "h") { description = "cell outputs [time, bS, nOut], that is per each time step" }
Doc(Language.ANY, DocScope.ALL) {
"""
The GRU operation. Gated Recurrent Unit - Cho et al. 2014.
""".trimIndent()
}
}
Op("lstmCell") {
javaPackage = namespaceJavaPackage
javaOpClass = "LSTMBlockCell"
Input(NUMERIC, "x") { description = "Input, with shape [batchSize, inSize]" }
Input(NUMERIC, "cLast") { description = "Previous cell state, with shape [batchSize, numUnits]" }
Input(NUMERIC, "yLast") { description = "revious cell output, with shape [batchSize, numUnits]" }
useConfig(LSTMWeights)
useConfig(LSTMConfiguration)
Output(NUMERIC, "i") { description = "Output - input modulation gate activations [batchSize, numUnits]." }
Output(NUMERIC, "c") { description = "Output - Activations, cell state (pre tanh) [batchSize, numUnits]." }
Output(NUMERIC, "f") { description = "Output - forget gate activations [batchSize, numUnits]." }
Output(NUMERIC, "o") { description = "Output - output gate activations [batchSize, numUnits]." }
Output(NUMERIC, "z") { description = "Output - input gate activations [batchSize, numUnits]." }
Output(NUMERIC, "h") { description = "Cell state, post tanh [batchSize, numUnits]." }
Output(NUMERIC, "y") { description = "Current cell output [batchSize, numUnits]." }
Doc(Language.ANY, DocScope.ALL) {
"""
The LSTM cell. Does a single time step operation.
""".trimIndent()
}
}
Op("lstmblock") {
javaPackage = namespaceJavaPackage
javaOpClass = "LSTMBlock"
Input(NUMERIC, "maxTSLength") {defaultValue=null}
Input(NUMERIC, "x") { description = " Input, with shape dependent on the data format (in config)." }
Input(NUMERIC, "cLast") { description = "Previous/initial cell state, with shape [batchSize, numUnits]" ; defaultValue=null}
Input(NUMERIC, "yLast") { description = "Previous/initial cell output, with shape [batchSize, numUnits]" ; defaultValue=null }
useConfig(LSTMWeights)
useConfig(LSTMConfiguration)
Output(NUMERIC, "output") { description = "The layer's outputs." }
Doc(Language.ANY, DocScope.ALL) {
"""
The LSTM block
""".trimIndent()
}
}
Op("lstmLayer") {
javaPackage = namespaceJavaPackage
javaOpClass = "LSTMLayer"
Input(NUMERIC, "x") { description = " Input, with shape dependent on the data format (in config)." }
Input(NUMERIC, "cLast") { description = "Previous/initial cell state, with shape [batchSize, numUnits]"; defaultValue=null }
Input(NUMERIC, "yLast") { description = "Previous/initial cell output, with shape [batchSize, numUnits]"; defaultValue=null }
Input(NUMERIC, "maxTSLength") { description = "maxTSLength with shape [batchSize]"; defaultValue=null }
useConfig(LSTMLayerWeights)
useConfig(LSTMLayerConfig)
//TODO these are optional
Output(NUMERIC, "output") { description = "The layer's outputs - full time series" }
Output(NUMERIC, "yLast") { description = "The layer's outputs - last time step activations (yLast)" }
Output(NUMERIC, "cLast") { description = "The layer's outputs - last time step cell state (cLast)" }
Doc(Language.ANY, DocScope.ALL) {
"""
Long Short-Term Memory layer - Hochreiter 1997.
SUPPORTS following data formats:
for unidirectional:
TNS: shapes [timeLength, numExamples, inOutSize]
NST: shapes [numExamples, inOutSize, timeLength]
NTS: shapes [numExamples, timeLength, inOutSize]
for bidirectional:
T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)
SUPPORTS following direction modes:
FWD: forward
BWD: backward
BIDIR_SUM: bidirectional sum
BIDIR_CONCAT: bidirectional concat
BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)
You may use different gate configurations:
specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum
("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")
Also this layer supports MKLDNN (DNNL) and cuDNN acceleration
""".trimIndent()
}
}
Op("sruCell") {
javaPackage = namespaceJavaPackage
javaOpClass = "SRUCell"
Input(NUMERIC, "x") { description = "Input, with shape [batchSize, inSize]" }
Input(NUMERIC, "cLast") { description = "Previous cell state, with shape [batchSize, inSize]" }
useConfig(SRUWeights)
Output(NUMERIC, "output") { description = "The cell's outputs." }
Doc(Language.ANY, DocScope.ALL) {
"""
The SRU layer. Does a single time step operation.
""".trimIndent()
}
}
Op("sru") {
javaPackage = namespaceJavaPackage
javaOpClass = "SRU"
Input(NUMERIC, "x") { description = "Input, with shape [batchSize, inSize]" }
Input(NUMERIC, "initialC") { description = "Initial cell state, with shape [batchSize, inSize]" }
Input(NUMERIC, "mask") { description = "An optional dropout mask, with shape [batchSize, inSize]"; defaultValue = null }
useConfig(SRUWeights)
Output(NUMERIC, "output") { description = "The cell's outputs.." }
Doc(Language.ANY, DocScope.ALL) {
"""
The SRU layer. Does a single time step operation.
""".trimIndent()
}
}
}