deeplearning4j/deeplearning4j

View on GitHub
nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRGraph.kt

Summary

Maintainability
F
3 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */
package org.nd4j.samediff.frameworkimport.onnx.ir

import onnx.Onnx
import org.nd4j.ir.OpNamespace
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.samediff.frameworkimport.context.MappingContext
import org.nd4j.samediff.frameworkimport.ir.IRDataType
import org.nd4j.samediff.frameworkimport.ir.IRGraph
import org.nd4j.samediff.frameworkimport.ir.IRNode
import org.nd4j.samediff.frameworkimport.ir.importInfoForEachNodeInGraph
import org.nd4j.samediff.frameworkimport.onnx.*
import org.nd4j.samediff.frameworkimport.onnx.context.OnnxMappingContext
import org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoaderHolder
import org.nd4j.samediff.frameworkimport.registry.OpMappingRegistry
import org.nd4j.samediff.frameworkimport.stripVarSuffix
import java.lang.IllegalArgumentException
import java.lang.IllegalStateException

class OnnxIRGraph(graphDef: Onnx.GraphProto,opMappingRegistry: OpMappingRegistry<Onnx.GraphProto,
        Onnx.NodeProto,Onnx.NodeProto,Onnx.TensorProto,Onnx.TensorProto.DataType,Onnx.AttributeProto,
        Onnx.AttributeProto>): IRGraph<
        Onnx.GraphProto, Onnx.NodeProto,
        Onnx.NodeProto, Onnx.TensorProto, Onnx.AttributeProto, Onnx.AttributeProto,
        Onnx.TensorProto.DataType> {

    var graphDef = graphDef
    val opList = graphDef.nodeList
    val opMappingRegistry = opMappingRegistry
    var cachedNodeList: List<IRNode<Onnx.NodeProto, Onnx.TensorProto, Onnx.AttributeProto, Onnx.AttributeProto, Onnx.TensorProto.DataType>>
    var inputList = ArrayList<String>()
    var outputList = ArrayList<String>()
    var variableList = ArrayList<String>()
    var initializerSet = HashSet<String>()
    val nodeNames: Set<String>
    val inputsOutputs = HashSet<String>()


    override fun nodeByName(input: String): Onnx.NodeProto {
        //sometimes models exported from onnx will have tensorflow's var suffix
        val input2 = stripVarSuffix(input)
        if(!cachedNodeList.map { input -> input.nodeName() }.contains(input2)) {
            throw IllegalStateException("No input found for node name $input")
        }
        return cachedNodeList.first { inputNode -> inputNode.nodeName() == input2 }.internalValue()
    }

    init {
        //sometimes onnx nodes will have empty names, ensure that each node has a deterministically generated name
        val indexToNode = HashMap<Int,Onnx.NodeProto>()
        val opTypes = HashMap<String,String>()

        nodeNames = HashSet()
        preProcessZeroSuffixes()

        cachedNodeList = nodeList()


        cachedNodeList.forEachIndexed { index,node ->
            if(node.nodeName().isEmpty()) {
                val newNodeBuilder = node.internalValue().toBuilder()
                if(node.numOutputs() > 1) {
                    println("Found node with no name and > 1 input.  Node was $node. Using first output as name.")
                }
                val newName = node.outputAt(0)
                newNodeBuilder.name = newName.replace(":0","")
                val newNode = newNodeBuilder.build()
                indexToNode[index] = newNode
            }

            node.inputs().forEach { inputsOutputs.add(it.replace(":0","")) }
            node.outputs().forEach { inputsOutputs.add(it.replace(":0","")) }
            nodeNames.add(node.nodeName().replace(":0",""))
            opTypes[node.nodeName()] = node.opName()


        }


        val initializers = this.graphDef.initializerList.map { input -> input.name.replace(":0","") }
        println(initializers)
        val inputList = this.graphDef.inputList.filter { input -> !opTypes.containsKey(input.name.replace(":0","")) && !initializers.contains(input.name.replace(":0",""))}.map { input -> input.name.replace(":0","") }
        val varList = this.graphDef.inputList.filter { input -> initializers.contains(input.name.replace(":0","")) }.map { input -> input.name.replace(":0","") }
        println("Inputs $inputList")
        println("Variables $varList")
        this.inputList.addAll(inputList)
        this.variableList.addAll(inputList)
        initializerSet.addAll(initializers)
        outputList.addAll(this.graphDef.outputList.filter { valueInfo -> !valueInfo.name.contains(valueInfo.name) }
            .map { input -> input.name.replace(":0","") })
    }

    /**
     * Handle zero suffixes such that the suffixes are removed.
     * This is for when you import a tensorflow model or import a model
     * from tf onnx and need to handle the :0 edge case which is pretty common
     * when interacting with anything that came from tensorflow.
     */
    fun preProcessZeroSuffixes() {
        val graphDefBuilder = graphDef.toBuilder()
        val initializerList = ArrayList<Onnx.TensorProto>()
        //ensure we prune all :0 suffixes which may come from tf onnx
        for(i in 0 until graphDefBuilder.initializerCount) {
            val currInitializer = graphDefBuilder.initializerList[0]
            val builder = currInitializer.toBuilder()
            builder.name = currInitializer.name.replace(":0","")
            initializerList.add(builder.build())
            graphDefBuilder.removeInitializer(0)
        }


        graphDefBuilder.nodeBuilderList.forEach {
            it.name = it.name.replace(":0","")
            val inputList = it.inputList.toMutableList()
            val outputList = it.outputList.toMutableList()
            for(i in 0 until it.inputCount) {
                it.clearInput()
            }
            for(i in 0 until it.outputCount) {
                it.clearOutput()
            }

            it.addAllInput(inputList.map { input -> input.replace(":0","") })
            it.addAllOutput(outputList.map { input -> input.replace(":0","") })

        }

        initializerList.forEach { graphDefBuilder.addInitializer(it) }
        this.graphDef = graphDefBuilder.build()
    }


    override fun nodeList(): List<IRNode<Onnx.NodeProto, Onnx.TensorProto, Onnx.AttributeProto, Onnx.AttributeProto, Onnx.TensorProto.DataType>> {
        if(cachedNodeList != null) {
            return cachedNodeList
        }

        val ret2 =
            ArrayList<IRNode<Onnx.NodeProto, Onnx.TensorProto, Onnx.AttributeProto, Onnx.AttributeProto, Onnx.TensorProto.DataType>>()
        //add all inputs, outputs, initializers together as "nodes" similar to TF

        val identityOp =  OpDescriptorLoaderHolder.listForFramework<Onnx.NodeProto>("onnx")["Constant"]!!

        //for model import purposes, add identity ops as dummies similar to how tensorflow does placeholders/constants
        val initializerListNames = graphDef.initializerList.map { input -> input.name.replace(":0","") }
        graphDef.inputList.filter { input -> !initializerListNames.contains(input.name.replace(":0","")) }.forEach { input ->
            //note: this is not a real op name in onnx, this is purely for flagging for import to grab the node from the initializer
            //add dummy values for placeholders
            val tensorBuilder = Onnx.TensorProto.newBuilder()
            tensorBuilder.name = input.name
            tensorBuilder.dataType = input.type.tensorType.elemType
            input.type.tensorType.shape.dimList.forEach {
                tensorBuilder.addDims(it.dimValue)
            }
            val nodeToAdd = NodeProto {
                opType = "Placeholder"
                name = input.name.replace(":0","")
                Attribute(
                    Onnx.AttributeProto.newBuilder().setName("value")
                        .addTensors(tensorBuilder.build())
                        .build()
                )
            }

            ret2.add(OnnxIRNode(nodeToAdd, identityOp,opMappingRegistry))
        }

        //add inputs and outputs for use cases like placeholder detection
        inputList.addAll(graphDef.inputList.filter { input -> !initializerListNames.contains(input.name) }.map { input -> input.name })
        outputList.addAll(graphDef.outputList.filter { valueInfo -> !outputList.contains(valueInfo.name) }.map { input -> input.name })
        val frameworkList =  OpDescriptorLoaderHolder.listForFramework<Onnx.NodeProto>("onnx")
        graphDef.nodeList.forEach {
            val opDefOrNull = if(!frameworkList.containsKey(it.opType)) {
                //use Constant as a placeholder for any op that resolves to noop, this is probably an op handled by the custom implementation
                frameworkList["Constant"]!!
            } else {
                frameworkList[it.opType]!!
            }
            ret2.add(OnnxIRNode(it, opDefOrNull!!,opMappingRegistry))
        }

        //create dummy nodes by inferring which nodes have outputs
        //setup identity nodes that reflect the output to automatically
        //map index outputs to nodes that actually have outputs
        val outputNames = graphDef.outputList.map { input -> input.name }.toSet()
        val outputNodes = ArrayList<Onnx.NodeProto>()
        graphDef.nodeList.forEach { nodeProto ->
            val outputList = nodeProto.outputList.map { input -> input.toString() }.toSet()
            val containsAny = outputNames.intersect(outputList)
            if(containsAny.isNotEmpty()) {
                outputNodes.add(nodeProto)
            }
        }




        this.cachedNodeList = ret2
        return ret2
    }


    fun graphDef(): Onnx.GraphProto {
        return graphDef
    }

    override fun internalValue(): Onnx.GraphProto {
        return graphDef
    }



    override fun createMappingContext(
        opDef: Onnx.NodeProto,
        node: Onnx.NodeProto,
        dynamicVariables: MutableMap<String, Onnx.TensorProto>
    ): MappingContext<Onnx.GraphProto, Onnx.NodeProto, Onnx.NodeProto, Onnx.TensorProto, Onnx.AttributeProto, Onnx.AttributeProto, Onnx.TensorProto.DataType> {
        return OnnxMappingContext(opDef = opDef, node = node, graph = this, dynamicVariables = dynamicVariables)
    }

    override fun frameworkName(): String {
        return "onnx"
    }

    override fun nd4jNameForInternalOpName(name: String): String {
        return opMappingRegistry.lookupOpMappingProcess(name).opName()
    }

    override fun isConstantOpName(name: String): Boolean {
        return name == "Constant"
    }

    override fun isConstant(opName: String): Boolean {
        return opName == "Constant"
    }

    override fun isPlaceHolder(opName: String): Boolean {
        //note: this is a dummy op only used for import, it's not a real onnx op
        return opName == "Placeholder"
    }

    override fun shapeOfInput(varName: String): LongArray? {
        val firstOrNull = graphDef.initializerList.firstOrNull { inputNode -> inputNode.name == varName }
        if(firstOrNull != null)
            return firstOrNull.dimsList.toLongArray()
        else if(nodeIsPlaceHolder(stripVarSuffix(varName))) {
            val placeHolder = irNodeByName(stripVarSuffix(varName))
            val attrValue = placeHolder.attributeMap()["value"]!!.tensorValue().shape()
            val ret =  attrValue.toLongArray()
            for(i in ret.indices) {
                //missing dimension, probably dynamic, infer as -1 to match dynamic shape behavior in samediff
                if(ret[i] == 0L) {
                    ret[i] = -1
                }
            }

            return ret
        }
        return null
    }

    override fun dataTypeForVariable(varName: String): IRDataType<Onnx.TensorProto.DataType> {
        val varNameStripped = stripVarSuffix(varName)
        val firstOrNull = graphDef.initializerList.firstOrNull {
                inputNode -> inputNode.name == varNameStripped }
        val input = graphDef.inputList.firstOrNull { input2 ->
            input2.name == varNameStripped
        }
        if(firstOrNull != null)
            return OnnxIRDataType(Onnx.TensorProto.DataType.values()[firstOrNull!!.dataType])
        else if(nodeIsPlaceHolder(varNameStripped)) {
            if(input != null && input.type.hasTensorType()) {
                return OnnxIRDataType(Onnx.TensorProto.DataType.forNumber(input.type.tensorType.elemType))
            } else if(input != null && input.type.hasSequenceType()) {
                return OnnxIRDataType(Onnx.TensorProto.DataType.forNumber(input.type.sequenceType.elemType.tensorType.elemType))

            }

            val placeHolder = irNodeByName(varNameStripped)
            return placeHolder.attributeMap()["value"]!!.tensorValue().dataType()
        }
        else if(input != null)
            return OnnxIRDataType(Onnx.TensorProto.DataType.forNumber(input.type.tensorType.elemType))
        else
            return OnnxIRDataType(Onnx.TensorProto.DataType.UNDEFINED)
    }

    override fun importInfoForEachNode(dynamicVariables: MutableMap<String, Onnx.TensorProto>): Map<String, Pair<MappingContext<Onnx.GraphProto, Onnx.NodeProto, Onnx.NodeProto, Onnx.TensorProto, Onnx.AttributeProto, Onnx.AttributeProto, Onnx.TensorProto.DataType>, OpNamespace.OpDescriptor>> {
        return importInfoForEachNodeInGraph(graph = this,dynamicVariables = dynamicVariables)
    }

    override fun nodeIsPlaceHolder(nodeName: String): Boolean {
        val realName = if(nodeName.endsWith(":0")) {
            nodeName.replace(":0","")
        } else {
            nodeName
        }


        return this.inputList.contains(realName) || this.inputList.contains("$realName:0")
    }

    override fun opMappingRegistry(): OpMappingRegistry<Onnx.GraphProto, Onnx.NodeProto, Onnx.NodeProto, Onnx.TensorProto, Onnx.TensorProto.DataType, Onnx.AttributeProto, Onnx.AttributeProto> {
        return opMappingRegistry
    }

    override fun addConstantNode(name: String, value: INDArray) {
        val graphBuilder = graphDef.toBuilder()
        val converted = convertToOnnxTensor(value,name)
        graphBuilder.addInitializer(converted)

        val tensorShapeInfo = TensorTypeProto {
            shape = OnnxShapeProto {
                OnnxShape(value.shape().toList())
            }

        }

        val valueType = TypeProto {
            tensorType = tensorShapeInfo
        }

        val newValueInfo = ValueInfoProto {
            Type(valueType)
        }

        graphBuilder.addValueInfo(newValueInfo)
        this.graphDef = graphBuilder.build()
    }

    override fun updateNode(node: IRNode<Onnx.NodeProto, Onnx.TensorProto, Onnx.AttributeProto, Onnx.AttributeProto, Onnx.TensorProto.DataType>) {
        val graphBuilder = graphDef.toBuilder()
        val indexOfNode = graphBuilder.nodeList.map { input -> input.name }.indexOf(node.nodeName())
        if(indexOfNode < 0) {
            throw IllegalStateException("No node of name ${node.nodeName()} was found")
        }
        graphBuilder.setNode(indexOfNode,node.internalValue())
        this.graphDef = graphBuilder.build()
    }

    override fun graphOutputs(): List<String> {
        return outputList
    }

    override fun outputAt(index: Int): String {
        return outputList[index]
    }

    override fun setOutputs(outputs: List<String>) {
        this.outputList = outputList
    }

    override fun graphInputs(): List<String> {
        return inputList
    }

    override fun inputAt(index: Int): String {
        return inputList[index]
    }

    override fun setInputs(inputs: List<String>) {
        this.inputList = inputs as ArrayList<String>
    }

    override fun isVariable(nodeName: String): Boolean {
        val realName = if(nodeName.endsWith(":0")) {
            nodeName.replace(":0","")
        } else {
            nodeName
        }

        return variableList.contains(realName) || variableList.contains("$realName:0")
    }

    override fun isVariableOpName(name: String): Boolean {
        return name != "Constant"
    }

    override fun getConstantArrayForName(name: String): INDArray {
        val check = graphDef.initializerList.map { input ->input.name }
        if(!check.contains(name)) {
            //initializer not found, see if there is a constant node
            if (this.nodeNames.contains(name)) {
                val constNode = nodeByName(name)
                if (constNode.opType == "Constant") {
                    //every constant should have a tensor value
                    val getValue = constNode.getAttribute(0).t
                    return OnnxIRTensor(getValue).toNd4jNDArray()
                } else {
                    throw IllegalArgumentException("Constant of name $name not found!")

                }

            }
        }

        return OnnxIRTensor(graphDef.initializerList.first { input -> input.name == name }).toNd4jNDArray()
    }

    override fun hasConstantInitializer(name: String): Boolean {
        return initializerSet.contains(name)
    }

    override fun indexOfNode(input: String): Int {
        return cachedNodeList.map { inputNode -> inputNode.nodeName() }.indexOf(input)
    }
    override fun nodesWithInput(name: String): List<IRNode<Onnx.NodeProto, Onnx.TensorProto, Onnx.AttributeProto, Onnx.AttributeProto, Onnx.TensorProto.DataType>> {
        return cachedNodeList.filter { input -> input.inputs().contains(name) }
    }

    override fun irNodeByName(input: String): IRNode<Onnx.NodeProto, Onnx.TensorProto, Onnx.AttributeProto, Onnx.AttributeProto, Onnx.TensorProto.DataType> {
        val node = nodeByName(input)
        return OnnxIRNode(node,opMappingRegistry.lookupInputFrameworkOpDef(node.opType),opMappingRegistry)
    }

    override fun hasNode(nodeName: String): Boolean {
        return nodeNames.contains(nodeName)
    }

    override fun addGraphOutputsAsProcessingNodes(): Boolean {
        return true
    }

    override fun convertToNDArray(tensorTypeInput: Onnx.TensorProto): INDArray {
        return OnnxIRTensor(tensorTypeInput).toNd4jNDArray()
    }

    override fun isInputOrOutput(name: String): Boolean {
        val realName = if(name.endsWith(":0")) {
            name.replace(":0","")
        } else {
            name
        }

        return inputsOutputs.contains(name) || inputsOutputs.contains(realName)
    }

    override fun updateNodeCacheWith(nodeList: List<IRNode<Onnx.NodeProto, Onnx.TensorProto, Onnx.AttributeProto, Onnx.AttributeProto, Onnx.TensorProto.DataType>>) {
        this.cachedNodeList = nodeList
        val graphDefBuilder = graphDef.toBuilder()
        for(i in 0 until graphDefBuilder.nodeCount) {
            graphDefBuilder.removeNode(0)
        }
        nodeList.forEach {
            graphDefBuilder.addNode(it.internalValue())
        }

        this.graphDef = graphDefBuilder.build()
    }

    override fun convertToTensor(ndarrayInput: INDArray, tensorName: String): Onnx.TensorProto {
        return convertToOnnxTensor(ndarrayInput,tensorName)
    }
}