codegen/op-codegen/src/main/kotlin/org/nd4j/codegen/api/Variables.kt
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.codegen.api
import org.nd4j.codegen.api.doc.DocSection
import java.util.*
open class Constraint (
var message: String? = null,
var check: Expression
)
class BackendConstraint(message: String? = null, check: Expression): Constraint(message, check)
// Used in Constraint Expressions
sealed class Reference
data class NumberReference<T: Number>(val value: T): Reference()
data class BooleanReference(val value: Boolean): Reference()
data class InputShapeReference(val input: Input, val idx: Int): Reference()
data class InputRankReference(val input: Input): Reference()
sealed class Expression: Reference()
data class BooleanExpression(val left: Reference, val right: Reference, val op: BooleanOperation): Expression()
data class SameTypeExpression(val inputs: List<Input>): Expression()
data class SameShapeExpression(val inputs: List<Input>): Expression()
data class BroadcastableShapesExpression(val inputs: List<Input>): Expression()
enum class BooleanOperation{EQ, NEQ, LT, LTE, GT, GTE, AND, OR}
// Used to define array sizes
sealed class Count
data class Range(val from: Int, val to: Int): Count()
data class AtLeast(val min: Int): Count()
data class AtMost(val max: Int): Count()
data class Exactly(val count: Int): Count()
// Actual parameters
interface Parameter {
fun name(): String
fun defaultValue() : Any?
fun hasDefaultValue(): Boolean
fun isVararg():Boolean
/**
* A default value only is applicable if it is a literal value, or the referenced value is either directly a part of
* the signature, or there is a reference chain that ends in something that is actually a part of the signature
*/
fun defaultValueIsApplicable(otherParams: List<Parameter>): Boolean = if(hasDefaultValue()){
when(val defaultValue = this.defaultValue()){
is Number, is Boolean, null -> true
is IntArray, is BooleanArray, is DoubleArray -> true
is String -> true
is org.nd4j.linalg.api.buffer.DataType -> true
is org.nd4j.codegen.api.LossReduce -> true
is Parameter -> otherParams.contains(defaultValue) || defaultValue.defaultValueIsApplicable(otherParams)
is TensorDataTypeValue -> otherParams.contains(defaultValue.tensor) || defaultValue.tensor.defaultValueIsApplicable(otherParams)
is TensorShapeValue -> otherParams.contains(defaultValue.tensor) || defaultValue.tensor.defaultValueIsApplicable(otherParams)
else -> false
}
}else{
false
}
}
interface Tensor: Parameter
data class Arg(
val name: String,
val type: DataType,
var description: String? = null,
var isVargarg: Boolean = false
) : Reference(), Parameter {
override fun name(): String = name
override fun defaultValue(): Any? = defaultValue
override fun hasDefaultValue(): Boolean = defaultValueIsSet
override fun isVararg(): Boolean {
return isVargarg
}
private var defaultValueIsSet = false
var defaultValue: Any? = null
set(value) = if(isAssignableFrom(value)) {
field = value
defaultValueIsSet = true
}else{
throw IllegalArgumentException("Illegal default value for $this. Got ${value.toDescriptiveString()} (${value?.javaClass?.name})")
}
var possibleValues: List<String>? = null
set(value) = if(type == DataType.ENUM) when {
value == null -> field = null
value.isEmpty() -> throw IllegalArgumentException("$this: Can not set empty possibleValues.")
else -> field = value
} else {
throw IllegalArgumentException("$this: Can not set possibleValues on non ENUM typed Arg.")
}
var count: Count? = null
set(value) = if(type == DataType.ENUM && value != Exactly(1)) {
throw IllegalArgumentException("$this: ENUM typed Arg can not be array")
}else{
field = value
}
private fun matchesDataType(value: Any?) = when(type){
DataType.FLOATING_POINT -> value is Double
DataType.INT -> (value is Int) || (value is Long)
DataType.LONG -> (value is Int) || (value is Long)
DataType.NUMERIC -> value is Number
DataType.BOOL -> value is Boolean
else -> false
}
private fun isAssignableFrom(value: Any?) = when(value){
is TensorShapeValue -> isArray() && type == DataType.INT
is TensorDataTypeValue -> type == DataType.DATA_TYPE
is Number, is Boolean -> matchesDataType(value)
is IntArray -> isArray() && (type == DataType.INT || type == DataType.NUMERIC) && countMatches(value.size)
is DoubleArray -> isArray() && (type == DataType.FLOATING_POINT || type == DataType.NUMERIC) && countMatches(value.size)
is BooleanArray -> isArray() && type == DataType.BOOL && countMatches(value.size)
is Arg -> value.count == count && value.type == type
is String -> type == DataType.STRING || type == DataType.ENUM && possibleValues != null && possibleValues?.contains(value) ?: false
//is String -> type == DataType.ENUM && possibleValues != null && possibleValues?.contains(value) ?: false
is org.nd4j.linalg.api.buffer.DataType -> type == DataType.DATA_TYPE
is org.nd4j.codegen.api.LossReduce -> type == DataType.LOSS_REDUCE
null -> true
else -> false
}
fun isArray() = count != Exactly(1) && count != null
fun countMatches(size: Int) = when(val c = count!!){
is Range -> c.from <= size && size <= c.to
is AtLeast -> c.min <= size
is AtMost -> size <= c.max
is Exactly -> c.count == size
}
fun Tensor.shape() = TensorShapeValue(this)
fun Tensor.dataType() = TensorDataTypeValue(this)
override fun toString() = "Arg(${if(type == DataType.ENUM){
"ENUM(${possibleValues?.joinToString(", ")})"
}else{
type.toString()
}}, $name)${if(count != null) "{ count = $count }" else "" }"
}
data class Input (
val name: String,
val type: DataType,
var description: String? = null,
var count: Count? = null
) : Parameter, Tensor {
override fun isVararg(): Boolean {
return false
}
override fun name(): String = name
override fun defaultValue(): Any? = defaultValue
override fun hasDefaultValue(): Boolean = defaultValueIsSet
private var defaultValueIsSet = false
var defaultValue: Input? = null
set(value) = if(matchesDataType(value)){
field = value
defaultValueIsSet = true
}else{
throw IllegalArgumentException("Illegal default value for Input($name). Allowed values have to match data type $type, but got ${value.toDescriptiveString()} (${value?.javaClass?.name})")
}
private fun matchesDataType(value: Input?) = when(value){
null -> true
else -> value.type == type
}
}
data class Output(
var name: String,
var type: DataType,
var multiOutput: Boolean,
var description: String? = null
) : Parameter, Tensor{
override fun isVararg(): Boolean {
return false
}
override fun name(): String = name
override fun defaultValue(): Any? = null
override fun hasDefaultValue(): Boolean = false
}
data class Signature(
val parameters: List<Parameter>,
val description: String? = null
){
override fun toString(): String {
return "Signature(${parameters.joinToString {it.name()}})"
}
}
// Used in defining default values
data class TensorShapeValue(val tensor: Tensor) {
override fun toString(): String = "${tensor.name()}.shape()"
}
data class TensorDataTypeValue(val tensor: Tensor){
override fun toString(): String = "${tensor.name()}.dataType()"
}
fun Any?.toDescriptiveString() = when(this){
null -> "null"
is IntArray -> Arrays.toString(this)
is LongArray -> Arrays.toString(this)
is DoubleArray -> Arrays.toString(this)
is FloatArray -> Arrays.toString(this)
is BooleanArray -> Arrays.toString(this)
is Array<*> -> Arrays.toString(this)
else -> this.toString()
}
data class Config(
val name: String,
val inputs: MutableList<Input> = mutableListOf(),
val args: MutableList<Arg> = mutableListOf(),
val constraints: MutableList<Constraint> = mutableListOf(),
val doc: MutableList<DocSection> = mutableListOf()
): Parameter {
override fun isVararg(): Boolean {
return false
}
override fun name(): String = name
override fun defaultValue(): Any? = null
override fun hasDefaultValue(): Boolean = false
fun addInput(input: Input) { inputs.add(input) }
fun addArgument(arg: Arg) { args.add(arg) }
fun addConstraint(constraint: Constraint){ constraints.add(constraint) }
fun addDoc(doc: DocSection){ this.doc.add(doc) }
var javaClassOverride: String = ""
}