deeplearning4j/deeplearning4j

View on GitHub
nd4j/nd4j-onnxruntime/src/main/java/org/nd4j/onnxruntime/util/ONNXUtils.java

Summary

Maintainability
F
4 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */
package org.nd4j.onnxruntime.util;

import onnx.Onnx;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.bytedeco.onnxruntime.*;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.onnxruntime.runner.enums.ONNXType;
import org.nd4j.shade.guava.primitives.Longs;
import org.slf4j.Logger;

import java.util.List;
import java.util.stream.Collectors;

import static org.bytedeco.onnxruntime.global.onnxruntime.*;
import static org.nd4j.linalg.api.buffer.DataType.*;

public class ONNXUtils {


    /**
     * Return the {@link INDArray} from a sequence
     * @param outValue the input sequence to get the ndarrays from
     * @param ortAllocator the allocator to use to retrieve relevant memory
     * @return the equivalent arrays
     */
    public static INDArray[] ndarraysFromSequence(Value outValue,OrtAllocator ortAllocator) {
        Preconditions.checkState(outValue.HasValue(),"No value found in specified value!");
        INDArray[] ret = new INDArray[(int) outValue.GetCount()];
        for(int i = 0; i < ret.length; i++) {
            INDArray retValue = getArray(outValue.GetValue(i,ortAllocator));
            ret[i] = retValue;
        }

        return ret;
    }


    /**
     * Create a sequence from a list of tensors
     * returning a {@link ValueVector} equivalent
     * using {@link #getTensor(INDArray, MemoryInfo)}
     *
     * @param sequence the sequence to get
     * @param memoryInfo the memory info to use for allocation
     * @return
     */
    public static ValueVector getSequence(List<INDArray> sequence,MemoryInfo memoryInfo) {
        ValueVector valueVector = new ValueVector(sequence.size());
        for(int i = 0; i < sequence.size(); i++) {
            valueVector.put(getTensor(sequence.get(i),memoryInfo));
        }

        return valueVector;
    }


    /**
     * Get the onnx type of the output
     * @param session the session to get the input for
     * @param i the index of the output
     * @return
     */
    public static ONNXType getTypeForOutput(Session session,int i) {
        TypeInfo typeInfo = session.GetOutputTypeInfo(i);
        return ONNXType.values()[typeInfo.GetONNXType()];
    }


    /**
     * Get the onnx type of the input
     * @param session the session to get the output type info from
     * @param i the index of the input
     * @return the relevant type information
     */
    public static ONNXType getTypeForInput(Session session,long i) {
        TypeInfo typeInfo = session.GetInputTypeInfo(i);
        return ONNXType.values()[typeInfo.GetONNXType()];
    }


    /**
     * Returns a zeroed array of the input data.
     * This array's shape and data type are determined
     * from {@link Onnx.ValueInfoProto#getType()}
     * tensor field.
     * given the value type. Mainly used for quick debugging/
     * testing.
     * @param valueInfoProto the value info proto
     *                       to get the shape information from
     * @return the sample tensor
     */
    public static INDArray getSampleForValueInfo(Onnx.ValueInfoProto valueInfoProto) {
        Preconditions.checkState(valueInfoProto.hasType(),"Value info must have a type!");
        Onnx.TypeProto.Tensor tensorType = valueInfoProto.getType().getTensorType();
        long[] shape = Longs.toArray(tensorType.getShape().getDimList().stream().map(input -> input.getDimValue()).collect(Collectors.toList()));
        DataType type = dataTypeForOnnxType(tensorType.getElemType());
        return Nd4j.create(type,shape);
    }

    /**
     *
     * @param expected
     * @param array
     */
    public static void validateType(DataType expected, INDArray array) {
        if (!array.dataType().equals(expected))
            throw new RuntimeException("INDArray data type (" + array.dataType() + ") does not match required ONNX data type (" + expected + ")");
    }

    /**
     * Return a {@link DataType}
     * for the onnx data type
     * @param dataType the equivalent nd4j data type
     * @return
     */
    public static DataType dataTypeForOnnxType(int dataType) {
        if(dataType == dataType) {
            return FLOAT;
        } else if(dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
            return INT8;
        } else if(dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
            return DOUBLE;
        } else if(dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) {
            return BOOL;
        } else if(dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
            return UINT8;
        } else if(dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) {
            return UINT16;
        } else if(dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) {
            return INT16;
        } else if(dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
            return INT32;
        } else if(dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
            return INT64;
        } else if(dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
            return FLOAT16;
        } else if(dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) {
            return UINT32;
        } else if(dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) {
            return UINT64;
        } else if(dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) {
            return BFLOAT16;
        }
        else
            throw new IllegalArgumentException("Illegal data type " + dataType);
    }

    /**
     * Convert the onnx type for the given data type
     * @param dataType
     * @return
     */
    public static int onnxTypeForDataType(DataType dataType) {
        if(dataType == FLOAT) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
        } else if(dataType == INT8) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
        } else if(dataType == DOUBLE) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
        } else if(dataType == BOOL) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
        } else if(dataType == UINT8) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
        } else if(dataType == UINT16) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
        } else if(dataType == INT16) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
        } else if(dataType == INT32) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
        } else if(dataType == INT64) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
        } else if(dataType == FLOAT16) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
        } else if(dataType == UINT32) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
        } else if(dataType == UINT64) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
        } else if(dataType == BFLOAT16) {
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
        }
        else
            throw new IllegalArgumentException("Illegal data type " + dataType);
    }


    /**
     * Convert an onnx {@link Value}
     *  in to an {@link INDArray}
     * @param value the value to convert
     * @return
     */
    public static INDArray getArray(Value value) {
        DataType dataType = dataTypeForOnnxType(value.GetTypeInfo().GetONNXType());
        LongPointer shape = value.GetTensorTypeAndShapeInfo().GetShape();
        long[] shapeConvert;
        if(shape != null) {
            shapeConvert = new long[(int) value.GetTensorTypeAndShapeInfo().GetDimensionsCount()];
            shape.get(shapeConvert);
        } else {
            shapeConvert = new long[]{1};
        }

        DataBuffer getBuffer = getDataBuffer(value);
        Preconditions.checkState(dataType.equals(getBuffer.dataType()),"Data type must be equivalent as specified by the onnx metadata.");
        return Nd4j.create(getBuffer,shapeConvert,Nd4j.getStrides(shapeConvert),0);
    }


    /**
     * Get the onnx log level relative to the given slf4j logger.
     * Trace or debug will return ORT_LOGGING_LEVEL_VERBOSE
     * Info will return: ORT_LOGGING_LEVEL_INFO
     * Warn returns ORT_LOGGING_LEVEL_WARNING
     * Error returns error ORT_LOGGING_LEVEL_ERROR
     *
     * The default is info
     * @param logger the slf4j logger to get the onnx log level for
     * @return
     */
    public static int getOnnxLogLevelFromLogger(Logger logger) {
        if(logger.isTraceEnabled() || logger.isDebugEnabled()) {
            return ORT_LOGGING_LEVEL_VERBOSE;
        }
        else if(logger.isInfoEnabled()) {
            return ORT_LOGGING_LEVEL_INFO;
        }
        else if(logger.isWarnEnabled()) {
            return ORT_LOGGING_LEVEL_WARNING;
        }
        else if(logger.isErrorEnabled()) {
            return ORT_LOGGING_LEVEL_ERROR;
        }

        return ORT_LOGGING_LEVEL_INFO;

    }

    /**
     * Get an onnx tensor from an ndarray.
     * @param ndArray the ndarray to get the value from
     * @param memoryInfo the {@link MemoryInfo} to use.
     *                   Can be created with:
     *                   MemoryInfo memoryInfo = MemoryInfo.CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
     * @return
     */
    public static Value getTensor(INDArray ndArray, MemoryInfo memoryInfo) {
        if(ndArray == null || ndArray.isEmpty()) {
            /**
             *   static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
             *                             ONNXTensorElementDataType type)
             */
            LongPointer dims = new LongPointer(0);
            Value ret =  Value.CreateTensor(
                    memoryInfo.asOrtMemoryInfo(),
                    new FloatPointer(),
                    0,
                    dims,
                    0,
                    onnxTypeForDataType(FLOAT));
            return ret;
        }

        Pointer inputTensorValuesPtr = ndArray.data().pointer();
        Pointer inputTensorValues = inputTensorValuesPtr;
        long sizeInBytes = ndArray.length() * ndArray.data().getElementSize();

        /**
         *   static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
         *                             ONNXTensorElementDataType type)
         */
        LongPointer dims = new LongPointer(ndArray.shape());
        Value ret =  Value.CreateTensor(
                memoryInfo.asOrtMemoryInfo(),
                inputTensorValues,
                sizeInBytes,
                dims,
                ndArray.rank(),
                onnxTypeForDataType(ndArray.dataType()));
        return  ret;
    }

    /**
     * Get the data buffer from the given value
     * @param tens the values to get
     * @return the equivalent data buffer
     */
    public static DataBuffer getDataBuffer(Value tens) {
        if(tens.isNull())
            throw new IllegalArgumentException("Native underlying tensor value was null!");
        try (PointerScope scope = new PointerScope()) {
            DataBuffer buffer = null;
            int type = tens.GetTensorTypeAndShapeInfo().GetElementType();
            long size = tens.GetTensorTypeAndShapeInfo().GetElementCount();
            switch (type) {
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
                    FloatPointer pFloat = tens.GetTensorMutableDataFloat().capacity(size);
                    FloatIndexer floatIndexer = FloatIndexer.create(pFloat);
                    buffer = Nd4j.createBuffer(pFloat, DataType.FLOAT, size, floatIndexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
                    BytePointer pUint8 = tens.GetTensorMutableDataUByte().capacity(size);
                    Indexer uint8Indexer = ByteIndexer.create(pUint8);
                    buffer = Nd4j.createBuffer(pUint8, DataType.UINT8, size, uint8Indexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
                    BytePointer pInt8 = tens.GetTensorMutableDataByte().capacity(size);
                    Indexer int8Indexer = ByteIndexer.create(pInt8);
                    buffer = Nd4j.createBuffer(pInt8, DataType.UINT8, size, int8Indexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
                    ShortPointer pUint16 = tens.GetTensorMutableDataUShort().capacity(size);
                    Indexer uint16Indexer = ShortIndexer.create(pUint16);
                    buffer = Nd4j.createBuffer(pUint16, DataType.UINT16, size, uint16Indexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
                    ShortPointer pInt16 = tens.GetTensorMutableDataShort().capacity(size);
                    Indexer int16Indexer = ShortIndexer.create(pInt16);
                    buffer = Nd4j.createBuffer(pInt16, INT16, size, int16Indexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
                    IntPointer pInt32 = tens.GetTensorMutableDataInt().capacity(size);
                    Indexer int32Indexer = IntIndexer.create(pInt32);
                    buffer = Nd4j.createBuffer(pInt32, DataType.INT32, size, int32Indexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
                    LongPointer pInt64 = tens.GetTensorMutableDataLong().capacity(size);
                    Indexer int64Indexer = LongIndexer.create(pInt64);
                    buffer = Nd4j.createBuffer(pInt64, DataType.INT64, size, int64Indexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
                    BytePointer pString = tens.GetTensorMutableDataByte().capacity(size);
                    Indexer stringIndexer = ByteIndexer.create(pString);
                    buffer = Nd4j.createBuffer(pString, DataType.INT8, size, stringIndexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
                    BoolPointer pBool = tens.GetTensorMutableDataBool().capacity(size);
                    Indexer boolIndexer = BooleanIndexer.create(new BooleanPointer(pBool)); //Converting from JavaCPP Bool to Boolean here - C++ bool type size is not defined, could cause problems on some platforms
                    buffer = Nd4j.createBuffer(pBool, DataType.BOOL, size, boolIndexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
                    ShortPointer pFloat16 = tens.GetTensorMutableDataShort().capacity(size);
                    Indexer float16Indexer = ShortIndexer.create(pFloat16);
                    buffer = Nd4j.createBuffer(pFloat16, DataType.FLOAT16, size, float16Indexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
                    DoublePointer pDouble = tens.GetTensorMutableDataDouble().capacity(size);
                    Indexer doubleIndexer = DoubleIndexer.create(pDouble);
                    buffer = Nd4j.createBuffer(pDouble, DataType.DOUBLE, size, doubleIndexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
                    IntPointer pUint32 = tens.GetTensorMutableDataUInt().capacity(size);
                    Indexer uint32Indexer = IntIndexer.create(pUint32);
                    buffer = Nd4j.createBuffer(pUint32, DataType.UINT32, size, uint32Indexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
                    LongPointer pUint64 = tens.GetTensorMutableDataULong().capacity(size);
                    Indexer uint64Indexer = LongIndexer.create(pUint64);
                    buffer = Nd4j.createBuffer(pUint64, DataType.UINT64, size, uint64Indexer);
                    break;
                case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
                    ShortPointer pBfloat16 = tens.GetTensorMutableDataShort().capacity(size);
                    Indexer bfloat16Indexer = ShortIndexer.create(pBfloat16);
                    buffer = Nd4j.createBuffer(pBfloat16, DataType.BFLOAT16, size, bfloat16Indexer);
                    break;
                default:
                    throw new RuntimeException("Unsupported data type encountered");
            }
            return buffer;
        }
    }

}