nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/DefaultImportRunner.kt
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.samediff.frameworkimport.runner
import org.nd4j.autodiff.functions.DifferentialFunction
import org.nd4j.autodiff.samediff.SameDiff
import org.nd4j.autodiff.samediff.VariableType
import org.nd4j.common.io.ReflectionUtils
import org.nd4j.graph.OpType
import org.nd4j.ir.OpNamespace
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.CustomOp
import org.nd4j.linalg.api.ops.DynamicCustomOp
import org.nd4j.linalg.api.ops.Op
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.samediff.frameworkimport.context.MappingContext
import org.nd4j.samediff.frameworkimport.convertNd4jDataTypeFromNameSpaceTensorDataType
import org.nd4j.samediff.frameworkimport.ndarrayFromNameSpaceTensor
import org.nd4j.samediff.frameworkimport.setNameForFunctionFromDescriptors
import org.nd4j.shade.protobuf.GeneratedMessageV3
import org.nd4j.shade.protobuf.ProtocolMessageEnum
import org.nd4j.weightinit.impl.ZeroInitScheme
import java.lang.IllegalArgumentException
import java.lang.reflect.Modifier
/**
* The default implementation of [ImportRunner].
*
* @author Adam Gibson
*/
class DefaultImportRunner<GRAPH_TYPE: GeneratedMessageV3,
NODE_TYPE : GeneratedMessageV3,
OP_DEF_TYPE : GeneratedMessageV3,
TENSOR_TYPE : GeneratedMessageV3,
ATTR_DEF_TYPE : GeneratedMessageV3,
ATTR_VALUE_TYPE : GeneratedMessageV3,
DATA_TYPE: ProtocolMessageEnum> : ImportRunner<GRAPH_TYPE,
NODE_TYPE,
OP_DEF_TYPE,
TENSOR_TYPE,
ATTR_DEF_TYPE,
ATTR_VALUE_TYPE,
DATA_TYPE> {
override fun <GRAPH_TYPE : GeneratedMessageV3, NODE_TYPE : GeneratedMessageV3, OP_DEF_TYPE : GeneratedMessageV3, TENSOR_TYPE : GeneratedMessageV3, ATTR_DEF_TYPE : GeneratedMessageV3, ATTR_VALUE_TYPE : GeneratedMessageV3, DATA_TYPE : ProtocolMessageEnum> initAttributes(
df: DifferentialFunction,
sd: SameDiff,
descriptorAndContext: Pair<MappingContext<GRAPH_TYPE, NODE_TYPE, OP_DEF_TYPE, TENSOR_TYPE, ATTR_DEF_TYPE, ATTR_VALUE_TYPE, DATA_TYPE>, OpNamespace.OpDescriptor>
) {
val applied = descriptorAndContext
val mappingContext = applied.first
when (df.opType()) {
Op.Type.CUSTOM,Op.Type.LOGIC -> {
val dynamicCustomOp = df as DynamicCustomOp
val grouped = descriptorAndContext.second.argDescriptorList.groupBy { descriptor ->
descriptor.argType
}
val sortedMap = HashMap<OpNamespace.ArgDescriptor.ArgType, List<OpNamespace.ArgDescriptor>>()
grouped.forEach { (argType, list) ->
sortedMap[argType] = list.sortedBy { arg -> arg.argIndex }
}
sortedMap.forEach { (argType, listOfArgsSortedByIndex) ->
when (argType) {
OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR -> {
if(df.opType() != Op.Type.LOGIC) {
val opInputs = sd.ops[dynamicCustomOp.ownName]
if(opInputs == null)
throw IllegalArgumentException("No op with name ${dynamicCustomOp.ownName} found!")
for(input in opInputs!!.inputsToOp) {
val name = if(mappingContext.graph().hasConstantInitializer(input)) {
input
} else {
"${input}:0"
}
//removes the suffix
if(!sd.hasVariable(input)) {
if(mappingContext.graph().hasConstantInitializer("${input}:0") || sd.hasVariable(name)) {
sd.renameVariable(name,input)
}
}
}
val args = dynamicCustomOp.args()
val arraysToAdd = ArrayList<INDArray>()
listOfArgsSortedByIndex.forEachIndexed { index, argDescriptor ->
val convertedTensor = ndarrayFromNameSpaceTensor(argDescriptor.inputValue)
if (index < args.size) {
val arg = args[index]
if (arg.variableType != VariableType.ARRAY) {
if (arg.shape == null) {
val emptyLongArray = LongArray(0)
arg.setShape(*emptyLongArray)
}
arraysToAdd.add(convertedTensor)
}
}
}
//note we don't add arrays one at a time because addInputArgument requires all the input arrays to be added at once
//dynamicCustomOp.addInputArgument(*arraysToAdd.toTypedArray())
}
}
OpNamespace.ArgDescriptor.ArgType.INT64, OpNamespace.ArgDescriptor.ArgType.INT32 -> {
listOfArgsSortedByIndex.forEach { dynamicCustomOp.addIArgument(it.int64Value) }
}
OpNamespace.ArgDescriptor.ArgType.DOUBLE, OpNamespace.ArgDescriptor.ArgType.FLOAT -> {
listOfArgsSortedByIndex.forEach { dynamicCustomOp.addTArgument(it.doubleValue) }
}
OpNamespace.ArgDescriptor.ArgType.OUTPUT_TENSOR -> {
listOfArgsSortedByIndex.forEach {
val convertedTensor = ndarrayFromNameSpaceTensor(it.inputValue)
dynamicCustomOp.addOutputArgument(convertedTensor)
}
}
//allow strings, but only for cases of setting a value in java
OpNamespace.ArgDescriptor.ArgType.STRING -> {}
OpNamespace.ArgDescriptor.ArgType.BOOL -> {
listOfArgsSortedByIndex.forEach {
dynamicCustomOp.addBArgument(it.boolValue)
}
}
OpNamespace.ArgDescriptor.ArgType.DATA_TYPE -> {
listOfArgsSortedByIndex.forEach {
val dtype = convertNd4jDataTypeFromNameSpaceTensorDataType(it.dataTypeValue!!)
val dtypeJavaClass = Class.forName("org.nd4j.linalg.api.buffer.DataType")
dynamicCustomOp.addDArgument(dtype)
df.javaClass.declaredFields.forEach { field ->
if (!Modifier.isStatic(field.modifiers) && !Modifier.isFinal(field.modifiers)
&& dtypeJavaClass.isAssignableFrom(field.type)
) {
field.isAccessible = true
ReflectionUtils.setField(field, df, dtype)
}
}
}
}
else -> {
throw IllegalArgumentException("Illegal type")
}
}
//set any left over fields if they're found
setNameForFunctionFromDescriptors(listOfArgsSortedByIndex, df)
}
val customOp = df as CustomOp
//important to call this as we may not have configured all fields
customOp.configureFromArguments()
df.configureWithSameDiff(sd)
}
Op.Type.SCALAR -> {
applied.second.argDescriptorList.forEach { argDescriptor ->
val field = ReflectionUtils.findField(df.javaClass, argDescriptor.name)
if (field != null) {
field.isAccessible = true
when (argDescriptor.name) {
"x", "y", "z" -> {
val createdNDArray = mappingContext.tensorInputFor(argDescriptor.name).toNd4jNDArray()
ReflectionUtils.setField(field, df, createdNDArray)
}
else -> {
val scalarField = ReflectionUtils.findField(df.javaClass, "scalarValue")
scalarField.isAccessible = true
//access the first input (should have been set) and make sure the scalar type is the
//the same
val firstValue = df.arg(0)
val dtype = firstValue.dataType()
when (argDescriptor.argType) {
OpNamespace.ArgDescriptor.ArgType.DOUBLE -> {
val nd4jScalarValue = Nd4j.scalar(argDescriptor.doubleValue).castTo(dtype)
ReflectionUtils.setField(scalarField, df, nd4jScalarValue)
}
OpNamespace.ArgDescriptor.ArgType.FLOAT -> {
val nd4jScalarValue = Nd4j.scalar(argDescriptor.floatValue).castTo(dtype)
ReflectionUtils.setField(scalarField, df, nd4jScalarValue)
}
OpNamespace.ArgDescriptor.ArgType.INT32 -> {
val nd4jScalarValue = Nd4j.scalar(argDescriptor.int32Value).castTo(dtype)
ReflectionUtils.setField(scalarField, df, nd4jScalarValue)
}
OpNamespace.ArgDescriptor.ArgType.INT64 -> {
val nd4jScalarValue = Nd4j.scalar(argDescriptor.int64Value).castTo(dtype)
ReflectionUtils.setField(scalarField, df, nd4jScalarValue)
}
OpNamespace.ArgDescriptor.ArgType.BOOL -> {
val nd4jScalarValue = Nd4j.scalar(argDescriptor.boolValue).castTo(dtype)
ReflectionUtils.setField(scalarField, df, nd4jScalarValue)
}
OpNamespace.ArgDescriptor.ArgType.STRING -> {
val nd4jScalarValue = Nd4j.scalar(argDescriptor.stringValue).castTo(dtype)
ReflectionUtils.setField(scalarField, df, nd4jScalarValue)
}
else -> {
throw IllegalArgumentException("Trying to convert invalid argument type " + argDescriptor.argType)
}
}
}
}
} else {
if (argDescriptor.argType in listOf(
OpNamespace.ArgDescriptor.ArgType.INT64,
OpNamespace.ArgDescriptor.ArgType.DOUBLE, OpNamespace.ArgDescriptor.ArgType.INT32,
OpNamespace.ArgDescriptor.ArgType.FLOAT
)
) {
val scalarField = ReflectionUtils.findField(df.javaClass, "scalarValue")
scalarField.isAccessible = true
//access the first input (should have been set) and make sure the scalar type is the
//the same
val irNode = mappingContext.irNode()
val firstValue = sd.getVariable(irNode.inputAt(0))
val dtype = firstValue.dataType()
when (argDescriptor.argType) {
OpNamespace.ArgDescriptor.ArgType.DOUBLE -> {
val nd4jScalarValue = Nd4j.scalar(argDescriptor.doubleValue).castTo(dtype)
ReflectionUtils.setField(scalarField, df, nd4jScalarValue)
}
OpNamespace.ArgDescriptor.ArgType.FLOAT -> {
val nd4jScalarValue = Nd4j.scalar(argDescriptor.floatValue).castTo(dtype)
ReflectionUtils.setField(scalarField, df, nd4jScalarValue)
}
OpNamespace.ArgDescriptor.ArgType.INT32 -> {
val nd4jScalarValue = Nd4j.scalar(argDescriptor.int32Value).castTo(dtype)
ReflectionUtils.setField(scalarField, df, nd4jScalarValue)
}
OpNamespace.ArgDescriptor.ArgType.INT64 -> {
val nd4jScalarValue = Nd4j.scalar(argDescriptor.int64Value).castTo(dtype)
ReflectionUtils.setField(scalarField, df, nd4jScalarValue)
}
OpNamespace.ArgDescriptor.ArgType.BOOL -> {
val nd4jScalarValue = Nd4j.scalar(argDescriptor.boolValue).castTo(dtype)
ReflectionUtils.setField(scalarField, df, nd4jScalarValue)
}
OpNamespace.ArgDescriptor.ArgType.STRING -> {
val nd4jScalarValue = Nd4j.scalar(argDescriptor.stringValue).castTo(dtype)
ReflectionUtils.setField(scalarField, df, nd4jScalarValue)
}
else -> {
throw IllegalArgumentException("Trying to convert invalid argument type " + argDescriptor.argType)
}
}
}
}
}
//set any left over fields if they're found
setNameForFunctionFromDescriptors(applied.second.argDescriptorList, df)
}
else -> {
var hasDimensions = false
if(df.opType() == Op.Type.REDUCE_LONG ||
df.opType() == Op.Type.REDUCE_BOOL ||
df.opType() == Op.Type.REDUCE_FLOAT ||
df.opType() == Op.Type.REDUCE_SAME ||
df.opType() == Op.Type.INDEXREDUCE && df.args().size > 1) {
hasDimensions = true
}
applied.second.argDescriptorList.forEach { argDescriptor ->
if (argDescriptor.name == "dimensions")
hasDimensions = true
val field = ReflectionUtils.findField(df.javaClass, argDescriptor.name)
if (field != null) {
field.isAccessible = true
when (argDescriptor.name) {
"x", "y", "z" -> {
val createdNDArray = mappingContext.tensorInputFor(argDescriptor.name).toNd4jNDArray()
ReflectionUtils.setField(field, df, createdNDArray)
}
"keepDims" -> ReflectionUtils.setField(field, df, argDescriptor.boolValue)
else -> {
}
}
}
}
if (hasDimensions) {
//dimensions sorted by index
val dimArgs: IntArray = when {
df.args().size > 1 && df.arg(1).arr != null -> {
df.arg(1).arr.toIntVector()
}
else -> {
applied.second.argDescriptorList.filter { argDescriptor -> argDescriptor.name.contains("dimensions") }
.sortedBy { argDescriptor -> argDescriptor.argIndex }
.map { argDescriptor -> argDescriptor.int64Value.toInt() }.toIntArray()
}
}
val dimensionsField = ReflectionUtils.findField(df.javaClass, "dimensions")
val dimensionzField = ReflectionUtils.findField(df.javaClass, "dimensionz")
val isEmptyReduce = ReflectionUtils.findField(df.javaClass,"isEmptyReduce")
if (dimensionsField != null) {
dimensionsField.isAccessible = true
if (intArrayOf(0).javaClass.isAssignableFrom(dimensionsField.type)) {
ReflectionUtils.setField(dimensionsField, df, dimArgs)
}
}
if (dimensionzField != null) {
dimensionzField.isAccessible = true
if (INDArray::class.java.isAssignableFrom(dimensionzField.type)) {
val buffer = Nd4j.createBuffer(dimArgs)
val createdArr = Nd4j.create(buffer)
ReflectionUtils.setField(dimensionzField, df, createdArr)
}
}
if(isEmptyReduce != null) {
isEmptyReduce.isAccessible = true
if(dimArgs.isEmpty()) {
ReflectionUtils.setField(isEmptyReduce,df,true)
}
}
}
//set any left over fields if they're found
setNameForFunctionFromDescriptors(applied.second.argDescriptorList, df)
}
}
}
}