nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/definitions/implementations/Resize.kt
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.samediff.frameworkimport.onnx.definitions.implementations
import org.nd4j.autodiff.samediff.SDIndex
import org.nd4j.autodiff.samediff.SDVariable
import org.nd4j.autodiff.samediff.SameDiff
import org.nd4j.autodiff.samediff.internal.SameDiffOp
import org.nd4j.enums.ImageResizeMethod
import org.nd4j.linalg.api.buffer.DataType
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.samediff.frameworkimport.ImportGraph
import org.nd4j.samediff.frameworkimport.hooks.PreImportHook
import org.nd4j.samediff.frameworkimport.hooks.annotations.PreHookRule
import org.nd4j.samediff.frameworkimport.registry.OpMappingRegistry
import org.nd4j.shade.protobuf.GeneratedMessageV3
import org.nd4j.shade.protobuf.ProtocolMessageEnum
import java.lang.IllegalArgumentException
/**
* A port of resize.py from onnx tensorflow for samediff:
* https://github.com/onnx/onnx-tensorflow/blob/master/onnx_tf/handlers/backend/resize.py#L195
*
* @author Adam Gibson
*/
@PreHookRule(nodeNames = [],opNames = ["Resize"],frameworkName = "onnx")
class Resize : PreImportHook {
override fun doImport(
sd: SameDiff,
attributes: Map<String, Any>,
outputNames: List<String>,
op: SameDiffOp,
mappingRegistry: OpMappingRegistry<GeneratedMessageV3, GeneratedMessageV3, GeneratedMessageV3, GeneratedMessageV3, ProtocolMessageEnum, GeneratedMessageV3, GeneratedMessageV3>,
importGraph: ImportGraph<GeneratedMessageV3, GeneratedMessageV3, GeneratedMessageV3, GeneratedMessageV3, GeneratedMessageV3, GeneratedMessageV3, ProtocolMessageEnum>,
dynamicVariables: Map<String, GeneratedMessageV3>
): Map<String, List<SDVariable>> {
// Parameter docs below are from the onnx operator docs:
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#resize
var inputVariable = sd.getVariable(op.inputsToOp[0])
val inputShape = sd.shape(inputVariable)
val roi = sd.getVariable(op.inputsToOp[1])
val scales = sd.getVariable(op.inputsToOp[2])
val sizes = sizes(sd,op)
/**
*
* If coordinate_transformation_mode is "half_pixel",
x_original = (x_resized + 0.5) / scale - 0.5,
if coordinate_transformation_mode is "pytorch_half_pixel",
x_original = length_resized > 1 ? (x_resized + 0.5) / scale - 0.5 : 0,
if coordinate_transformation_mode is "align_corners",
x_original = x_resized * (length_original - 1) / (length_resized - 1),
if coordinate_transformation_mode is "asymmetric",
x_original = x_resized / scale,
if coordinate_transformation_mode is "tf_crop_and_resize",
x_original = length_resized > 1 ? start_x * (length_original - 1) + x_resized * (end_x - start_x) * (length_original - 1) / (length_resized - 1) : 0.5 * (start_x + end_x) * (length_original - 1).
*/
val coordTransformationMode = attributes.getOrDefault("coordinate_transformation_mode","half_pixel") as String
val extrapolationValue = attributes.getOrDefault("extrapolation_value",0.0) as Double
/**
* Three interpolation modes: nearest (default), linear and cubic. The "linear" mode includes linear
* interpolation for 1D tensor and N-linear interpolation for N-D tensor (for example, bilinear interpolation for 2D tensor).
* The "cubic" mode includes cubic interpolation for 1D tensor
* and N-cubic interpolation for N-D tensor (for example, bicubic interpolation for 2D tensor).
*/
val mode = attributes.getOrDefault("mode","nearest") as String
val outputVarName = outputNames[0]
val outputSize = outputSize(sd, op, inputVariable, scales, sizes,inputShape)
outputSize!!.setShape(2)
//switch to NWHC (tensorflow format) and then back to NCHW (onnx format)
inputVariable = sd.permute(inputVariable,0,2,3,1)
var result: SDVariable? = null
when (coordTransformationMode) {
"tf_crop_and_resize" -> {
val indices = mutableListOf<Int>()
val rank = inputVariable.arr.rank()
for(i in 2 until rank) {
indices.add(i - 2,i)
indices.add(i,i + rank)
}
val boxes = sd.expandDims(sd.gather(roi,indices.toIntArray(),0),0)
val boxIndices = sd.range(0.0,inputVariable.shape[0] as Double,1.0, DataType.INT64)
result = sd.image().cropAndResize(inputVariable,boxes,boxIndices,outputSize,extrapolationValue)
}
"align_corners" -> {
result = invokeResize(mode, sd, inputVariable, outputSize, true, false)
}
"asymmetric" -> {
result = invokeResize(mode, sd, inputVariable, outputSize, false, false)
}
else -> {
when(mode) {
"nearest" -> {
result = sd.image().imageResize(inputVariable,outputSize,false,false,ImageResizeMethod.ResizeNearest)
}
"cubic" -> {
result = sd.image().imageResize(inputVariable,outputSize,false,false,ImageResizeMethod.ResizeBicubic)
}
"linear" -> {
result = sd.image().imageResize(inputVariable,outputSize,false,false,ImageResizeMethod.ResizeBilinear)
}
}
if(result == null) {
throw IllegalArgumentException("Illegal mode found $mode")
}
}
}
val finalOutput = sd.permute(outputVarName,result,0,3,1,2)
return mapOf(finalOutput.name() to listOf(finalOutput))
}
fun invokeResize(
type: String,
sd: SameDiff,
input: SDVariable,
size: SDVariable,
alignCorners: Boolean,
halfPixelCenters: Boolean
): SDVariable? {
return when (type) {
"linear" -> {
val height = size.arr.getInt(0)
val width = size.arr.getInt(1)
sd.image().resizeBiLinear(input,height,width, alignCorners, halfPixelCenters)
}
"cubic" -> {
sd.image().resizeBiCubic(input,size,alignCorners,halfPixelCenters)
}
else -> {
sd.image().imageResize(input,size,true,true,ImageResizeMethod.ResizeNearest)
}
}
}
fun outputSize(
sd: SameDiff,
op: SameDiffOp,
input: SDVariable,
scales: SDVariable,
sizes: SDVariable,
inputVariableShape: SDVariable
): SDVariable? {
var ret: SDVariable? = null
ret = if(op.inputsToOp.size == 3) {
val heightWidthScale = scales.get(SDIndex.interval(2,-1))
val subGet = inputVariableShape.get(SDIndex.interval(2,-1))
val heightWidthShape = sd.castTo(subGet,heightWidthScale.dataType())
val scaled = sd.castTo(sd.math.mul(heightWidthScale,heightWidthShape),DataType.INT32)
scaled
} else {
sizes.get(SDIndex.interval(2, 1,input.rank().arr.getInt(0)))
}
if(ret.shape.size < 2) {
var newRet = sd.zero(null,DataType.INT32,2)
ret = newRet.add(ret.arr.getInt(0).toDouble())
}
return ret.castTo(DataType.INT32)
}
fun alignCornersFor(coordTransformationMode: String): Boolean {
//note this includes the coordTransformationMode == "asymmetric"
return coordTransformationMode == "align_corners"
}
fun sizes(sd: SameDiff,op: SameDiffOp): SDVariable {
if(op.inputsToOp.size == 4)
return sd.getVariable(op.inputsToOp[3])
else
return sd.constant(Nd4j.empty())
}
}