deeplearning4j/deeplearning4j

View on GitHub
nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java

Summary

Maintainability
F
1 wk
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.nativeblas;

import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.factory.BaseNDArrayFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.memory.MemcpyDirection;

import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.Map;

@Slf4j
public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {

    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();

    public BaseNativeNDArrayFactory(DataType dtype, Character order) {
        super(dtype, order);
    }

    public BaseNativeNDArrayFactory(DataType dtype, char order) {
        super(dtype, order);
    }

    public BaseNativeNDArrayFactory() {}


    @Override
    public DataBuffer convertToNumpyBuffer(INDArray array) {
        Pointer pointer = NativeOpsHolder.getInstance().getDeviceNativeOps().numpyFromNd4j(array.data().addressPointer(), array.shapeInfoDataBuffer().pointer(), array.data().getElementSize());
        Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.HOST);
        long len = NativeOpsHolder.getInstance().getDeviceNativeOps().numpyHeaderLength(array.data().opaqueBuffer(),array.shapeInfoDataBuffer().pointer());
        pointer.capacity(len + array.length() * array.data().getElementSize());
        pointer.limit(len + array.length() * array.data().getElementSize());
        BytePointer wrapper = new BytePointer(pointer);
        wrapper.capacity(len + array.length() * array.data().getElementSize());
        wrapper.limit(len + array.length() * array.data().getElementSize());
        DataBuffer buffer = Nd4j.createBuffer(wrapper,len  + array.length() * array.data().getElementSize(),DataType.INT8);
        return buffer;
    }

    @Override
    public Pointer convertToNumpy(INDArray array) {
        DataBuffer dataBuffer = convertToNumpyBuffer(array);
        OpaqueDataBuffer opaqueDataBuffer = dataBuffer.opaqueBuffer();
        opaqueDataBuffer.capacity(dataBuffer.length());
        opaqueDataBuffer.limit(dataBuffer.length());
        return opaqueDataBuffer.primaryBuffer();
    }

    /**
     * Create from an in memory numpy pointer.
     * Note that this is heavily used
     * in our python library jumpy.
     *
     * @param pointer the pointer to the
     *                numpy array
     * @return an ndarray created from the in memory
     * numpy pointer
     */
    @Override
    public INDArray createFromNpyPointer(Pointer pointer) {
        Pointer dataPointer = nativeOps.dataPointForNumpy(pointer);
        DataBuffer data = null;
        Pointer shapeBufferPointer = nativeOps.shapeBufferForNumpy(pointer);
        int length = nativeOps.lengthForShapeBufferPointer(shapeBufferPointer);
        shapeBufferPointer.capacity(8 * length);
        shapeBufferPointer.limit(8 * length);
        shapeBufferPointer.position(0);


        val intPointer = new LongPointer(shapeBufferPointer);
        val newPointer = new LongPointer(length);

        val perfD = PerformanceTracker.getInstance().helperStartTransaction();

        Pointer.memcpy(newPointer, intPointer, shapeBufferPointer.limit());

        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, shapeBufferPointer.limit(), MemcpyDirection.HOST_TO_HOST);

        DataBuffer shapeBuffer = Nd4j.createBuffer(
                newPointer,
                DataType.INT64,
                length,
                LongIndexer.create(newPointer));

        val jvmShapeInfo = shapeBuffer.asLong();
        val dtype = ArrayOptionsHelper.dataType(jvmShapeInfo);

        //set the location to copy from to the actual data buffer passed the header
        long dataBufferLength = Shape.length(jvmShapeInfo);

        long totalBytesToCopy = dtype.width() * dataBufferLength;
        Pointer pointer1 = nativeOps.dataPointForNumpyHeader(pointer);
        pointer1.capacity(dataBufferLength);

        switch (dtype) {
            case BOOL: {
                val dPointer = new BooleanPointer(dataBufferLength);
                val perfX = PerformanceTracker.getInstance().helperStartTransaction();

                Pointer.memcpy(dPointer, dataPointer,totalBytesToCopy);

                PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX,totalBytesToCopy, MemcpyDirection.HOST_TO_HOST);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        dataBufferLength,
                        BooleanIndexer.create(dPointer));
            }
            break;
            case UBYTE: {
                val dPointer = new BytePointer(dataBufferLength);
                val perfX = PerformanceTracker.getInstance().helperStartTransaction();

                Pointer.memcpy(dPointer, pointer1, totalBytesToCopy);

                PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, totalBytesToCopy, MemcpyDirection.HOST_TO_HOST);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        dataBufferLength,
                        UByteIndexer.create(dPointer));
            }
            break;
            case BYTE: {
                val dPointer = new BytePointer(dataBufferLength);
                val perfX = PerformanceTracker.getInstance().helperStartTransaction();

                Pointer.memcpy(dPointer, dataPointer, totalBytesToCopy);

                PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, totalBytesToCopy, MemcpyDirection.HOST_TO_HOST);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        dataBufferLength,
                        ByteIndexer.create(dPointer));
            }
            break;
            case UINT64:
            case LONG: {
                val dPointer = new LongPointer(dataBufferLength);
                val perfX = PerformanceTracker.getInstance().helperStartTransaction();

                Pointer.memcpy(dPointer, dataPointer, totalBytesToCopy);

                PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, totalBytesToCopy, MemcpyDirection.HOST_TO_HOST);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        dataBufferLength,
                        LongIndexer.create(dPointer));
            }
            break;
            case UINT32: {
                val dPointer = new IntPointer(dataBufferLength);
                val perfX = PerformanceTracker.getInstance().helperStartTransaction();

                Pointer.memcpy(dPointer, dataPointer,totalBytesToCopy);

                PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, totalBytesToCopy, MemcpyDirection.HOST_TO_HOST);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        dataBufferLength,
                        UIntIndexer.create(dPointer));
            }
            break;
            case INT: {
                val dPointer = new IntPointer(dataBufferLength);
                val perfX = PerformanceTracker.getInstance().helperStartTransaction();

                Pointer.memcpy(dPointer, dataPointer,totalBytesToCopy);

                PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, totalBytesToCopy, MemcpyDirection.HOST_TO_HOST);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        dataBufferLength,
                        IntIndexer.create(dPointer));
            }
            break;
            case UINT16: {
                val dPointer = new ShortPointer(dataBufferLength);
                val perfX = PerformanceTracker.getInstance().helperStartTransaction();

                Pointer.memcpy(dPointer, dataPointer, totalBytesToCopy);

                PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, totalBytesToCopy, MemcpyDirection.HOST_TO_HOST);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        dataBufferLength,
                        UShortIndexer.create(dPointer));
            }
            break;
            case SHORT: {
                val dPointer = new ShortPointer(dataBufferLength);
                val perfX = PerformanceTracker.getInstance().helperStartTransaction();

                Pointer.memcpy(dPointer, dataPointer, totalBytesToCopy);

                PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, totalBytesToCopy, MemcpyDirection.HOST_TO_HOST);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        dataBufferLength,
                        ShortIndexer.create(dPointer));
            }
            break;
            case BFLOAT16:
            case HALF: {
                val dPointer = new ShortPointer(dataBufferLength);
                val perfX = PerformanceTracker.getInstance().helperStartTransaction();

                Pointer.memcpy(dPointer, dataPointer, totalBytesToCopy);

                PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, totalBytesToCopy, MemcpyDirection.HOST_TO_HOST);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        dataBufferLength,
                        HalfIndexer.create(dPointer));
            }
            break;
            case FLOAT: {
                val dPointer = new FloatPointer(dataBufferLength);
                val perfX = PerformanceTracker.getInstance().helperStartTransaction();

                Pointer.memcpy(dPointer, dataPointer, totalBytesToCopy);

                PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, totalBytesToCopy, MemcpyDirection.HOST_TO_HOST);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        dataBufferLength,
                        FloatIndexer.create(dPointer));
            }
            break;
            case DOUBLE: {
                val dPointer = new DoublePointer(dataBufferLength);
                val perfX = PerformanceTracker.getInstance().helperStartTransaction();

                Pointer.memcpy(dPointer, dataPointer, totalBytesToCopy);

                PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, totalBytesToCopy, MemcpyDirection.HOST_TO_HOST);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        dataBufferLength,
                        DoubleIndexer.create(dPointer));
            }
            break;
        }

        INDArray ret = Nd4j.create(data,
                Shape.shape(shapeBuffer),
                Shape.strideArr(shapeBuffer),
                0,
                Shape.order(shapeBuffer));

        Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.DEVICE);

        return ret;
    }

    @Override
    public INDArray createFromNpyHeaderPointer(Pointer pointer) {
        val dtype = DataType.fromInt(nativeOps.dataTypeFromNpyHeader(pointer));

        Pointer dataPointer = nativeOps.dataPointForNumpyHeader(pointer);
        int dataBufferElementSize = nativeOps.elementSizeForNpyArrayHeader(pointer);
        DataBuffer data = null;
        Pointer shapeBufferPointer = nativeOps.shapeBufferForNumpyHeader(pointer);
        int length = nativeOps.lengthForShapeBufferPointer(shapeBufferPointer);
        shapeBufferPointer.capacity(8 * length);
        shapeBufferPointer.limit(8 * length);
        shapeBufferPointer.position(0);


        val intPointer = new LongPointer(shapeBufferPointer);
        val newPointer = new LongPointer(length);

        val perfD = PerformanceTracker.getInstance().helperStartTransaction();

        Pointer.memcpy(newPointer, intPointer, shapeBufferPointer.limit());

        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, shapeBufferPointer.limit(), MemcpyDirection.HOST_TO_HOST);

        DataBuffer shapeBuffer = Nd4j.createBuffer(
                newPointer,
                DataType.INT64,
                length,
                LongIndexer.create(newPointer));

        dataPointer.position(0);
        long dataNumElements =  Shape.length(shapeBuffer);
        long dataLength = dataBufferElementSize * Shape.length(shapeBuffer);
        dataPointer.limit(dataLength);
        dataPointer.capacity(dataLength);

        val perfX = PerformanceTracker.getInstance().helperStartTransaction();

        switch (dtype) {
            case BYTE: {
                val dPointer = new BytePointer(dataNumElements);
                Pointer.memcpy(dPointer, dataPointer, dataNumElements);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        Shape.length(shapeBuffer),
                        ByteIndexer.create(dPointer));
            }
            break;
            case SHORT: {
                val dPointer = new ShortPointer(dataNumElements);
                Pointer.memcpy(dPointer, dataPointer, dataNumElements);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        Shape.length(shapeBuffer),
                        ShortIndexer.create(dPointer));
            }
            break;
            case INT: {
                val dPointer = new IntPointer(dataNumElements);
                Pointer.memcpy(dPointer, dataPointer, dataNumElements);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        Shape.length(shapeBuffer),
                        IntIndexer.create(dPointer));
            }
            break;
            case LONG: {
                val dPointer = new LongPointer(dataNumElements);
                Pointer.memcpy(dPointer, dataPointer, dataNumElements);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        Shape.length(shapeBuffer),
                        LongIndexer.create(dPointer));
            }
            break;
            case UBYTE: {
                val dPointer = new BytePointer(dataNumElements);
                Pointer.memcpy(dPointer, dataPointer,dataNumElements);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        Shape.length(shapeBuffer),
                        UByteIndexer.create(dPointer));
            }
            break;
            case UINT16: {
                val dPointer = new ShortPointer(dataNumElements);
                Pointer.memcpy(dPointer, dataPointer, dataNumElements);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        Shape.length(shapeBuffer),
                        UShortIndexer.create(dPointer));
            }
            break;
            case UINT32: {
                val dPointer = new IntPointer(dataNumElements);
                Pointer.memcpy(dPointer, dataPointer,dataNumElements);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        Shape.length(shapeBuffer),
                        IntIndexer.create(dPointer));
            }
            break;
            case UINT64: {
                val dPointer = new LongPointer(dataNumElements);
                Pointer.memcpy(dPointer, dataPointer, dataNumElements);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        Shape.length(shapeBuffer),
                        LongIndexer.create(dPointer));
            }
            break;
            case HALF: {
                val dPointer = new ShortPointer(dataNumElements);
                Pointer.memcpy(dPointer, dataPointer,dataNumElements);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        Shape.length(shapeBuffer),
                        HalfIndexer.create(dPointer));
            }
            break;
            case FLOAT: {
                // TODO: we might want to skip copy, and use existing pointer/data here
                val dPointer = new FloatPointer(dataNumElements);
                Pointer.memcpy(dPointer, dataPointer, dataNumElements);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        Shape.length(shapeBuffer),
                        FloatIndexer.create(dPointer));
            }
            break;
            case DOUBLE: {
                // TODO: we might want to skip copy, and use existing pointer/data here
                val dPointer = new DoublePointer(dataNumElements);
                Pointer.memcpy(dPointer, dataPointer,dataNumElements);

                data = Nd4j.createBuffer(dPointer,
                        dtype,
                        Shape.length(shapeBuffer),
                        DoubleIndexer.create(dPointer));
            }
            break;
            default:
                throw new RuntimeException("Unsupported data type: [" + dtype + "]");
        }

        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST);

        INDArray ret = Nd4j.create(data,
                Shape.shape(shapeBuffer),
                Shape.strideArr(shapeBuffer),
                0,
                Shape.order(shapeBuffer));

        return ret;
    }


    /**
     * Create from a given numpy file.
     *
     * @param file the file to create the ndarray from
     * @return the created ndarray
     */
    @Override
    public INDArray createFromNpyFile(File file) {
        byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
        ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder());
        directBuffer.put(pathBytes);
        ((Buffer) directBuffer).rewind();
        ((Buffer) directBuffer).position(0);
        Pointer pointer = nativeOps.numpyFromFile(new BytePointer(directBuffer));

        INDArray result = createFromNpyPointer(pointer);

        // releasing original pointer here
        nativeOps.releaseNumpy(pointer);
        return result;
    }

    @Override
    public Map<String, INDArray> createFromNpzFile(File file) throws Exception{

        // TODO error checks
        HashMap<String, INDArray> map = new HashMap<>();
        InputStream is = new FileInputStream(file);
        while(true){
            byte[] localHeader = new byte[30];
            is.read(localHeader);
            if ((int)localHeader[2] != 3 || (int)localHeader[3] != 4){
                if(map.isEmpty()) {
                    throw new IllegalStateException("Found malformed NZP file header: File is not a npz file? " + file.getPath());
                } else {
                    break;
                }
            }
            int fNameLength = localHeader[26];
            byte[] fNameBytes = new byte[fNameLength];
            is.read(fNameBytes);
            String fName = "";
            for (int i=0; i < fNameLength - 4; i++){
                fName += (char)fNameBytes[i];
            }
            int extraFieldLength = localHeader[28];
            if (extraFieldLength > 0){
                is.read(new byte[extraFieldLength]);
            }
            is.read(new byte[11]);

            String headerStr = "";
            int b;
            while((b = is.read()) != ((int)'\n')){
                headerStr += (char)b;
            }

            int idx;
            String typeStr;
            if(headerStr.contains("<")){
                idx = headerStr.indexOf("'<") + 2;
            } else {
                idx = headerStr.indexOf("'|") + 2;
            }
            typeStr = headerStr.substring(idx, idx + 2);

            int elemSize;
            DataType dt;
            if (typeStr.equals("f8")){
                elemSize = 8;
                dt = DataType.DOUBLE;
            } else if (typeStr.equals("f4")){
                elemSize = 4;
                dt = DataType.FLOAT;
            } else if(typeStr.equals("f2")){
                elemSize = 2;
                dt = DataType.HALF;
            } else if(typeStr.equals("i8")){
                elemSize = 8;
                dt = DataType.LONG;
            } else if (typeStr.equals("i4")){
                elemSize = 4;
                dt = DataType.INT;
            } else if(typeStr.equals("i2")){
                elemSize = 2;
                dt = DataType.SHORT;
            } else if(typeStr.equals("i1")){
                elemSize = 1;
                dt = DataType.BYTE;
            } else if(typeStr.equals("u1")){
                elemSize = 1;
                dt = DataType.UBYTE;
            } else{
                throw new Exception("Unsupported data type: " + typeStr);
            }
            idx = headerStr.indexOf("'fortran_order': ");
            char order = (headerStr.charAt(idx + "'fortran_order': ".length()) == 'F')? 'c' : 'f';

            String shapeStr = headerStr.substring(headerStr.indexOf("(") + 1, headerStr.indexOf(")"));

            shapeStr = shapeStr.replace(" ", "");
            String[] dims = shapeStr.split(",");
            long[] shape = new long[dims.length];
            long size = 1;
            for (int i =0; i < dims.length; i++){
                long d = Long.parseLong(dims[i]);
                shape[i] = d;
                size *= d;
            }


            // TODO support long shape

            int numBytes = (int)(size * elemSize);
            byte[] data = new byte[numBytes];
            is.read(data);
            ByteBuffer bb = ByteBuffer.wrap(data);

            if (dt == DataType.DOUBLE){
                double[] doubleData = new double[(int)size];
                for (int i = 0; i < size; i++) {
                    long l = bb.getLong(8 * i);
                    l = Long.reverseBytes(l);
                    doubleData[i] = Double.longBitsToDouble(l);
                }
                map.put(fName, Nd4j.create(doubleData, shape, order));
            } else if(dt == DataType.FLOAT) {
                float[] floatData = new float[(int)size];
                for (int i = 0; i < size; i++) {
                    int i2 = bb.getInt(4 * i);
                    i2 = Integer.reverseBytes(i2);
                    float f = Float.intBitsToFloat(i2);
                    floatData[i] = f;
                }
                map.put(fName, Nd4j.create(floatData, shape, order));
            } else if(dt == DataType.HALF || dt == DataType.FLOAT16) {
                INDArray arr = Nd4j.create(DataType.FLOAT16, size);
                ByteBuffer bb2 = arr.data().pointer().asByteBuffer();
                for( int i = 0; i < size; i++ ) {
                    short s = bb.getShort(2*i);
                    bb2.put((byte)((s >> 8) & 0xff));
                    bb2.put((byte)(s & 0xff));
                }
                Nd4j.getAffinityManager().tagLocation(arr, AffinityManager.Location.HOST);
                map.put(fName, arr.reshape(order, shape));
            } else if(dt == DataType.LONG || dt == DataType.INT64){
                long[] d = new long[(int)size];
                for (int i = 0; i < size; i++){
                    long l = bb.getLong(8 * i);
                    l = Long.reverseBytes(l);
                    d[i] = l;
                }
                map.put(fName, Nd4j.createFromArray(d).reshape(order, shape));
            } else if(dt == DataType.INT || dt == DataType.INT32) {
                int[] d = new int[(int)size];
                for (int i = 0; i < size; i++) {
                    int l = bb.getInt(4 * i);
                    l = Integer.reverseBytes(l);
                    d[i] = l;
                }
                map.put(fName, Nd4j.createFromArray(d).reshape(order, shape));
            } else if(dt == DataType.SHORT) {
                short[] d = new short[(int)size];
                for (int i = 0; i < size; i++) {
                    short l = bb.getShort(2 * i);
                    l = Short.reverseBytes(l);
                    d[i] = l;
                }
                map.put(fName, Nd4j.createFromArray(d).reshape(order, shape));
            } else if(dt == DataType.BYTE || dt == DataType.INT8) {
                map.put(fName, Nd4j.createFromArray(data).reshape(order, shape));
            } else if(dt == DataType.UBYTE) {
                short[] d = new short[(int)size];
                for (int i = 0; i < size; i++) {
                    short l = ((short) (bb.get(i) & (short) 0xff));
                    d[i] = l;
                }
                map.put(fName, Nd4j.createFromArray(d).reshape(order, shape).castTo(DataType.UBYTE));
            }

        }

        return map;

    }
    public Map<String, INDArray> _createFromNpzFile(File file) throws Exception{

        // TODO: Fix libnd4j implementation
        byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
        ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder());
        directBuffer.put(pathBytes);
        ((Buffer) directBuffer).rewind();
        ((Buffer) directBuffer).position(0);
        Pointer pointer = nativeOps.mapFromNpzFile(new BytePointer(directBuffer));
        int n = nativeOps.getNumNpyArraysInMap(pointer);
        HashMap<String, INDArray> map = new HashMap<>();

        for (int i=0; i < n; i++) {
            //pre allocate 255 chars, only use up to null terminated
            //create a null terminated string buffer pre allocated for use
            //with the buffer
            byte[] buffer = new byte[255];
            for(int j = 0; j < buffer.length; j++) {
                buffer[j] = '\0';
            }

            BytePointer charPointer = new BytePointer(buffer);
            String arrName = nativeOps.getNpyArrayNameFromMap(pointer, i,charPointer);
            Pointer arrPtr = nativeOps.getNpyArrayFromMap(pointer, i);
            int ndim = nativeOps.getNpyArrayRank(arrPtr);
            long[] shape = new long[ndim];
            LongPointer shapePtr = nativeOps.getNpyArrayShape(arrPtr);

            long length = 1;
            for (int j = 0; j < ndim; j++) {
                shape[j] = shapePtr.get(j);
                length *= shape[j];
            }

            int numBytes = nativeOps.getNpyArrayElemSize(arrPtr);

            int elemSize = numBytes * 8;

            char order = nativeOps.getNpyArrayOrder(arrPtr);

            Pointer dataPointer = nativeOps.dataPointForNumpyStruct(arrPtr);


            dataPointer.position(0);

            long size = elemSize * length;
            dataPointer.limit(size);
            dataPointer.capacity(size);

            INDArray arr;
            if (elemSize == Float.SIZE){
                FloatPointer dPointer = new FloatPointer(dataPointer.limit() / elemSize);
                DataBuffer data = Nd4j.createBuffer(dPointer,
                        DataType.FLOAT,
                        length,
                        FloatIndexer.create(dPointer));

                arr = Nd4j.create(data, shape, Nd4j.getStrides(shape, order), 0, order, DataType.FLOAT);

            }
            else if (elemSize == Double.SIZE){
                DoublePointer dPointer = new DoublePointer(dataPointer.limit() / elemSize);
                DataBuffer data = Nd4j.createBuffer(dPointer,
                        DataType.DOUBLE,
                        length,
                        DoubleIndexer.create(dPointer));
                arr = Nd4j.create(data, shape, Nd4j.getStrides(shape, order), 0, order, DataType.DOUBLE);
            }

            else{
                throw new Exception("Unsupported data type: " + String.valueOf(elemSize));
            }


            map.put(arrName, arr);
        }

        return map;

    }

}