nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/IRProtobufExtensions.kt
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.samediff.frameworkimport
import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer
import org.bytedeco.javacpp.indexer.HalfIndexer
import org.nd4j.autodiff.functions.DifferentialFunction
import org.nd4j.autodiff.samediff.SDVariable
import org.nd4j.autodiff.samediff.SameDiff
import org.nd4j.autodiff.samediff.VariableType
import org.nd4j.common.io.ReflectionUtils
import org.nd4j.common.util.ArrayUtil
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMappers
import org.nd4j.ir.OpNamespace
import org.nd4j.ir.TensorNamespace
import org.nd4j.linalg.api.buffer.DataType
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.nativeblas.NativeOpsHolder
import org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoaderHolder
import org.nd4j.shade.protobuf.ByteString
import java.lang.IllegalArgumentException
import java.nio.ByteBuffer
import java.nio.charset.Charset
import java.util.*
import kotlin.collections.ArrayList
import java.lang.reflect.Field
import java.nio.Buffer
fun isOutputFrameworkAttributeName(name: String, opDescriptor: OpNamespace.OpDescriptor): Boolean {
return opDescriptor.argDescriptorList.filter { argDescriptor -> argDescriptor.argType != OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR
&& argDescriptor.argType != OpNamespace.ArgDescriptor.ArgType.OUTPUT_TENSOR
}
.map { inputArg -> inputArg.name }.contains(name)
}
fun isNd4jTensorName(name: String, opDescriptor: OpNamespace.OpDescriptor): Boolean {
return opDescriptor.argDescriptorList.filter { argDescriptor -> argDescriptor.argType == OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR }
.map { inputArg -> inputArg.name }
.contains(name)
}
fun argDescriptorType(name: String, opDescriptor: OpNamespace.OpDescriptor): OpNamespace.ArgDescriptor.ArgType {
return opDescriptor.argDescriptorList.filter { argDescriptor -> argDescriptor.name == name }[0].argType
}
fun OpNamespace.OpDescriptorList.findOp(opName: String): OpNamespace.OpDescriptor {
val opRet = this.opListList.firstOrNull {opDescriptor -> opDescriptor.name == opName }
if(opRet == null) {
throw IllegalArgumentException("Op name $opName not found!")
}
return opRet
}
fun ArgDescriptor(block: OpNamespace.ArgDescriptor.Builder.() -> Unit): OpNamespace.ArgDescriptor {
return OpNamespace.ArgDescriptor.newBuilder()
.apply(block).build()
}
fun NameSpaceTensor(block: TensorNamespace.TensorProto.Builder.() -> Unit): TensorNamespace.TensorProto {
return TensorNamespace.TensorProto.newBuilder()
.apply(block).build()
}
fun TensorNamespace.TensorProto.Builder.RawData(rawData: ByteArray) {
this.rawData = ByteString.copyFrom(rawData)
}
fun TensorNamespace.TensorProto.Builder.IntData(intData: List<Int>) {
this.addAllInt32Data(intData)
}
fun TensorNamespace.TensorProto.Builder.FloatData(floatData: List<Float>) {
this.addAllFloatData(floatData)
}
fun TensorNamespace.TensorProto.Builder.DoubleData(doubleData: List<Double>) {
this.addAllDoubleData(doubleData)
}
fun TensorNamespace.TensorProto.Builder.StringData(stringData: List<String>) {
this.addAllStringData(stringData.map { input -> ByteString.copyFrom(input.toByteArray(Charset.defaultCharset())) })
}
fun TensorNamespace.TensorProto.Builder.Int64Data(intData: List<Long>) {
this.addAllInt64Data(intData)
}
fun TensorNamespace.TensorProto.Builder.Dims(shape: List<Long>) {
shape.forEach { this.addDims(it) }
}
fun convertNd4jDataTypeFromNameSpaceTensorDataType(dataType: TensorNamespace.DataType): DataType {
return when(dataType) {
TensorNamespace.DataType.UINT32 -> return DataType.UINT32
TensorNamespace.DataType.UINT8 -> return DataType.UINT8
TensorNamespace.DataType.INT64 -> return DataType.INT64
TensorNamespace.DataType.INT16 -> return DataType.INT16
TensorNamespace.DataType.UINT64 -> return DataType.UINT64
TensorNamespace.DataType.DOUBLE -> return DataType.DOUBLE
TensorNamespace.DataType.FLOAT -> return DataType.FLOAT
TensorNamespace.DataType.FLOAT16 -> return DataType.FLOAT16
TensorNamespace.DataType.FLOAT16 -> return DataType.FLOAT16
TensorNamespace.DataType.INT32 -> return DataType.INT32
TensorNamespace.DataType.STRING -> return DataType.UTF8
TensorNamespace.DataType.BOOL -> return DataType.BOOL
TensorNamespace.DataType.BFLOAT16 -> return DataType.BFLOAT16
TensorNamespace.DataType.INT8 -> return DataType.INT8
TensorNamespace.DataType.UINT16 -> return DataType.UINT16
TensorNamespace.DataType.UNDEFINED, TensorNamespace.DataType.UNRECOGNIZED -> return DataType.UNKNOWN
else -> {
throw IllegalArgumentException("Illegal data type $dataType")
}
}
}
fun convertNameSpaceTensorDataTypeFromNd4jDataType(dataType: DataType): TensorNamespace.DataType {
return when(dataType) {
DataType.UINT32 -> return TensorNamespace.DataType.UINT32
DataType.INT64, DataType.LONG -> return TensorNamespace.DataType.INT64
DataType.UINT64 -> return TensorNamespace.DataType.UINT64
DataType.DOUBLE -> return TensorNamespace.DataType.DOUBLE
DataType.FLOAT -> return TensorNamespace.DataType.FLOAT
DataType.FLOAT16, DataType.HALF -> return TensorNamespace.DataType.FLOAT16
DataType.HALF -> return TensorNamespace.DataType.FLOAT16
DataType.INT32, DataType.INT -> return TensorNamespace.DataType.INT32
DataType.UTF8 -> return TensorNamespace.DataType.STRING
DataType.BOOL -> return TensorNamespace.DataType.BOOL
DataType.BFLOAT16 -> return TensorNamespace.DataType.BFLOAT16
DataType.SHORT, DataType.INT8 -> return TensorNamespace.DataType.INT8
DataType.UINT16 -> return TensorNamespace.DataType.UINT16
DataType.BYTE, DataType.UINT8, DataType.UBYTE -> return TensorNamespace.DataType.UINT8
else -> {
throw IllegalArgumentException("Illegal data type $dataType")
}
}
}
fun ndarrayFromNameSpaceTensor(inputTensor: TensorNamespace.TensorProto): INDArray {
val dtype = convertNd4jDataTypeFromNameSpaceTensorDataType(TensorNamespace.DataType.values()[inputTensor.dataType])
val shape = inputTensor.dimsList.filter { input -> input > 0 }.toLongArray()
val totalLen = ArrayUtil.prod(*shape)
//note for all cases here scalars can be either zero shape with 1 element or rank >= 1 with 1 element
when(dtype) {
DataType.FLOAT -> {
val floatArray = inputTensor.floatDataList.toFloatArray()
if(floatArray.isEmpty())
return loadDataBufferFromRawData(inputTensor)
else if(totalLen <= 1 && shape.isEmpty()) {
return Nd4j.scalar(floatArray[0])
} else if(totalLen != floatArray.size) {
//broadcast case
if(floatArray.size == 1) {
return Nd4j.valueArrayOf(shape,floatArray[0])
}
else
throw IllegalArgumentException("Shape of ${Arrays.toString(shape)} did not match length ${floatArray.size}")
}
val dataBuffer = Nd4j.createBuffer(floatArray)
return Nd4j.create(dataBuffer).reshape(*shape)
}
DataType.DOUBLE -> {
val doubleArray = inputTensor.doubleDataList.toDoubleArray()
if(doubleArray.isEmpty())
return loadDataBufferFromRawData(inputTensor)
else if(totalLen <= 1 && shape.isEmpty()) {
return Nd4j.scalar(doubleArray[0])
}
else if(totalLen != doubleArray.size) {
//broadcast case
if(doubleArray.size == 1) {
return Nd4j.valueArrayOf(shape,doubleArray[0])
}
else
throw IllegalArgumentException("Shape of ${Arrays.toString(shape)} did not match length ${doubleArray.size}")
}
val dataBuffer = Nd4j.createBuffer(doubleArray)
return Nd4j.create(dataBuffer).reshape(*shape)
}
DataType.FLOAT16,DataType.HALF -> {
val halfArray = inputTensor.halfValList.toIntArray()
if(halfArray.isEmpty()) {
return loadDataBufferFromRawData(inputTensor)
} else if(totalLen <= 1 && shape.isEmpty()) {
val convertedFloat = HalfIndexer.toFloat(halfArray[0])
return Nd4j.scalar(convertedFloat).castTo(DataType.FLOAT16)
} else if(totalLen != halfArray.size) {
//broadcast case
if(halfArray.size == 1) {
val convertedFloat = HalfIndexer.toFloat(halfArray[0])
return Nd4j.valueArrayOf(shape,convertedFloat).castTo(DataType.FLOAT16)
}
else
throw IllegalArgumentException("Shape of ${Arrays.toString(shape)} did not match length ${halfArray.size}")
}
val dataBuffer = Nd4j.createTypedBuffer(halfArray.map { input -> input.toShort() }.toShortArray(),
DataType.HALF)
return Nd4j.create(dataBuffer).reshape(*shape)
}
DataType.BFLOAT16 -> {
val halfArray = inputTensor.halfValList.toIntArray()
if(halfArray.isEmpty()) {
return loadDataBufferFromRawData(inputTensor)
} else if(totalLen <= 1 && shape.isEmpty()) {
val convertedFloat = Bfloat16ArrayIndexer.toFloat(halfArray[0])
return Nd4j.scalar(convertedFloat).castTo(DataType.BFLOAT16)
} else if(totalLen != halfArray.size) {
//broadcast case
if(halfArray.size == 1) {
val convertedFloat = Bfloat16ArrayIndexer.toFloat(halfArray[0])
return Nd4j.valueArrayOf(shape,convertedFloat).castTo(DataType.BFLOAT16)
}
else
throw IllegalArgumentException("Shape of ${Arrays.toString(shape)} did not match length ${halfArray.size}")
}
val dataBuffer = Nd4j.createTypedBuffer(halfArray.map { input -> input.toShort() }.toShortArray(),
DataType.BFLOAT16)
val ret = Nd4j.create(dataBuffer).reshape(*shape)
return ret
}
DataType.INT64 -> {
val longArray = inputTensor.int64DataList.toLongArray()
if(longArray.isEmpty())
return loadDataBufferFromRawData(inputTensor)
else if(totalLen <= 1 && shape.isEmpty()) {
return Nd4j.scalar(longArray[0])
} else if(totalLen != longArray.size) {
//broadcast case
if(longArray.size == 1) {
return Nd4j.zeros(*shape).addi(longArray[0]).castTo(DataType.INT64)
}
else
throw IllegalArgumentException("Shape of ${Arrays.toString(shape)} did not match length ${longArray.size}")
}
val dataBuffer = Nd4j.createBuffer(longArray)
return Nd4j.create(dataBuffer).reshape(*shape)
}
DataType.INT32 -> {
val intArray = inputTensor.int32DataList.toIntArray()
if(intArray.isEmpty())
return loadDataBufferFromRawData(inputTensor)
else if(totalLen <= 1 && shape.isEmpty()) {
return Nd4j.scalar(intArray[0])
}
else if(totalLen != intArray.size) {
//broadcast case
if(intArray.size == 1) {
return Nd4j.valueArrayOf(shape,intArray[0])
}
else
throw IllegalArgumentException("Shape of ${Arrays.toString(shape)} did not match length ${intArray.size}")
}
val dataBuffer = Nd4j.createBuffer(intArray)
return Nd4j.create(dataBuffer).reshape(*shape)
}
DataType.INT16 -> {
val intArray = inputTensor.int32DataList.toIntArray()
if(intArray.isEmpty())
return loadDataBufferFromRawData(inputTensor)
else if(totalLen <= 1 && shape.isEmpty()) {
return Nd4j.scalar(intArray[0]).castTo(DataType.INT16)
}
else if(totalLen != intArray.size) {
//broadcast case
if(intArray.size == 1) {
return Nd4j.valueArrayOf(shape,intArray[0]).castTo(DataType.INT16)
}
else
throw IllegalArgumentException("Shape of ${Arrays.toString(shape)} did not match length ${intArray.size}")
}
val dataBuffer = Nd4j.createBuffer(intArray)
return Nd4j.create(dataBuffer).reshape(*shape).castTo(DataType.INT16)
}
DataType.INT8 -> {
val intArray = inputTensor.int32DataList.toIntArray()
if(intArray.isEmpty())
return loadDataBufferFromRawData(inputTensor)
else if(totalLen <= 1 && shape.isEmpty()) {
return Nd4j.scalar(intArray[0]).castTo(DataType.INT8)
}
else if(totalLen != intArray.size) {
//broadcast case
if(intArray.size == 1) {
return Nd4j.valueArrayOf(shape,intArray[0]).castTo(DataType.INT8)
}
else
throw IllegalArgumentException("Shape of ${Arrays.toString(shape)} did not match length ${intArray.size}")
}
val dataBuffer = Nd4j.createBuffer(intArray)
return Nd4j.create(dataBuffer).reshape(*shape).castTo(DataType.INT8)
}
DataType.UINT8 -> {
val intArray = inputTensor.int32DataList.toIntArray()
if(intArray.isEmpty())
return loadDataBufferFromRawData(inputTensor)
else if(totalLen <= 1 && shape.isEmpty()) {
return Nd4j.scalar(intArray[0]).castTo(DataType.UINT8)
}
else if(totalLen != intArray.size) {
//broadcast case
if(intArray.size == 1) {
return Nd4j.valueArrayOf(shape,intArray[0]).castTo(DataType.UINT8)
}
else
throw IllegalArgumentException("Shape of ${Arrays.toString(shape)} did not match length ${intArray.size}")
}
val dataBuffer = Nd4j.createBuffer(intArray)
return Nd4j.create(dataBuffer).reshape(*shape).castTo(DataType.UINT8)
}
DataType.UINT16 -> {
val intArray = inputTensor.int32DataList.toIntArray()
if(intArray.isEmpty())
return loadDataBufferFromRawData(inputTensor)
else if(totalLen <= 1 && shape.isEmpty()) {
return Nd4j.scalar(intArray[0]).castTo(DataType.UINT16)
}
else if(totalLen != intArray.size) {
//broadcast case
if(intArray.size == 1) {
return Nd4j.valueArrayOf(shape,intArray[0]).castTo(DataType.UINT16)
}
else
throw IllegalArgumentException("Shape of ${Arrays.toString(shape)} did not match length ${intArray.size}")
}
val dataBuffer = Nd4j.createBuffer(intArray)
return Nd4j.create(dataBuffer).reshape(*shape).castTo(DataType.UINT16)
}
DataType.BOOL -> {
val intArray = inputTensor.boolValList.toBooleanArray()
if(intArray.isEmpty())
return loadDataBufferFromRawData(inputTensor)
if(totalLen <= 1 && shape.isEmpty()) {
return Nd4j.scalar(intArray[0])
}
else if(totalLen != intArray.size) {
//broadcast case
if(intArray.size == 1) {
val booleanList = ArrayList<Boolean>()
for(i in 0 until totalLen) {
booleanList.add(intArray[0])
}
return Nd4j.create(booleanList.toBooleanArray()).reshape(*shape)
}
else
throw IllegalArgumentException("Shape of ${shape.contentToString()} did not match length ${intArray.size}")
}
return Nd4j.create(intArray).reshape(*shape)
}
DataType.UTF8 -> {
val stringList = inputTensor.stringDataList.map { input -> input.toStringUtf8() }
if(stringList.isEmpty())
return loadDataBufferFromRawData(inputTensor)
else if(totalLen <= 1 && shape.isEmpty()) {
return Nd4j.scalar(stringList[0])
} else if(totalLen != stringList.size) {
//broadcast case
if(stringList.size == 1) {
val newStringList = ArrayList<String>()
for(i in 0 until totalLen) {
newStringList.add(stringList[0])
}
return Nd4j.create(newStringList).reshape(*shape)
}
throw IllegalArgumentException("Shape of ${Arrays.toString(shape)} did not match length ${stringList.size}")
}
return Nd4j.create(stringList).reshape(*shape)
}
DataType.UNKNOWN -> {
val ret = Nd4j.empty()
return ret
}
else -> {
return loadDataBufferFromRawData(inputTensor)
}
}
throw IllegalArgumentException("Illegal type found for conversion ${dtype}")
}
fun loadDataBufferFromRawData(inputTensor: TensorNamespace.TensorProto): INDArray {
val shape = inputTensor.dimsList.toLongArray()
val dtype = convertNd4jDataTypeFromNameSpaceTensorDataType(TensorNamespace.DataType.values()[inputTensor.dataType])
val byteArray = inputTensor.rawData.toByteArray()
//note: scalar can be zero
var totalLen = ArrayUtil.prod(*shape)
if(totalLen < 1 && byteArray.isEmpty()) {
if(shape.isNotEmpty()) {
return Nd4j.zeros(*shape).castTo(dtype)
}
else {
return Nd4j.empty(dtype)
}
}
if(dtype == DataType.UTF8) {
val rawDataBuffer = Nd4j.getDataBufferFactory().createUtf8Buffer(byteArray,byteArray.size.toLong())
if(shape.isNotEmpty() && totalLen > 0) {
if(rawDataBuffer.length() > 0) {
val stringInput = java.lang.String(byteArray).toString()
return Nd4j.create(stringInput)
}
return Nd4j.empty(dtype)
}
return Nd4j.create(rawDataBuffer)
} else {
//sometimes data isn't empty but the shape is still a scalar
if(totalLen < 1)
totalLen = 1
val byteBuffer = ByteBuffer.allocateDirect(totalLen * dtype.width())
byteBuffer.put(byteArray)
//See: https://github.com/apache/felix/pull/114
val castBuffer = byteBuffer as Buffer
castBuffer.rewind()
val rawDataBuffer = Nd4j.createBuffer(byteBuffer, dtype, totalLen, 0)
if(shape.isNotEmpty() && totalLen > 0) {
if(rawDataBuffer.length() > 0)
return Nd4j.create(rawDataBuffer).reshape('c',*shape)
return Nd4j.empty(dtype)
}
return Nd4j.create(rawDataBuffer)
}
}
fun nameSpaceTensorFromNDarray(ndarray: INDArray): TensorNamespace.TensorProto {
val nameSpaceDataType = convertNameSpaceTensorDataTypeFromNd4jDataType(ndarray.dataType()).ordinal
when(ndarray.dataType()) {
DataType.INT64 -> {
return NameSpaceTensor {
dataType = nameSpaceDataType
Int64Data(ndarray.data().asLong().toList())
Dims(ndarray.shape().asList())
}
}
DataType.INT32 -> {
return NameSpaceTensor {
dataType = nameSpaceDataType
IntData(ndarray.data().asInt().toList())
Dims(ndarray.shape().asList())
}
}
DataType.DOUBLE -> {
return NameSpaceTensor {
dataType = nameSpaceDataType
DoubleData(ndarray.data().asDouble().toList())
Dims(ndarray.shape().asList())
}
}
DataType.FLOAT -> {
return NameSpaceTensor {
dataType = nameSpaceDataType
FloatData(ndarray.data().asFloat().toList())
Dims(ndarray.shape().asList())
}
}
DataType.UTF8 -> {
val stringList = ArrayList<String>()
for(i in 0 until ndarray.length()) {
stringList.add(ndarray.getString(i))
}
return NameSpaceTensor {
dataType = nameSpaceDataType
StringData(stringList)
Dims(ndarray.shape().asList())
}
}
else -> {
throw IllegalArgumentException("Illegal data type ${ndarray.dataType()}")
}
}
}
fun lookupIndexForArgDescriptor(
argDescriptorName: String,
opDescriptorName: String,
argDescriptorType: OpNamespace.ArgDescriptor.ArgType
): Int {
val op = OpDescriptorLoaderHolder.nd4jOpDescriptor.findOp(opDescriptorName)
val names = op.argDescriptorList.map { argDescriptor -> argDescriptor.name }
if(!names.contains(argDescriptorName)) {
throw IllegalArgumentException("Invalid name $argDescriptorName for op $opDescriptorName passed in. $argDescriptorName not found in $opDescriptorName. Available names were ${names}")
}
val ret = op
.argDescriptorList.firstOrNull { argDescriptor -> argDescriptor.name == argDescriptorName &&
argDescriptor.argType == argDescriptorType }
if(ret == null)
return -1
else return ret.argIndex
}
fun createVariable(varName: String, varType: VariableType, sameDiff: SameDiff, shape: List<Long>, dataType: DataType): SDVariable {
return SDVariable(varName, varType, sameDiff, shape.toLongArray(), dataType)
}
fun descriptorsForName(
name: String,
argDescriptors: Collection<OpNamespace.ArgDescriptor>): List<OpNamespace.ArgDescriptor> {
return argDescriptors.filter { argDescriptor -> argDescriptor.name == name }!!
}
fun setNameForFunctionFromDescriptors(argDescriptors: Collection<OpNamespace.ArgDescriptor>, func: DifferentialFunction) {
val fields = ArrayList<Field>()
fields.addAll(func.javaClass.declaredFields.toList())
fields.addAll(func.javaClass.superclass.declaredFields.toList())
fields.forEach { field ->
if(hasArgDescriptorWithNameAndType(argDescriptors, field.name)) {
val descriptors = descriptorsForName(field.name, argDescriptors)
descriptors.forEach { descriptor ->
when(descriptor.argType) {
OpNamespace.ArgDescriptor.ArgType.BOOL -> {
if(Boolean.javaClass.isAssignableFrom(field.type) || Boolean::class.javaPrimitiveType!!.isAssignableFrom(field.type)) {
field.isAccessible = true
ReflectionUtils.setField(field, func, descriptor.boolValue)
}
}
OpNamespace.ArgDescriptor.ArgType.STRING -> {
if(field.type.isAssignableFrom(String::class.java)) {
field.isAccessible = true
ReflectionUtils.setField(field, func, descriptor.stringValue)
}
}
OpNamespace.ArgDescriptor.ArgType.INT64, OpNamespace.ArgDescriptor.ArgType.INT32 -> {
if(Int.javaClass.isAssignableFrom(field.type) || Int::class.javaPrimitiveType!!.isAssignableFrom(field.type)) {
field.isAccessible = true
ReflectionUtils.setField(field, func, descriptor.int64Value.toInt())
}
if(Long.javaClass.isAssignableFrom(field.type) || Long::class.javaPrimitiveType!!.isAssignableFrom(field.type)) {
field.isAccessible = true
ReflectionUtils.setField(field, func, descriptor.int64Value)
}
if(DataType::javaClass.javaClass.isAssignableFrom(field.type)) {
field.isAccessible = true
ReflectionUtils.setField(field, func, DataType.fromInt(descriptor.int64Value.toInt()))
}
}
OpNamespace.ArgDescriptor.ArgType.FLOAT, OpNamespace.ArgDescriptor.ArgType.DOUBLE -> {
if(Float.javaClass.isAssignableFrom(field.type) || Float::class.javaPrimitiveType!!.isAssignableFrom(field.type)) {
field.isAccessible = true
ReflectionUtils.setField(field, func, descriptor.doubleValue.toFloat())
}
if(Double.javaClass.isAssignableFrom(field.type) || Double::class.javaPrimitiveType!!.isAssignableFrom(field.type)) {
field.isAccessible = true
ReflectionUtils.setField(field, func, descriptor.doubleValue)
}
}
OpNamespace.ArgDescriptor.ArgType.DATA_TYPE -> {
if(DataType::class.java.isAssignableFrom(field.type)) {
field.isAccessible = true
ReflectionUtils.setField(
field,
func,
convertNd4jDataTypeFromNameSpaceTensorDataType(descriptor.dataTypeValue)
)
}
}
OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR,OpNamespace.ArgDescriptor.ArgType.OUTPUT_TENSOR -> {
if(INDArray::class.java.isAssignableFrom(field.type)) {
field.isAccessible = true
ReflectionUtils.setField(
field,
func,
ndarrayFromNameSpaceTensor(descriptor.inputValue))
}
}
OpNamespace.ArgDescriptor.ArgType.OUTPUT_TENSOR ->
if(INDArray::class.java.isAssignableFrom(field.type)) {
field.isAccessible = true
ReflectionUtils.setField(
field,
func,
ndarrayFromNameSpaceTensor(descriptor.outputValue))
}
OpNamespace.ArgDescriptor.ArgType.UNRECOGNIZED -> throw IllegalArgumentException("Illegal type ${field.type}")
}
}
}
}
}
fun hasArgDescriptorWithNameAndType(argDescriptors: Collection<OpNamespace.ArgDescriptor>, name: String): Boolean {
return argDescriptors.map { input -> input.name}.contains(name)
}
/**
* @return The specified name without the leading "^" character (if any) that appears for control dependencies
*/
fun stripControl(name: String): String {
return if (name.startsWith("^")) {
name.substring(1)
} else name
}
/**
* Remove the ":1" etc suffix for a variable name to get the op name
*
* @param varName Variable name
* @return Variable name without any number suffix
*/
fun stripVarSuffix(varName: String): String {
if (varName.matches(regex = Regex(".*:\\d+"))) {
val idx = varName.lastIndexOf(':')
return varName.substring(0, idx)
}
return varName
}