deeplearning4j/deeplearning4j

View on GitHub
nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/definitions/implementations/Conv.kt

Summary

Maintainability
F
5 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */
package org.nd4j.samediff.frameworkimport.onnx.definitions.implementations

import org.apache.commons.lang3.StringUtils
import org.nd4j.autodiff.samediff.SDIndex
import org.nd4j.autodiff.samediff.SDVariable
import org.nd4j.autodiff.samediff.SameDiff
import org.nd4j.autodiff.samediff.internal.SameDiffOp
import org.nd4j.common.util.ArrayUtil
import org.nd4j.enums.Mode
import org.nd4j.enums.WeightsFormat
import org.nd4j.linalg.api.buffer.DataType
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.samediff.frameworkimport.ImportGraph
import org.nd4j.samediff.frameworkimport.ImportUtils
import org.nd4j.samediff.frameworkimport.hooks.PreImportHook
import org.nd4j.samediff.frameworkimport.hooks.annotations.PreHookRule
import org.nd4j.samediff.frameworkimport.registry.OpMappingRegistry
import org.nd4j.shade.protobuf.GeneratedMessageV3
import org.nd4j.shade.protobuf.ProtocolMessageEnum

/**
 * A port of cast.py from onnx tensorflow for samediff:
 * https://github.com/onnx/onnx-tensorflow/blob/master/onnx_tf/handlers/backend/cast.py
 *
 * @author Adam Gibson
 */
@PreHookRule(nodeNames = [],opNames = ["Conv"],frameworkName = "onnx")
class Conv : PreImportHook  {

    override fun doImport(
        sd: SameDiff,
        attributes: Map<String, Any>,
        outputNames: List<String>,
        op: SameDiffOp,
        mappingRegistry: OpMappingRegistry<GeneratedMessageV3, GeneratedMessageV3, GeneratedMessageV3, GeneratedMessageV3, ProtocolMessageEnum, GeneratedMessageV3, GeneratedMessageV3>,
        importGraph: ImportGraph<GeneratedMessageV3, GeneratedMessageV3, GeneratedMessageV3, GeneratedMessageV3, GeneratedMessageV3, GeneratedMessageV3, ProtocolMessageEnum>,
        dynamicVariables: Map<String, GeneratedMessageV3>
    ): Map<String, List<SDVariable>> {
        val inWeights = sd.getVariable(op.inputsToOp[1])
        val weightsRank = inWeights.shape.size

        var inputVariable = sd.getVariable(op.inputsToOp[0])
        val rank = weightsRank
        val xShape = inputVariable.shape
        val spatialSize = rank - 2
        val storageComputeFormat = ImportUtils.getDataFormat(rank)
        val computeIndex = storageComputeFormat.second.indexOf('C')
        val spatialFormat = StringUtils.join(storageComputeFormat.second.filter { input -> input == 'C' || input == 'W' })

        val perm = ((2 to weightsRank - 1).toList() + listOf(1,0)).map { input -> input.toLong() }.toLongArray()
        val kernelShape = if(attributes.containsKey("kernel_shape")) {
            val kernelShapeList = attributes["kernel_shape"] as List<Int>
            kernelShapeList.map { input -> input }.toIntArray()
        } else {
            val weightsShape = inWeights.shape
            weightsShape.map { input -> input.toInt() }.toIntArray()
        }

        var weights = sd.permute(inWeights,*perm)
        var inWeightsShape = ArrayUtil.permute(ArrayUtil.copy(inWeights.shape),perm)
        val dilations = if(attributes.containsKey("dilations")) {
            val dilationsList = attributes["dilations"] as List<Int>
            val dilationsArr = dilationsList
            dilationsList.map { input -> input.toLong() }
        } else {
            List<Long>(spatialSize) { _ -> 1}
        }

        val spatialSizeConst = sd.constant(spatialSize)

        val strides = if(attributes.containsKey("strides")) {
            val stridesList = attributes["strides"] as List<Int>
            val stridesArr = stridesList
            stridesArr.map { input -> input.toLong() }

        } else {
            List<Long>(spatialSize) { _ -> 1}
        }

        val pads = if(attributes.containsKey("pads")) {
            val padsList = attributes["pads"] as List<Int>
            padsList.map { input -> input.toLong() }
        } else {
            val newPadsList = mutableListOf<Long>(0,0)
            for(i in 0 until spatialSize) {
                newPadsList.add(0)
            }
            newPadsList
        }

        val defaultPads2 = defaultPads(spatialSize)
        var padMode = attributes["auto_pad"] as String?
        if(!attributes.containsKey("auto_pad") || attributes["auto_pad"] == "NOTSET") {
            if(pads != defaultPads2) {
                inputVariable = paddingOp(sd,inputVariable,pads)
                //note our padding is not quite the same is onnx
                //our valid is equivalent to NOTSET and paddings should not be modified
                padMode = "NOTSET"
            }
        } else if(padMode == "SAME_UPPER") {
            padMode = "SAME"
        } else if(padMode == "VALID") {
            padMode = "VALID"
        } else if(padMode == "SAME_LOWER") {
            throw IllegalArgumentException("Unable to convert model running SAME_LOWER")
        }


        var groups = attributes.getOrDefault("group",1) as Number
        groups = groups.toLong()
        var depthWise = (rank == 4 && weightsRank == 4 && groups.toInt() != 1)
        if(depthWise && xShape != null && xShape[1].toInt() != -1) {
            depthWise = depthWise && groups == xShape[1]
        }
        /*  if depthwise and x.get_shape().as_list()[1] != None:
      depthwise = bool(group == x.get_shape().as_list()[1])
        * */
        var xs = mutableListOf<SDVariable>()
        var weightGroupsList = mutableListOf<SDVariable>()
        if(depthWise) {
            val depthWiseFilterShape = mutableListOf<Int>()
            for(i in 0 until 2) depthWiseFilterShape.add(inWeightsShape[i].toInt())
            depthWiseFilterShape.add(-1)
            depthWiseFilterShape.add(Math.floorDiv(weights.shape[3].toInt(),groups.toInt()))
            weights = weights.reshape(*depthWiseFilterShape.toIntArray())
            inputVariable = sd.permute(inputVariable,*ImportUtils.getPermFromFormats(storageComputeFormat.first,storageComputeFormat.second))
            xs.add(inputVariable)
            weightGroupsList.add(weights)

        } else {
            val weightGroups = sd.split(weights,groups.toInt(),-1)
            val permuteFormat = ImportUtils.getPermFromFormats(storageComputeFormat.first,storageComputeFormat.second)
            inputVariable = sd.permute(inputVariable,*permuteFormat)
            if(groups.toInt() == 1)
                xs.add(inputVariable)
            else {
                xs.addAll(sd.split(inputVariable,groups.toInt(),-1))
            }
            weightGroupsList.addAll(weightGroups)
        }

        val convolvedList = mutableListOf<SDVariable>()
        var stridesList = mutableListOf<Long>()
        if(depthWise) {
            if(storageComputeFormat.second == "NHWC") {
                stridesList.add(1)
                stridesList.addAll(strides)
                stridesList.add(1)
            } else {
                stridesList.add(1)
                stridesList.add(1)
                stridesList.addAll(strides)
            }

            val convConfig = Conv2DConfig.builder()
                .kH(kernelShape[0].toLong())
                .kW(kernelShape[1].toLong())
                .sH(strides[0])
                .sW(strides[1])
                .dH(dilations[0])
                .dW(dilations[1])
                .dataFormat("NWHC")
                .weightsFormat(WeightsFormat.YXIO)
                .paddingMode(padModeForName(padMode!!))
                .build()

            for(i in 0 until xs.size) {
                var depthWiseConv2d = sd.cnn().depthWiseConv2d(xs[i.toInt()], weightGroupsList[i.toInt()], convConfig)
                convolvedList.add(depthWiseConv2d)
            }
        } else {
            for(i in 0 until groups) {
                if(rank == 3) {
                    //notset => valid
                    //valid => valid + pads zeroed
                    var totalPad = if(padMode == "NOTSET") {
                        0
                    } else {
                        pads[0]
                    }
                    val oneDConfig = Conv1DConfig.builder()
                        .k(kernelShape[0].toLong())
                        .dataFormat("NWC")
                        .d(dilations[0])
                        .p(totalPad)
                        .s(strides[0])
                        .paddingMode(PaddingMode.valueOf(padMode!!))
                        .build()
                    var convolved = sd.cnn().conv1d(xs[i.toInt()],weightGroupsList[i.toInt()], oneDConfig)
                    if(pads[0] > 0) {
                        convolved = convolved.get(*indicesForPads("NWC",pads).toTypedArray())
                    }
                    convolvedList.add(convolved)

                } else if(rank == 4) {
                    //notset => valid
                    //valid => valid + pads zeroed
                    var totalPadHeight = if(padMode == "NOTSET") {
                        0
                    } else {
                        pads[1]
                    }
                    var totalPadWidth = if(padMode == "NOTSET") {
                        0
                    } else {
                        pads[2]
                    }

                    val convConfig = Conv2DConfig.builder()
                        .kH(kernelShape[0].toLong())
                        .kW(kernelShape[1].toLong())
                        .sH(strides[0])
                        .sW(strides[1])
                        .pH(totalPadHeight)
                        .pW(totalPadWidth)
                        .dH(dilations[0])
                        .dW(dilations[1])
                        .dataFormat("NHWC")
                        .weightsFormat(WeightsFormat.YXIO)
                        .paddingMode(padModeForName(padMode!!))
                        .build()
                    var conv2d = sd.cnn().conv2d(xs[i.toInt()], weightGroupsList[i.toInt()], convConfig)
                    convolvedList.add(conv2d)

                } else if(rank == 5) {
                    var totalPadHeight = if(padMode == "NOTSET") {
                        0
                    } else {
                        pads[1]
                    }
                    var totalPadWidth = if(padMode == "NOTSET") {
                        0
                    } else {
                        pads[2]
                    }

                    var totalPadDepth = if(padMode == "NOTSET") {
                        0
                    } else {
                        pads[2]
                    }

                    val threeDConfig = Conv3DConfig.builder()
                        .kD(kernelShape[0].toLong())
                        .kH(kernelShape[1].toLong())
                        .kW(kernelShape[2].toLong())
                        .dD(dilations[0])
                        .dH(dilations[1])
                        .dW(dilations[2])
                        .pD(totalPadDepth).pH(totalPadHeight)
                        .pW(totalPadWidth)
                        .biasUsed(false)
                        .dataFormat("NWHDC")
                        .paddingMode(padModeForName(padMode!!))
                        .build()
                    var conv3d = sd.cnn().conv3d(xs[i.toInt()],weightGroupsList[i.toInt()], threeDConfig)
                    convolvedList.add(conv3d)

                }
            }
        }



        //grouped convolutions need to handle bias differently
        if(op.inputsToOp.size > 2) {
            val bias = sd.getVariable(op.inputsToOp[2])
            var output = sd.concat(-1,*convolvedList.toTypedArray())
            output = output.add(bias)
            output = sd.permute(outputNames[0],output,*ImportUtils.getPermFromFormats(storageComputeFormat.second,storageComputeFormat.first))
            return mapOf(output.name() to listOf(output))
        } else {
            var output = sd.concat(-1,*convolvedList.toTypedArray())
            val newPermute = ImportUtils.getPermFromFormats(storageComputeFormat.second,storageComputeFormat.first)
            output = sd.permute(outputNames[0],output,*newPermute)
            return mapOf(output.name() to listOf(output))
        }
    }



    fun padModeForName(name: String): PaddingMode {
        return when(name) {
            "VALID" -> PaddingMode.VALID
            "SAME" -> PaddingMode.SAME
            "NOTSET" -> PaddingMode.VALID
            else -> PaddingMode.CAUSAL
        }
    }

    fun indicesForPads(dataFormat: String,pads: List<Long>): List<SDIndex> {
        val ret = ArrayList<SDIndex>()
        val rank = dataFormat.length
        when(pads.size) {
            //1D cnn
            3 -> {
                val widthIdx = dataFormat.indexOf("W")
                for(i in 0 until rank) {
                    if(i == widthIdx) {
                        ret.add(SDIndex.interval(pads[i], - pads[i] - 1))
                    } else {
                        ret.add(SDIndex.all())

                    }
                }
            }
            //2d CNN
            4 -> {
                val widthIdx = dataFormat.indexOf("W")
                val heightIdx = dataFormat.indexOf("H")
                for(i in 0 until rank) {
                    if(i == widthIdx) {
                        ret.add(SDIndex.interval(pads[i], - pads[i] - 1))

                    } else if(i == heightIdx) {
                        ret.add(SDIndex.interval(pads[i],- pads[i] - 1))

                    } else {
                        ret.add(SDIndex.all())

                    }
                }
            }
            //3D CNN
            5 -> {
                val widthIdx = dataFormat.indexOf("W")
                val heightIdx = dataFormat.indexOf("H")
                val depthIdx = dataFormat.indexOf("D")
                for(i in 0 until rank) {
                    if(i == widthIdx) {
                        ret.add(SDIndex.interval(pads[i], - pads[i] - 1))

                    } else if(i == heightIdx) {
                        ret.add(SDIndex.interval(pads[i], - pads[i] - 1))

                    } else if(i == depthIdx) {
                        ret.add(SDIndex.interval(pads[i], - pads[i] - 1))

                    } else {
                        ret.add(SDIndex.all())
                    }
                }

            }
        }

        return ret
    }


    fun adaptPads(inputPads: List<Long>): List<Long> {
        if(inputPads.size == 4) {
            return listOf(inputPads[0], inputPads[2], inputPads[1], inputPads[3])
        }

        return inputPads
    }

    fun defaultPads(spatialSize : Int): List<Int> {
        val newPadsList = mutableListOf(0,0)
        for(i in 0 until spatialSize) {
            newPadsList.add(0)
        }
        return newPadsList
    }

    fun paddingOp(sd: SameDiff,x: SDVariable,pads: List<Long>): SDVariable {
        val adaptedPads = adaptPads(pads)
        val numDim = adaptedPads.size / 2
        val newPads = Nd4j.create(Nd4j.createBuffer(adaptedPads.toLongArray())).transpose().reshape('c',2,numDim)
        val firstConcat = Nd4j.create(Nd4j.createBuffer(longArrayOf(0,0,0,0))).reshape(4)
        val secondConcat = newPads.ravel()
        val newPads2 = Nd4j.concat(0,firstConcat,secondConcat)
        val inputPadding = sd.constant(newPads2.reshape('c',numDim + 2,2).castTo(DataType.INT32))
        return sd.image().pad(x,inputPadding,Mode.CONSTANT,0.0)
    }

}