codegen/op-codegen/src/main/ops/org/nd4j/codegen/ops/Linalg.kt
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.codegen.ops
import org.nd4j.codegen.api.DataType
import org.nd4j.codegen.api.DataType.*
import org.nd4j.codegen.api.Language
import org.nd4j.codegen.api.doc.DocScope
import org.nd4j.codegen.dsl.*
import org.nd4j.codegen.api.Range
fun Linalg() = Namespace("Linalg") {
//val namespaceJavaPackage = "org.nd4j.linalg"
Op("Cholesky") {
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms"
javaOpClass = "Cholesky"
Input(DataType.NUMERIC, "input") { description = "Input tensor with inner-most 2 dimensions forming square matrices" }
Output(DataType.NUMERIC, "output"){ description = "Transformed tensor" }
Doc(Language.ANY, DocScope.ALL){
"""
Computes the Cholesky decomposition of one or more square matrices.
""".trimIndent()
}
}
Op("Lstsq") {
javaPackage = "org.nd4j.linalg.api.ops.custom"
javaOpClass = "Lstsq"
Input(DataType.NUMERIC, "matrix") {description = "input tensor"}
Input(DataType.NUMERIC, "rhs") {description = "input tensor"}
Arg(DataType.FLOATING_POINT, "l2_reguralizer") {description = "regularizer"}
Arg(DataType.BOOL, "fast") {description = "fast mode, defaults to True"; defaultValue = true}
Output(DataType.FLOATING_POINT, "output"){ description = "Transformed tensor" }
Doc(Language.ANY, DocScope.ALL){
"""
Solver for linear squares problems.
""".trimIndent()
}
}
Op("Solve") {
javaPackage = "org.nd4j.linalg.api.ops.custom"
javaOpClass = "LinearSolve"
Input(DataType.NUMERIC, "matrix") {description = "input tensor"}
Input(DataType.NUMERIC, "rhs") {description = "input tensor"}
Arg(DataType.BOOL, "adjoint") {description = "adjoint mode, defaults to False"; defaultValue = false}
Output(FLOATING_POINT, "output"){ description = "Output tensor" }
Doc(Language.ANY, DocScope.ALL){
"""
Solver for systems of linear equations.
""".trimIndent()
}
}
Op("TriangularSolve") {
javaPackage = "org.nd4j.linalg.api.ops.custom"
javaOpClass = "TriangularSolve"
Input(DataType.NUMERIC, "matrix") {description = "input tensor"}
Input(DataType.NUMERIC, "rhs") {description = "input tensor"}
Arg(DataType.BOOL, "lower") {description = "defines whether innermost matrices in matrix are lower or upper triangular"}
Arg(DataType.BOOL, "adjoint") {description = "adjoint mode"}
Output(DataType.FLOATING_POINT, "output")
Doc(Language.ANY, DocScope.ALL){
"""
Solver for systems of linear questions.
""".trimIndent()
}
}
Op("Lu") {
javaPackage = "org.nd4j.linalg.api.ops.custom"
javaOpClass = "Lu"
Input(DataType.NUMERIC, "input") {description = "input tensor"}
Output(FLOATING_POINT, "output")
Doc(Language.ANY, DocScope.ALL){
"""
Computes LU decomposition.
""".trimIndent()
}
}
Op("Matmul") {
javaPackage = "org.nd4j.linalg.api.ops.impl.reduce"
javaOpClass = "Mmul"
Input(DataType.NUMERIC, "a") {description = "input tensor"}
Input(DataType.NUMERIC, "b") {description = "input tensor"}
Arg(DataType.FLOATING_POINT,"alpha",{defaultValue = 1.0; description = "Defaults to 1.0: the scalar multiplier for the product of a* b "})
Arg(DataType.FLOATING_POINT,"beta",{defaultValue = 1.0; description = "Defaults to 1.0: the scalar multiplier for c "})
Arg(DataType.BOOL,"transA",{defaultValue = false; description = "Whether to transpose a when running multiply "})
Arg(DataType.BOOL,"transB",{defaultValue = false; description = "Whether to transpose b when running multiply "})
Output(DataType.FLOATING_POINT, "output")
Doc(Language.ANY, DocScope.ALL){
"""
Performs matrix multiplication on input tensors.
""".trimIndent()
}
}
Op("Qr") {
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
javaOpClass = "Qr"
Input(DataType.NUMERIC, "input") {description = "input tensor"}
Arg(DataType.BOOL, "full") {description = "full matrices mode"; defaultValue = false}
Output(FLOATING_POINT, "outputQ")
Output(FLOATING_POINT, "outputR")
Doc(Language.ANY, DocScope.ALL){
"""
Computes the QR decompositions of input matrix.
""".trimIndent()
}
}
Op("MatrixBandPart") {
javaPackage = "org.nd4j.linalg.api.ops.custom"
javaOpClass = "MatrixBandPart"
Input(DataType.NUMERIC, "input") { description = "input tensor" }
Arg(DataType.INT, "minLower") { description = "lower diagonal count" }
Arg(DataType.INT, "maxUpper") { description = "upper diagonal count" }
Output(DataType.FLOATING_POINT, "output1")
Output(DataType.FLOATING_POINT, "output2")
Doc(Language.ANY, DocScope.ALL){
"""
Copy a tensor setting outside a central band in each innermost matrix.
""".trimIndent()
}
}
Op("cross") {
javaPackage = "org.nd4j.linalg.api.ops.impl.shape"
javaOpClass = "Cross"
Input(DataType.NUMERIC, "a") {"Input tensor a"}
Input(DataType.NUMERIC, "b") {"Input tensor b"}
Output(FLOATING_POINT, "output")
Doc(Language.ANY, DocScope.ALL){
"""
Computes pairwise cross product.
""".trimIndent()
}
}
Op("diag") {
javaPackage = "org.nd4j.linalg.api.ops.impl.shape"
javaOpClass = "Diag"
Input(DataType.NUMERIC, "input") {"Input tensor"}
Output(DataType.FLOATING_POINT, "output")
Doc(Language.ANY, DocScope.ALL){
"""
Calculates diagonal tensor.
""".trimIndent()
}
}
Op("diag_part") {
javaPackage = "org.nd4j.linalg.api.ops.impl.shape"
javaOpClass = "DiagPart"
Input(DataType.NUMERIC, "input") {"Input tensor"}
Output(DataType.FLOATING_POINT, "output")
Doc(Language.ANY, DocScope.ALL){
"""
Calculates diagonal tensor.
""".trimIndent()
}
}
Op("matrixDeterminant") {
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
javaOpClass = "MatrixDeterminant"
Input(DataType.NUMERIC, "input") {"Input tensor"}
Output(FLOATING_POINT, "output")
Doc(Language.ANY, DocScope.ALL){
"""
Calculates matrix determinant.
""".trimIndent()
}
}
Op("logdet") {
javaPackage = "org.nd4j.linalg.api.ops.custom"
javaOpClass = "Logdet"
Input(DataType.NUMERIC, "input") {"Input tensor"}
Output(FLOATING_POINT, "output")
Doc(Language.ANY, DocScope.ALL){
"""
Calculates log of determinant.
""".trimIndent()
}
}
Op("matrixInverse") {
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
javaOpClass = "MatrixInverse"
Input(DataType.NUMERIC, "input") {"Input tensor"}
Output(FLOATING_POINT, "output")
Doc(Language.ANY, DocScope.ALL){
"""
Inverts a matrix
""".trimIndent()
}
}
Op("eig") {
javaPackage = "org.nd4j.linalg.api.ops.custom"
javaOpClass = "Eig"
Input(DataType.NUMERIC, "input") {"Input tensor"}
Output(FLOATING_POINT, "eigenValues")
Output(FLOATING_POINT, "eigenVectors")
Doc(Language.ANY, DocScope.ALL){
"""
Calculates eigen values
""".trimIndent()
}
}
Op("svd") {
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
javaOpClass = "Svd"
Input(DataType.NUMERIC, "input") {"Input tensor"}
Arg(DataType.BOOL, "fullUV") {"Full matrices mode"}
Arg(DataType.BOOL, "computeUV") {"Compute U and V"}
Arg(DataType.INT, "switchNum") {"Switch number"; defaultValue = 16}
Output(FLOATING_POINT, "output")
Doc(Language.ANY, DocScope.ALL){
"""
Calculates singular value decomposition.
""".trimIndent()
}
}
Op("tri") {
javaPackage = "org.nd4j.linalg.api.ops.custom"
javaOpClass = "Tri"
Arg(DATA_TYPE, "dataType") { description = "Data type"; defaultValue = org.nd4j.linalg.api.buffer.DataType.FLOAT }
Arg(INT, "row") {"Number of rows in the array"; }
Arg(INT, "column") {"Number of columns in the array"; }
Arg(INT, "diagonal") {"The sub-diagonal at and below which the array is filled. k = 0 is the main diagonal, while k < 0 is below it, and k > 0 is above. The default is 0."; defaultValue = 0}
Output(FLOATING_POINT, "output")
Doc(Language.ANY, DocScope.ALL) {
"""
An array with ones at and below the given diagonal and zeros elsewhere.
""".trimIndent()
}
}
Op("triu") {
javaPackage = "org.nd4j.linalg.api.ops.custom"
javaOpClass = "Triu"
Input(DataType.NUMERIC, "input") {"Input tensor"}
Arg(DataType.INT, "diag") {"diagonal"; defaultValue = 0}
Output(FLOATING_POINT, "output")
Doc(Language.ANY, DocScope.ALL){
"""
Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.
""".trimIndent()
}
}
Alias(SDBaseOps(), "mmul")
}