codegen/op-codegen/src/main/kotlin/org/nd4j/codegen/dsl/OpBuilder.kt
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.codegen.dsl
import org.nd4j.codegen.api.*
import org.nd4j.codegen.api.doc.DocScope
import org.nd4j.codegen.api.doc.DocSection
import org.nd4j.codegen.ops.SDBaseOps
import java.lang.IllegalStateException
fun Namespace(name: String, block: NamespaceOps.() -> Unit): NamespaceOps {
val ns = NamespaceOps(name)
ns.block()
ns.checkInvariants()
return ns
}
fun Mixin(name: String, block: Mixin.() -> Unit): Mixin {
return Mixin(name).apply(block).also {
it.checkInvariants()
}
}
fun NamespaceOps.Alias(ns:NamespaceOps, opName:String):Op {
val op:Op? = ns.op(opName)
op?.let {
this.parentNamespaceOps[ns.name]?.add(op)
this.ops.add(op)
return op
}
throw IllegalStateException("Failed to create alias for op: $opName")
}
fun NamespaceOps.Op(name: String, block: Op.() -> Unit): Op {
val op = Op(name)
op.libnd4jOpName = name
op.block()
if (!op.isAbstract && op.signatures.isEmpty()) {
op.AllParamSignature()
op.AllDefaultsSignature()
}
op.checkInvariants()
this.ops.add(op)
return op
}
fun NamespaceOps.Op(name: String,
extends: Mixin,
keepArgs: Boolean = true,
keepInputs: Boolean = true,
keepOutputs: Boolean = true,
keepConstraints: Boolean = true,
keepSignatures: Boolean = true,
keepDocs: Boolean = true,
block: (Op.() -> Unit)? = null): Op {
return this.Op(name) {
useMixin(extends, keepArgs = keepArgs, keepInputs = keepInputs, keepOutputs = keepOutputs, keepConstraints = keepConstraints, keepSignatures = keepSignatures, keepDocs = keepDocs)
if (block != null) {
this.block()
}
}
}
fun OpLike.Input(dataType: DataType, name: String, block: (Input.() -> Unit)? = null): Input {
val input = Input(name, dataType)
if (block != null) input.block()
if (!dataType.isTensorDataType()) {
throw IllegalArgumentException("Invalid datatype for input \"$name\" of op ${this.name()}: inputs arrays cannot have type $dataType - wrong type, or should be Arg type?");
}
this.addInput(input)
return input
}
fun OpLike.Arg(dataType: DataType, name: String, block: (Arg.() -> Unit)? = null): Arg {
val input = Arg(name, dataType)
if (block != null) input.block()
this.addArgument(input)
if(dataType == DataType.ENUM){
Registry.registerEnum(input)
}
return input
}
fun OpLike.Output(dataType: DataType, name: String, block: (Output.() -> Unit)? = null): Output {
val output = Output(name, dataType, false)
if (block != null) output.block()
if (!dataType.isTensorDataType()) {
throw IllegalArgumentException("Invalid datatype for output \"$name\" of op ${this.name()}: output arrays cannot have type $dataType");
}
this.addOutput(output)
return output
}
fun OpLike.Doc(language: Language, scope: DocScope, block: DocSection.() -> String): DocSection {
val doc = DocSection().apply {
this.language = language
this.scope = scope
text = this.block()
}
this.addDoc(doc)
return doc
}
fun OpLike.AllParamSignature(withOutput: Boolean = false) {
val allParameters = allParameters()
this.addSignature(Signature(allParameters))
if (withOutput) {
val withOutputParams = mutableListOf<Parameter>().also {
it.addAll(this.outputs())
it.addAll(allParameters)
}
this.addSignature(Signature(withOutputParams))
}
}
fun OpLike.AllDefaultsSignature(withOutput: Boolean = false) {
val allParameters = allParameters()
if (allParameters.find { it.hasDefaultValue() } != null) {
val params = allParameters.filterNot { it.hasDefaultValue() }
this.addSignature(Signature(params))
if (withOutput) {
val withOutputParams = mutableListOf<Parameter>().also {
it.addAll(this.outputs())
it.addAll(allParameters)
}
this.addSignature(Signature(withOutputParams))
}
}
}
fun OpLike.Signature(vararg params: Parameter, block: (Signature.() -> String)? = null): Signature {
if (params.toSet().size < params.size) {
throw IllegalArgumentException("A parameter may not be used twice in a signature!")
}
val paramsAndOutputs = allParameters() + outputs()
if (!paramsAndOutputs.containsAll(params.toList())) {
throw IllegalArgumentException("You can only use parameters in a signature that are actually defined in $this! Did you forget to useMixin(...) a mixin?")
}
val signature = Signature(params.toList())
if (block != null) {
signature.block()
}
this.addSignature(signature)
return signature
}
fun OpLike.Constraint(desc: String, block: ConstraintBuilder.() -> Expression): Constraint {
val check = ConstraintBuilder().block()
val constraint = Constraint(desc, check)
this.addConstraint(constraint)
return constraint
}
fun OpLike.BackendConstraint(desc: String, block: ConstraintBuilder.() -> Expression): Constraint {
val check = ConstraintBuilder().block()
val constraint = BackendConstraint(desc, check)
this.addConstraint(constraint)
return constraint
}
class ConstraintBuilder {
fun broadcastableShapes(vararg inputs: Input) = BroadcastableShapesExpression(inputs.toList())
fun sameShape(vararg inputs: Input) = SameShapeExpression(inputs.toList())
fun sameType(vararg inputs: Input) = SameTypeExpression(inputs.toList())
fun Input.sizeAt(i: Int) = InputShapeReference(this, i)
fun Input.rank() = InputRankReference(this)
fun Input.isScalar() = this.rank() eq 0
fun some(expr: BooleanExpression, vararg exprs: BooleanExpression) = exprs.fold(expr, { acc, cur -> acc or cur })
fun all(expr: BooleanExpression, vararg exprs: BooleanExpression) = exprs.fold(expr, { acc, cur -> acc and cur })
fun not(expr: BooleanExpression) = expr eq false
infix fun BooleanExpression.and(other: BooleanExpression) = BooleanExpression(this, other, BooleanOperation.AND)
infix fun BooleanExpression.or(other: BooleanExpression) = BooleanExpression(this, other, BooleanOperation.OR)
infix fun Reference.eq(other: Reference) = BooleanExpression(this, other, BooleanOperation.EQ)
infix fun Reference.eq(other: Number) = this eq NumberReference(other)
infix fun Reference.eq(other: Boolean) = this eq BooleanReference(other)
infix fun Reference.neq(other: Reference) = BooleanExpression(this, other, BooleanOperation.NEQ)
infix fun <T : Number> Reference.neq(other: T) = this neq NumberReference(other)
infix fun Reference.neq(other: Boolean) = this neq BooleanReference(other)
infix fun Reference.gt(other: Reference) = BooleanExpression(this, other, BooleanOperation.GT)
infix fun <T : Number> Reference.gt(other: T) = this gt NumberReference(other)
infix fun Reference.lt(other: Reference) = BooleanExpression(this, other, BooleanOperation.LT)
infix fun <T : Number> Reference.lt(other: T) = this lt NumberReference(other)
infix fun <T : Number> Reference.gte(other: T) = this gte NumberReference(other)
infix fun Reference.gte(other: Reference) = BooleanExpression(this, other, BooleanOperation.GTE)
infix fun <T : Number> Reference.lte(other: T) = this lte NumberReference(other)
infix fun Reference.lte(other: Reference) = BooleanExpression(this, other, BooleanOperation.LTE)
}
fun NamespaceOps.Config(name: String, block: (Config.() -> Unit)): Config {
val config = Config(name)
config.block()
this.addConfig(config)
Registry.registerConfig(config)
return config
}
fun Config.Input(dataType: DataType, name: String, block: (Input.() -> Unit)? = null): Input {
val input = Input(name, dataType)
if (block != null) input.block()
if (!dataType.isTensorDataType()) {
throw IllegalArgumentException("Invalid datatype for input \"$name\" of config ${this.name}: inputs arrays cannot have type $dataType - wrong type, or should be Arg type?");
}
this.addInput(input)
return input
}
fun Config.Arg(dataType: DataType, name: String, block: (Arg.() -> Unit)? = null): Arg {
val input = Arg(name, dataType)
if (block != null) input.block()
this.addArgument(input)
if(dataType == DataType.ENUM){
Registry.registerEnum(input)
}
return input
}
fun Config.Constraint(desc: String, block: ConstraintBuilder.() -> Expression): Constraint {
val check = ConstraintBuilder().block()
val constraint = Constraint(desc, check)
this.addConstraint(constraint)
return constraint
}
fun Config.BackendConstraint(desc: String, block: ConstraintBuilder.() -> Expression): Constraint {
val check = ConstraintBuilder().block()
val constraint = BackendConstraint(desc, check)
this.addConstraint(constraint)
return constraint
}
fun Config.Doc(language: Language, scope: DocScope, block: DocSection.() -> String): DocSection {
val doc = DocSection().apply {
this.language = language
this.scope = scope
text = this.block()
}
this.addDoc(doc)
return doc
}
fun OpLike.useConfig(config: Config): Config {
this.addConfig(config)
return config
}
fun Op.useMixin(mixin: Mixin,
keepArgs: Boolean = true,
keepInputs: Boolean = true,
keepOutputs: Boolean = true,
keepConstraints: Boolean = true,
keepSignatures: Boolean = true,
keepDocs: Boolean = true,
keepConfigs: Boolean = true
) {
if(mixin.legacyWasSet){
legacy = mixin.legacy
}
if(mixin.javaPackageWasSet){
javaPackage = mixin.javaPackage
}
if (keepArgs) {
args.addOrReplaceAll(mixin.args)
}
if (keepInputs) {
inputs.addOrReplaceAll(mixin.inputs)
}
if (keepOutputs) {
outputs.addOrReplaceAll(mixin.outputs)
}
if (keepConstraints) {
constraints.addAll(mixin.constraints)
}
if (keepSignatures) {
signatures.addAll(mixin.signatures)
}
if (keepDocs) {
doc.addAll(mixin.doc)
}
if(keepConfigs){
configs.addOrReplaceAll(mixin.configs)
}
}
fun Mixin.useMixin(mixin: Mixin,
keepArgs: Boolean = true,
keepInputs: Boolean = true,
keepOutputs: Boolean = true,
keepConstraints: Boolean = true,
keepSignatures: Boolean = true,
keepDocs: Boolean = true,
keepConfigs: Boolean = true) {
if(mixin.legacyWasSet){
legacy = mixin.legacy
}
if(mixin.javaPackageWasSet){
javaPackage = mixin.javaPackage
}
if (keepArgs) {
args.addOrReplaceAll(mixin.args)
}
if (keepInputs) {
inputs.addOrReplaceAll(mixin.inputs)
}
if (keepOutputs) {
outputs.addOrReplaceAll(mixin.outputs)
}
if (keepConstraints) {
constraints.addAll(mixin.constraints)
}
if (keepSignatures) {
signatures.addAll(mixin.signatures)
}
if (keepDocs) {
doc.addAll(mixin.doc)
}
if(keepConfigs){
configs.addOrReplaceAll(mixin.configs)
}
}